From 24e6af69049f4fb05389d707c7876cee47e78979 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Fri, 30 Aug 2024 14:14:50 -0700 Subject: [PATCH] QAT support in core loop (#892) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/892 Reviewed By: ywwwer, diego-urgell Differential Revision: D61935530 fbshipit-source-id: 6c85ffdccf3a1014441e4bdfc1b527769a44fae9 --- tests/framework/test_loop_utils.py | 42 ++++++++++++++++++++++++++++++ torchtnt/framework/_loop_utils.py | 34 +++++++++++++++++------- 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/tests/framework/test_loop_utils.py b/tests/framework/test_loop_utils.py index e0c7dc60f4..9be63bb113 100644 --- a/tests/framework/test_loop_utils.py +++ b/tests/framework/test_loop_utils.py @@ -12,6 +12,7 @@ import torch from torch import distributed as dist, nn +from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.distributed import launcher from torch.utils.data import DataLoader @@ -88,6 +89,47 @@ def test_set_module_training_mode(self) -> None: self.assertFalse(prior_module_train_states["module"]) self.assertFalse(prior_module_train_states["loss_fn"]) + def test_set_module_training_mode_qat(self) -> None: + """ + Test _set_module_training_mode + """ + + # define a floating point model + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.fc(x) + return x + + loss_fn = nn.CrossEntropyLoss() + module = torch.export.export(M(), (torch.rand(4, 4),)).module() + + tracked_modules: Dict[str, torch.nn.Module] = { + "module": module, + "loss_fn": loss_fn, + } + + self.assertTrue(model_is_exported(module)) + prior_module_train_states = _set_module_training_mode(tracked_modules, False) + + self.assertFalse(module.training) + self.assertFalse(loss_fn.training) + + self.assertTrue(prior_module_train_states["module"]) + self.assertTrue(prior_module_train_states["loss_fn"]) + + # set back to True + prior_module_train_states = _set_module_training_mode(tracked_modules, True) + + self.assertTrue(module.training) + self.assertTrue(loss_fn.training) + + self.assertFalse(prior_module_train_states["module"]) + self.assertFalse(prior_module_train_states["loss_fn"]) + def test_reset_module_training_mode(self) -> None: """ Test _reset_module_training_mode diff --git a/torchtnt/framework/_loop_utils.py b/torchtnt/framework/_loop_utils.py index 7e2a5bf608..74a139aec3 100644 --- a/torchtnt/framework/_loop_utils.py +++ b/torchtnt/framework/_loop_utils.py @@ -96,14 +96,15 @@ def _set_module_training_mode( if _EXPORT_UTILS_AVAIL and model_is_exported( module.module if is_ddp else module ): - if mode: - module = torch.ao.quantization.move_exported_model_to_train( - module.module if is_ddp else module - ) - else: - module = torch.ao.quantization.move_exported_model_to_eval( - module.module if is_ddp else module - ) + move_fn = ( + torch.ao.quantization.move_exported_model_to_train + if mode + else torch.ao.quantization.move_exported_model_to_eval + ) + move_fn(module.module if is_ddp else module) + module.training = mode + if is_ddp: + module.module.training = mode else: module.train(mode) @@ -118,7 +119,22 @@ def _reset_module_training_mode( # returning back to the user for name, module in modules.items(): if name in prior_modes: - module.train(prior_modes[name]) + is_ddp = isinstance(module, DistributedDataParallel) + + if _EXPORT_UTILS_AVAIL and model_is_exported( + module.module if is_ddp else module + ): + move_fn = ( + torch.ao.quantization.move_exported_model_to_train + if prior_modes[name] + else torch.ao.quantization.move_exported_model_to_eval + ) + move_fn(module.module if is_ddp else module) + module.training = prior_modes[name] + if is_ddp: + module.module.training = prior_modes[name] + else: + module.train(prior_modes[name]) def _log_api_usage(entry_point: str) -> None: