Skip to content

Commit

Permalink
Add --torchcompile-mode args to train, validation, inference, benchma…
Browse files Browse the repository at this point in the history
…rk scripts
  • Loading branch information
rwightman committed Oct 2, 2024
1 parent 14d55a7 commit 4d4bdd6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
5 changes: 4 additions & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")

# codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group()
Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
device='cuda',
torchscript=False,
torchcompile=None,
torchcompile_mode=None,
aot_autograd=False,
reparam=False,
precision='float32',
Expand Down Expand Up @@ -278,7 +281,7 @@ def __init__(
elif torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
torch._dynamo.reset()
self.model = torch.compile(self.model, backend=torchcompile)
self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode)
self.compiled = True
elif aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
Expand Down
4 changes: 3 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")

scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
Expand Down Expand Up @@ -216,7 +218,7 @@ def main():
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value')
group.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")

# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
Expand Down Expand Up @@ -627,7 +629,7 @@ def main():
if args.torchcompile:
# torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile)
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)

# create the train and eval datasets
if args.data and not args.data_dir:
Expand Down
5 changes: 3 additions & 2 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@
parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)

parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")

scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
Expand Down Expand Up @@ -246,7 +247,7 @@ def validate(args):
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
Expand Down

0 comments on commit 4d4bdd6

Please sign in to comment.