Skip to content

Commit

Permalink
Fixing python script for generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jjomier committed Sep 19, 2023
1 parent 6dcda48 commit 61d8e0d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
17 changes: 16 additions & 1 deletion applications/object_detection_torch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,22 @@ If resolution is updated in entity generation, it must be updated in the followi

## Building the application

Please refer to the top level Holohub README.md file for information on how to build this application.
The best way to run this application is inside the container, as it would provide all the required third-party packages:

```bash
# Create the container image for this application
./dev_container build --docker_file applications/object_detection_torch/Dockerfile --img object_detection_torch
# Launch the container
./dev_container launch --img object_detection_torch
# Build the application. Note that this downloads the video data as well
./run build object_detection_torch
# Generate the pytorch model
python3 applications/object_detection_torch/generate_resnet_model.py data/object_detection_torch/frcnn_resnet50_t.pt
# Run the application
./run launch object_detection_torch
```

Please refer to the top level Holohub README.md file for more information on how to build this application.

## Running the application

Expand Down
8 changes: 5 additions & 3 deletions applications/object_detection_torch/generate_resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import sys

import torch
from torchvision.models import detection

os.environ['TORCH_HOME'] = os.getcwd()
os.environ["TORCH_HOME"] = os.getcwd()

model_file = "frcnn_resnet50_t.pt"
if len(sys.argv) > 1:
model_file = sys.argv[1]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

det_model = detection.fasterrcnn_resnet50_fpn(progress=True).to(DEVICE)
det_model = detection.fasterrcnn_resnet50_fpn(
pretrained=True, progress=True, weights_backbone=True
).to(DEVICE)

det_model.eval()
det_model_script = torch.jit.script(det_model)
Expand Down

0 comments on commit 61d8e0d

Please sign in to comment.