Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect duplicated references to module instance in AutoUnit #893

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
from torch import nn

from torch.distributed import GradBucket
from torchtnt.framework._test_utils import (
Expand All @@ -39,7 +40,7 @@
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down Expand Up @@ -294,6 +295,95 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None:
)
self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1)

@skip_if_not_distributed
def test_module_attr_duplicate_reference_validation(self) -> None:
spawn_multi_process(
2,
"gloo",
self._test_module_attr_duplicate_reference_validation,
)

@staticmethod
def _test_module_attr_duplicate_reference_validation() -> None:
error_msg = (
"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
"by AutoUnit. This is known to cause errors on checkpointing and model training "
"mode. Please remove this attribute and access the existing `self.module` instead."
)

# Unit that stores unwrapped module
class ChildUnit(AutoUnit):
def __init__(self, module, strategy):
super().__init__(module=module, strategy=strategy)
self._module = self.module.module if strategy else self.module

def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.Tensor([1]), torch.Tensor([1])

def configure_optimizers_and_lr_scheduler(
self, module: torch.nn.Module
) -> Tuple[torch.optim.Optimizer, TLRScheduler]:
return MagicMock(), MagicMock()

# Test with two levels of inheritance
class GrandchildUnit(DummyAutoUnit):
def __init__(self, module, strategy):
super().__init__(module=module, strategy=strategy)
self._module = module

# Test duplicated references to module
test_cases = [
(DummyAutoUnit, None, False),
(ChildUnit, None, True),
(ChildUnit, FSDPStrategy(), True),
(ChildUnit, DDPStrategy(), True),
(GrandchildUnit, None, True),
]
for unit_type, strategy, expect_error in test_cases:
module = nn.Linear(2, 2)
error_container = []
with patch(
"torchtnt.framework.auto_unit.logging.Logger.error",
side_effect=error_container.append,
):
unit = unit_type(module=module, strategy=strategy)

tc = unittest.TestCase()
expected_errors = [error_msg.format(name="_module")] if expect_error else []
tc.assertEqual(error_container, expected_errors)
tc.assertIs(module, unit.module.module if strategy else unit.module)

def test_module_attr_reassignment_validation(self) -> None:
# Test reassignment of module attribute
class ReassigningUnit1(DummyAutoUnit):
def __init__(self, module):
super().__init__(module=module)
self.module = module

class ReassigningUnit2(DummyAutoUnit):
def __init__(self, module):
super().__init__(module=module)
self.configure_model()

def configure_model(self):
self.module = torch.nn.Linear(3, 3)

for unit_type in (ReassigningUnit1, ReassigningUnit2):
module = nn.Linear(2, 2)
warning_container = []
with patch(
"torchtnt.framework.auto_unit.logging.Logger.warning",
side_effect=warning_container.append,
):
unit_type(module=module)

expected_warnings = [
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
]
self.assertEqual(warning_container, expected_warnings)

@skip_if_not_distributed
def test_auto_unit_ddp(self) -> None:
"""
Expand Down
40 changes: 40 additions & 0 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import contextlib
import logging
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -52,6 +53,8 @@
from torchtnt.utils.swa import AveragedModel
from typing_extensions import Literal

_logger: logging.Logger = logging.getLogger(__name__)


TData = TypeVar("TData")

Expand Down Expand Up @@ -550,6 +553,43 @@ def __init__(
self.lr_scheduler: Optional[TLRScheduler] = None
self.swa_scheduler: Optional[SWALR] = None

def __setattr__(self, name: str, value: object) -> None:
if isinstance(value, torch.nn.Module):
self._validate_module_attr(name, value)

super().__setattr__(name, value)

def _validate_module_attr(self, name: str, module: torch.nn.Module) -> None:
"""
The AutoUnit is designed to manage the input model using the `self.module` attribute,
which should not be reassigned. Additionally, if a subclass saves another attribute
referencing the same model instance (wrapped or unwrapped), then the same instance will
appear two times in the tracked_modules. This is problematic for checkpointing and handling
of evaluation/training mode.
"""
# First time the module attribute is set is in the AutoUnit's initialization
if not hasattr(self, "module"):
return

# Value of self.module should not be changed after initialization
if name == "module":
_logger.warning(
"The self.module attribute is managed by AutoUnit and is not meant to be reassigned."
)
return

# Otherwise, double check that this is not a duplicate reference to the self.module instance
managed_modules = [self.module]
if isinstance(self.module, DDP) or isinstance(self.module, FSDP):
managed_modules.append(self.module.module)

if any(module is managed_module for managed_module in managed_modules):
_logger.error(
f"Attribute '{name}' of the custom TNT Unit stores a reference to the model managed"
+ "by AutoUnit. This is known to cause errors on checkpointing and model training "
+ "mode. Please remove this attribute and access the existing `self.module` instead."
)

@abstractmethod
def configure_optimizers_and_lr_scheduler(
self, module: torch.nn.Module
Expand Down
Loading