Skip to content

Commit

Permalink
Ensure state_dict fqn unchanged for pipelined postproc modules (#2503)
Browse files Browse the repository at this point in the history
Summary:

To avoid issues with checkpointing and restoring, `PipelinedPreproc` should behave "transparently" to operations related to the model structure - specifically the ones that save or load `state_dict`. To do so, we want to preserve the original FQN of the preproc module inside the `PipelinedPreproc` class that wraps it. The relevant methods are:

1) named_modules
2) named_parameters
3) named_buffers
4) state_dict
5) load_state_dict

**Potential limitation:** This solution relies on the `load_state_dict` override in the `DistributedModelParallel` to adjust the way model modules restore their state - the same override that `ShardedModule` relies on. It means that using PipelinedPreproc outside `DistributedModelParallel` might cause the model to fail restoring from a checkpoint. However, similar to `ShardedModule`, PipelinedPreproc is not supposed to be directly used on the models, and be injected as part of the model rewrite for semi-sync pipeline. **TL;DR:** it should not happen, unless someone is actively doing the wrong thing.

Differential Revision: D64572844
  • Loading branch information
che-sh authored and facebook-github-bot committed Oct 24, 2024
1 parent 7855789 commit 9e061fe
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _load_state_dict(
state_dict, prefix
)
add_prefix_to_state_dict(state_dict, prefix + _DDP_STATE_DICT_PREFIX)
if isinstance(module, ShardedModule):
if getattr(module, "_FORCE_STATE_DICT_LOAD", False):
return module.load_state_dict(state_dict, strict=strict)
else:
module._load_from_state_dict(
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,13 +1788,16 @@ class TestNegSamplingModule(torch.nn.Module):
ModelInput
"""

TEST_BUFFER_NAME = "test_buffer"

def __init__(
self,
extra_input: ModelInput,
has_params: bool = False,
) -> None:
super().__init__()
self._extra_input = extra_input
self.register_buffer(self.TEST_BUFFER_NAME, torch.zeros(1))
if has_params:
self._linear: nn.Module = nn.Linear(30, 30)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

# pyre-strict

import copy
import enum
import unittest
from unittest.mock import MagicMock

import torch

from parameterized import parameterized
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule

Expand All @@ -21,13 +25,19 @@
_get_node_args,
_rewrite_model,
PipelinedForward,
PipelinedPreproc,
TrainPipelineContext,
)
from torchrec.distributed.types import ShardingType

from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class ModelType(enum.Enum):
VANILLA = "vanilla"
SHARDED = "sharded"
PIPELINED = "pipelined"


class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
Expand Down Expand Up @@ -79,6 +89,7 @@ def test_rewrite_model(self) -> None:
dist_stream=None,
pipeline_preproc=True,
)

self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward)
self.assertIsInstance(
sharded_model.module.sparse.weighted_ebc.forward, PipelinedForward
Expand All @@ -93,6 +104,134 @@ def test_rewrite_model(self) -> None:
],
sharded_model.module.preproc_module,
)
state_dict = sharded_model.state_dict()
missing_keys, unexpected_keys = sharded_model.load_state_dict(state_dict)
self.assertEqual(missing_keys, [])
self.assertEqual(unexpected_keys, [])

def test_pipelined_preproc_state_dict(self) -> None:
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.tensor(1.0))

def forward(self, x):
return x

class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.test_module = TestModule()

def forward(self, x):
return self.test_module(x)

model = TestModel()

rewritten_model = copy.deepcopy(model)
# pyre-ignore[8]
rewritten_model.test_module = PipelinedPreproc(
preproc_module=rewritten_model.test_module,
fqn="test_module",
args=[],
context=TrainPipelineContext(),
)
# self-check - we want the state dict be the same between vanilla model and "rewritten model"
self.assertDictEqual(model.state_dict(), rewritten_model.state_dict())
state_dict = rewritten_model.state_dict()
self.assertEqual(list(state_dict.keys()), ["test_module.weight"])

def _create_model_for_snapshot_test(
self, source_model_type: ModelType
) -> torch.nn.Module:
if source_model_type == ModelType.VANILLA:
extra_input = ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=10,
world_size=1,
num_float_features=10,
randomize_indices=False,
)[0].to(self.device)

preproc_module = TestNegSamplingModule(
extra_input=extra_input,
)
model = self._setup_model(preproc_module=preproc_module)
model.to_empty(device=self.device)
return model
elif source_model_type == ModelType.SHARDED:
model = self._create_model_for_snapshot_test(ModelType.VANILLA)
sharded_model, optim = self._generate_sharded_model_and_optimizer(
model,
ShardingType.TABLE_WISE.value,
EmbeddingComputeKernel.FUSED.value,
{},
)
return sharded_model
elif source_model_type == ModelType.PIPELINED:
model = self._create_model_for_snapshot_test(ModelType.SHARDED)
_rewrite_model(
model=model,
batch=None,
context=TrainPipelineContext(),
dist_stream=None,
pipeline_preproc=True,
)
return model
else:
raise ValueError(f"Unknown model type {source_model_type}")

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-ignore[56]: cannot infer type of name_func
# pyre-ignore[16]: cannot find parameterized.expand
@parameterized.expand(
[
# Self-consistency checks - model should be able to load it's own state
(ModelType.VANILLA, ModelType.VANILLA),
(ModelType.SHARDED, ModelType.SHARDED),
(ModelType.PIPELINED, ModelType.PIPELINED),
# Production case - saved from pipelined, restored to sharded
(ModelType.PIPELINED, ModelType.SHARDED),
# Nice-to-haves:
(ModelType.SHARDED, ModelType.PIPELINED),
(ModelType.VANILLA, ModelType.PIPELINED),
(ModelType.VANILLA, ModelType.SHARDED),
# Won't work - restoring sharded/pipelined into vanilla fails with
# "'Parameter' object has no attribute 'local_shards'"
# ... which is totally expected, as vanilla model is not sharded
# (ModelType.SHARDED, ModelType.VANILLA),
# (ModelType.PIPELINED, ModelType.VANILLA),
],
name_func=lambda fn, _num, p: f"{fn.__name__}_{p.args[0].value}_{p.args[1].value}",
)
def test_restore_from_snapshot(
self, source_model_type: ModelType, recipient_model_type: ModelType
) -> None:
source_model = self._create_model_for_snapshot_test(source_model_type)
recipient_model = self._create_model_for_snapshot_test(recipient_model_type)

# self-check - we want the state dict be the same between source and recipient
# although this is not strictly necessary
# Asserting only on keys since the asserting on entire state dict fails with
# "Boolean value of Tensor with more than one value is ambiguous" (not sure why)
self.assertEqual(
source_model.state_dict().keys(), recipient_model.state_dict().keys()
)

state_dict = source_model.state_dict()
self.assertTrue(
f"preproc_module.{TestNegSamplingModule.TEST_BUFFER_NAME}"
in state_dict.keys()
)

missing_keys, unexpected_keys = recipient_model.load_state_dict(state_dict)
# if both are empty, restoring the state dict was successful
self.assertEqual(missing_keys, [])
self.assertEqual(unexpected_keys, [])


class TestUtils(unittest.TestCase):
Expand Down
63 changes: 62 additions & 1 deletion torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import copy
import itertools
import logging
from collections import defaultdict
from collections import defaultdict, OrderedDict
from dataclasses import dataclass, field

from itertools import chain
Expand Down Expand Up @@ -48,6 +48,7 @@ class FSDP2:
immutable_list as fx_immutable_list,
)
from torch.fx.node import Node
from torch.nn.modules.module import _IncompatibleKeys
from torch.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -249,6 +250,8 @@ class PipelinedPreproc(torch.nn.Module):
setattr(model, fqn, preproc)
"""

_FORCE_STATE_DICT_LOAD = True

def __init__(
self,
preproc_module: torch.nn.Module,
Expand Down Expand Up @@ -320,6 +323,64 @@ def set_context(self, context: TrainPipelineContext) -> None:
def get_context(self) -> TrainPipelineContext:
return self._context

def named_modules(
self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = "",
remove_duplicate: bool = True,
) -> Iterator[Tuple[str, torch.nn.Module]]:
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
# This is needed because otherwise the rewrite won't find the existing preproc, and will create a new one
# Also, `named_modules` need to include self - see base implementation in the nn.modules.Module
yield prefix, self
# Difference from base implementation is here - the child name (_preproc_module) is not added to the prefix
yield from self._preproc_module.named_modules(
memo, prefix, remove_duplicate
)

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
yield from self._preproc_module.named_parameters(
prefix,
recurse,
remove_duplicate,
)

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
yield from self._preproc_module.named_buffers(prefix, recurse, remove_duplicate)

# pyre-ignore [14]
def state_dict(
self,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
# super().state_dict(destination, prefix, keep_vars)
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
destination._metadata = OrderedDict()
self._preproc_module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
return destination

# pyre-ignore [14]
def load_state_dict(
self,
state_dict: OrderedDict[str, torch.Tensor],
strict: bool = True,
) -> _IncompatibleKeys:
return self._preproc_module.load_state_dict(state_dict, strict=strict)


class BaseForward:
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ class ShardedModule(
from data-parallel to model parallel and vise-versa.
"""

_FORCE_STATE_DICT_LOAD = True

@abc.abstractmethod
def __init__(
self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None
Expand Down

0 comments on commit 9e061fe

Please sign in to comment.