Skip to content

Commit

Permalink
Fix yolov3 train accuracy test (#2321)
Browse files Browse the repository at this point in the history
Summary:
Setup the correct example inputs for the yolov3 train accuracy test
Fixes #2248

Pull Request resolved: #2321

Reviewed By: aaronenyeshi

Differential Revision: D58823036

Pulled By: xuzhao9

fbshipit-source-id: d63c069bdf0da6ba496f274510e64d76e3c10f76
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jun 20, 2024
1 parent 5254910 commit caa76d8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchbenchmark/models/yolov3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.num_epochs = 1
self.train_num_batch = 1
self.prefetch = True
if test == "train":
if test == "eval" or self.dargs.accuracy:
self.model, self.example_inputs = self.prep_eval()
elif test == "train":
train_args = split(
f"--data {DATA_DIR}/coco128.data --img 416 --batch {self.batch_size} --nosave --notest \
--epochs {self.num_epochs} --device {self.device_str} --weights '' \
Expand All @@ -64,8 +66,6 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.training_loop, self.model, self.example_inputs = prepare_training_loop(
train_args
)
elif test == "eval":
self.model, self.example_inputs = self.prep_eval()
self.amp_context = nullcontext

def prep_eval(self):
Expand Down

0 comments on commit caa76d8

Please sign in to comment.