Skip to content

Commit

Permalink
Detect duplicated references to module instance in AutoUnit (pytorc…
Browse files Browse the repository at this point in the history
…h#893)

Summary: Pull Request resolved: pytorch#893

Reviewed By: JKSenthil

Differential Revision: D62053414

fbshipit-source-id: 93dd9d73807d12707561ed00318adf9a4cdf90af
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 6, 2024
1 parent 24e6af6 commit b5b0b03
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
92 changes: 91 additions & 1 deletion tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
from torch import nn

from torch.distributed import GradBucket
from torchtnt.framework._test_utils import (
Expand All @@ -39,7 +40,7 @@
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down Expand Up @@ -294,6 +295,95 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None:
)
self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1)

@skip_if_not_distributed
def test_module_attr_duplicate_reference_validation(self) -> None:
spawn_multi_process(
2,
"gloo",
self._test_module_attr_duplicate_reference_validation,
)

@staticmethod
def _test_module_attr_duplicate_reference_validation() -> None:
error_msg = (
"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
"by AutoUnit. This is known to cause errors on checkpointing and model training "
"mode. Please remove this attribute and access the existing `self.module` instead."
)

# Unit that stores unwrapped module
class ChildUnit(AutoUnit):
def __init__(self, module, strategy):
super().__init__(module=module, strategy=strategy)
self._module = self.module.module if strategy else self.module

def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.Tensor([1]), torch.Tensor([1])

def configure_optimizers_and_lr_scheduler(
self, module: torch.nn.Module
) -> Tuple[torch.optim.Optimizer, TLRScheduler]:
return MagicMock(), MagicMock()

# Test with two levels of inheritance
class GrandchildUnit(DummyAutoUnit):
def __init__(self, module, strategy):
super().__init__(module=module, strategy=strategy)
self._module = module

# Test duplicated references to module
test_cases = [
(DummyAutoUnit, None, False),
(ChildUnit, None, True),
(ChildUnit, FSDPStrategy(), True),
(ChildUnit, DDPStrategy(), True),
(GrandchildUnit, None, True),
]
for unit_type, strategy, expect_error in test_cases:
module = nn.Linear(2, 2)
error_container = []
with patch(
"torchtnt.framework.auto_unit.logging.Logger.error",
side_effect=error_container.append,
):
unit = unit_type(module=module, strategy=strategy)

tc = unittest.TestCase()
expected_errors = [error_msg.format(name="_module")] if expect_error else []
tc.assertEqual(error_container, expected_errors)
tc.assertIs(module, unit.module.module if strategy else unit.module)

def test_module_attr_reassignment_validation(self) -> None:
# Test reassignment of module attribute
class ReassigningUnit1(DummyAutoUnit):
def __init__(self, module):
super().__init__(module=module)
self.module = module

class ReassigningUnit2(DummyAutoUnit):
def __init__(self, module):
super().__init__(module=module)
self.configure_model()

def configure_model(self):
self.module = torch.nn.Linear(3, 3)

for unit_type in (ReassigningUnit1, ReassigningUnit2):
module = nn.Linear(2, 2)
warning_container = []
with patch(
"torchtnt.framework.auto_unit.logging.Logger.warning",
side_effect=warning_container.append,
):
unit_type(module=module)

expected_warnings = [
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
]
self.assertEqual(warning_container, expected_warnings)

@skip_if_not_distributed
def test_auto_unit_ddp(self) -> None:
"""
Expand Down
40 changes: 40 additions & 0 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import contextlib
import logging
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -52,6 +53,8 @@
from torchtnt.utils.swa import AveragedModel
from typing_extensions import Literal

_logger: logging.Logger = logging.getLogger(__name__)


TData = TypeVar("TData")

Expand Down Expand Up @@ -550,6 +553,43 @@ def __init__(
self.lr_scheduler: Optional[TLRScheduler] = None
self.swa_scheduler: Optional[SWALR] = None

def __setattr__(self, name: str, value: object) -> None:
if isinstance(value, torch.nn.Module):
self._validate_module_attr(name, value)

super().__setattr__(name, value)

def _validate_module_attr(self, name: str, module: torch.nn.Module) -> None:
"""
The AutoUnit is designed to manage the input model using the `self.module` attribute,
which should not be reassigned. Additionally, if a subclass saves another attribute
referencing the same model instance (wrapped or unwrapped), then the same instance will
appear two times in the tracked_modules. This is problematic for checkpointing and handling
of evaluation/training mode.
"""
# First time the module attribute is set is in the AutoUnit's initialization
if not hasattr(self, "module"):
return

# Value of self.module should not be changed after initialization
if name == "module":
_logger.warning(
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
)
return

# Otherwise, double check that this is not a duplicate reference to the self.module instance
managed_modules = [self.module]
if isinstance(self.module, DDP) or isinstance(self.module, FSDP):
managed_modules.append(self.module.module)

if any(module is managed_module for managed_module in managed_modules):
_logger.error(
f"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
+ "by AutoUnit. This is known to cause errors on checkpointing and model training "
+ "mode. Please remove this attribute and access the existing `self.module` instead."
)

@abstractmethod
def configure_optimizers_and_lr_scheduler(
self, module: torch.nn.Module
Expand Down

0 comments on commit b5b0b03

Please sign in to comment.