diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 10bb68985..11164e3e0 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -474,7 +474,7 @@ def _load_state_dict( state_dict, prefix ) add_prefix_to_state_dict(state_dict, prefix + _DDP_STATE_DICT_PREFIX) - if isinstance(module, ShardedModule): + if getattr(module, "_FORCE_STATE_DICT_LOAD", False): return module.load_state_dict(state_dict, strict=strict) else: module._load_from_state_dict( diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 91d37edc0..b91ed8eb4 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -1788,6 +1788,8 @@ class TestNegSamplingModule(torch.nn.Module): ModelInput """ + TEST_BUFFER_NAME = "test_buffer" + def __init__( self, extra_input: ModelInput, @@ -1795,6 +1797,7 @@ def __init__( ) -> None: super().__init__() self._extra_input = extra_input + self.register_buffer(self.TEST_BUFFER_NAME, torch.zeros(1)) if has_params: self._linear: nn.Module = nn.Linear(30, 30) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index f8dcf08fb..a0a8fc097 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -7,10 +7,13 @@ # pyre-strict +import copy +import enum import unittest from unittest.mock import MagicMock import torch + from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule @@ -21,13 +24,19 @@ _get_node_args, _rewrite_model, PipelinedForward, + PipelinedPreproc, TrainPipelineContext, ) from torchrec.distributed.types import ShardingType - from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +class ModelType(enum.Enum): + VANILLA = "vanilla" + SHARDED = "sharded" + PIPELINED = "pipelined" + + class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( @@ -79,6 +88,7 @@ def test_rewrite_model(self) -> None: dist_stream=None, pipeline_preproc=True, ) + self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward) self.assertIsInstance( sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward @@ -93,6 +103,135 @@ def test_rewrite_model(self) -> None: ], sharded_model.module.preproc_module, ) + state_dict = sharded_model.state_dict() + missing_keys, unexpected_keys = sharded_model.load_state_dict(state_dict) + self.assertEqual(missing_keys, []) + self.assertEqual(unexpected_keys, []) + + def test_pipelined_preproc_state_dict(self) -> None: + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("weight", torch.tensor(1.0)) + + def forward(self, x): + return x + + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.test_module = TestModule() + + def forward(self, x): + return self.test_module(x) + + model = TestModel() + + rewritten_model = copy.deepcopy(model) + # pyre-ignore[8] + rewritten_model.test_module = PipelinedPreproc( + preproc_module=rewritten_model.test_module, + fqn="test_module", + args=[], + context=TrainPipelineContext(), + ) + # self-check - we want the state dict be the same between vanilla model and "rewritten model" + self.assertDictEqual(model.state_dict(), rewritten_model.state_dict()) + state_dict = rewritten_model.state_dict() + self.assertEqual(list(state_dict.keys()), ["test_module.weight"]) + + def _create_model_for_snapshot_test( + self, source_model_type: ModelType + ) -> torch.nn.Module: + if source_model_type == ModelType.VANILLA: + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=10, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(preproc_module=preproc_module) + model.to_empty(device=self.device) + return model + elif source_model_type == ModelType.SHARDED: + model = self._create_model_for_snapshot_test(ModelType.VANILLA) + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, + ShardingType.TABLE_WISE.value, + EmbeddingComputeKernel.FUSED.value, + {}, + ) + return sharded_model + elif source_model_type == ModelType.PIPELINED: + model = self._create_model_for_snapshot_test(ModelType.SHARDED) + _rewrite_model( + model=model, + batch=None, + context=TrainPipelineContext(), + dist_stream=None, + pipeline_preproc=True, + ) + return model + else: + raise ValueError(f"Unknown model type {source_model_type}") + + def _test_restore_from_snapshot( + self, source_model_type: ModelType, recipient_model_type: ModelType + ) -> None: + source_model = self._create_model_for_snapshot_test(source_model_type) + recipient_model = self._create_model_for_snapshot_test(recipient_model_type) + + # self-check - we want the state dict be the same between source and recipient + # although this is not strictly necessary + # Asserting only on keys since the asserting on entire state dict fails with + # "Boolean value of Tensor with more than one value is ambiguous" (not sure why) + self.assertEqual( + source_model.state_dict().keys(), recipient_model.state_dict().keys() + ) + + state_dict = source_model.state_dict() + self.assertTrue( + f"preproc_module.{TestNegSamplingModule.TEST_BUFFER_NAME}" + in state_dict.keys() + ) + + missing_keys, unexpected_keys = recipient_model.load_state_dict(state_dict) + # if both are empty, restoring the state dict was successful + self.assertEqual(missing_keys, []) + self.assertEqual(unexpected_keys, []) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_restore_from_snapshot(self) -> None: + # makeshift parameterized test - to avoid introducing new dependencies + variants = [ + # Self-consistency checks - model should be able to load it's own state + (ModelType.VANILLA, ModelType.VANILLA), + (ModelType.SHARDED, ModelType.SHARDED), + (ModelType.PIPELINED, ModelType.PIPELINED), + # Production case - saved from pipelined, restored to sharded + (ModelType.PIPELINED, ModelType.SHARDED), + # Nice-to-haves: + (ModelType.SHARDED, ModelType.PIPELINED), + (ModelType.VANILLA, ModelType.PIPELINED), + (ModelType.VANILLA, ModelType.SHARDED), + # Won't work - restoring sharded/pipelined into vanilla fails with + # "'Parameter' object has no attribute 'local_shards'" + # ... which is totally expected, as vanilla model is not sharded + # (ModelType.SHARDED, ModelType.VANILLA), + # (ModelType.PIPELINED, ModelType.VANILLA), + ] + for source_model_type, recipient_model_type in variants: + self._test_restore_from_snapshot(source_model_type, recipient_model_type) class TestUtils(unittest.TestCase): diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 69058b0f6..154c0a60d 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -10,7 +10,7 @@ import copy import itertools import logging -from collections import defaultdict +from collections import defaultdict, OrderedDict from dataclasses import dataclass, field from itertools import chain @@ -48,6 +48,7 @@ class FSDP2: immutable_list as fx_immutable_list, ) from torch.fx.node import Node +from torch.nn.modules.module import _IncompatibleKeys from torch.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable from torchrec.distributed.embedding_sharding import ( @@ -249,6 +250,8 @@ class PipelinedPreproc(torch.nn.Module): setattr(model, fqn, preproc) """ + _FORCE_STATE_DICT_LOAD = True + def __init__( self, preproc_module: torch.nn.Module, @@ -320,6 +323,64 @@ def set_context(self, context: TrainPipelineContext) -> None: def get_context(self) -> TrainPipelineContext: return self._context + def named_modules( + self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, torch.nn.Module]]: + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + # This is needed because otherwise the rewrite won't find the existing preproc, and will create a new one + # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module + yield prefix, self + # Difference from base implementation is here - the child name (_preproc_module) is not added to the prefix + yield from self._preproc_module.named_modules( + memo, prefix, remove_duplicate + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + yield from self._preproc_module.named_parameters( + prefix, + recurse, + remove_duplicate, + ) + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from self._preproc_module.named_buffers(prefix, recurse, remove_duplicate) + + # pyre-ignore [14] + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + # super().state_dict(destination, prefix, keep_vars) + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + self._preproc_module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + return destination + + # pyre-ignore [14] + def load_state_dict( + self, + state_dict: OrderedDict[str, torch.Tensor], + strict: bool = True, + ) -> _IncompatibleKeys: + return self._preproc_module.load_state_dict(state_dict, strict=strict) + class BaseForward: def __init__( diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index b6bdec3bf..141ae049c 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -872,6 +872,8 @@ class ShardedModule( from data-parallel to model parallel and vise-versa. """ + _FORCE_STATE_DICT_LOAD = True + @abc.abstractmethod def __init__( self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None