Skip to content

Commit

Permalink
QAT support in core loop (#892)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #892

Reviewed By: ywwwer, diego-urgell

Differential Revision: D61935530

fbshipit-source-id: 6c85ffdccf3a1014441e4bdfc1b527769a44fae9
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Aug 30, 2024
1 parent 1545b34 commit 24e6af6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
42 changes: 42 additions & 0 deletions tests/framework/test_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from torch import distributed as dist, nn
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.distributed import launcher

from torch.utils.data import DataLoader
Expand Down Expand Up @@ -88,6 +89,47 @@ def test_set_module_training_mode(self) -> None:
self.assertFalse(prior_module_train_states["module"])
self.assertFalse(prior_module_train_states["loss_fn"])

def test_set_module_training_mode_qat(self) -> None:
"""
Test _set_module_training_mode
"""

# define a floating point model
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(4, 4)

def forward(self, x):
x = self.fc(x)
return x

loss_fn = nn.CrossEntropyLoss()
module = torch.export.export(M(), (torch.rand(4, 4),)).module()

tracked_modules: Dict[str, torch.nn.Module] = {
"module": module,
"loss_fn": loss_fn,
}

self.assertTrue(model_is_exported(module))
prior_module_train_states = _set_module_training_mode(tracked_modules, False)

self.assertFalse(module.training)
self.assertFalse(loss_fn.training)

self.assertTrue(prior_module_train_states["module"])
self.assertTrue(prior_module_train_states["loss_fn"])

# set back to True
prior_module_train_states = _set_module_training_mode(tracked_modules, True)

self.assertTrue(module.training)
self.assertTrue(loss_fn.training)

self.assertFalse(prior_module_train_states["module"])
self.assertFalse(prior_module_train_states["loss_fn"])

def test_reset_module_training_mode(self) -> None:
"""
Test _reset_module_training_mode
Expand Down
34 changes: 25 additions & 9 deletions torchtnt/framework/_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ def _set_module_training_mode(
if _EXPORT_UTILS_AVAIL and model_is_exported(
module.module if is_ddp else module
):
if mode:
module = torch.ao.quantization.move_exported_model_to_train(
module.module if is_ddp else module
)
else:
module = torch.ao.quantization.move_exported_model_to_eval(
module.module if is_ddp else module
)
move_fn = (
torch.ao.quantization.move_exported_model_to_train
if mode
else torch.ao.quantization.move_exported_model_to_eval
)
move_fn(module.module if is_ddp else module)
module.training = mode
if is_ddp:
module.module.training = mode
else:
module.train(mode)

Expand All @@ -118,7 +119,22 @@ def _reset_module_training_mode(
# returning back to the user
for name, module in modules.items():
if name in prior_modes:
module.train(prior_modes[name])
is_ddp = isinstance(module, DistributedDataParallel)

if _EXPORT_UTILS_AVAIL and model_is_exported(
module.module if is_ddp else module
):
move_fn = (
torch.ao.quantization.move_exported_model_to_train
if prior_modes[name]
else torch.ao.quantization.move_exported_model_to_eval
)
move_fn(module.module if is_ddp else module)
module.training = prior_modes[name]
if is_ddp:
module.module.training = prior_modes[name]
else:
module.train(prior_modes[name])


def _log_api_usage(entry_point: str) -> None:
Expand Down

0 comments on commit 24e6af6

Please sign in to comment.