Skip to content

Commit

Permalink
fix ddp train mode (#794)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #794

Since D53659696, train mode was only applied to the inner module of DDP. This diff fixes this

Reviewed By: diego-urgell

Differential Revision: D56424247

fbshipit-source-id: 179e14180cdb8bbc08fde595220bb76f75a37c02
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 22, 2024
1 parent f02d654 commit e6739ab
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions torchtnt/framework/_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def _set_module_training_mode(
prior_module_train_states = {}
for name, module in modules.items():
prior_module_train_states[name] = module.training
if isinstance(module, DistributedDataParallel):
module = module.module
if torch.ao.quantization.pt2e.export_utils.model_is_exported(module):
is_ddp = isinstance(module, DistributedDataParallel)

if torch.ao.quantization.pt2e.export_utils.model_is_exported(
module.module if is_ddp else module
):
if mode:
module = torch.ao.quantization.move_exported_model_to_train(module)
module = torch.ao.quantization.move_exported_model_to_train(
module.module if is_ddp else module
)
else:
module = torch.ao.quantization.move_exported_model_to_eval(module)
module = torch.ao.quantization.move_exported_model_to_eval(
module.module if is_ddp else module
)
else:
module.train(mode)

Expand Down

0 comments on commit e6739ab

Please sign in to comment.