Skip to content

Commit

Permalink
make PartiallyMaterializedTensor work with checkpointing (#2531)
Browse files Browse the repository at this point in the history
Summary:

create ShardedTensor from PartiallyMaterializedTensor, so that KVTensorWrapper can be used for checkpointing.

Reviewed By: pradeepfn

Differential Revision: D65281052
  • Loading branch information
Yulu Jia authored and facebook-github-bot committed Nov 1, 2024
1 parent d2ed744 commit 5e4a0b8
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 21 deletions.
61 changes: 51 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,25 @@ def _gen_named_parameters_by_table_ssd(
yield (table_name, weight)


def _gen_named_parameters_by_table_ssd_pmt(
emb_module: SSDTableBatchedEmbeddingBags,
table_name_to_count: Dict[str, int],
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
) -> Iterator[Tuple[str, nn.Parameter]]:
"""
Return an PartiallyMaterializedTensor to indicate that the table is on remote device.
"""
pmts = emb_module.split_embedding_weights()
for table_config, pmt in zip(config.embedding_tables, pmts):
table_name = table_config.name
emb_table = pmt
weight: nn.Parameter = nn.Parameter(emb_table)
# pyre-ignore
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
yield (table_name, weight)


def _gen_named_parameters_by_table_fused(
emb_module: SplitTableBatchedEmbeddingBagsCodegen,
table_name_to_count: Dict[str, int],
Expand Down Expand Up @@ -1257,7 +1276,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -1291,11 +1310,20 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()

return destination
self.flush()
tmp = self.split_embedding_weights(no_snapshot=no_snapshot)
for emb_table in self._config.embedding_tables:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
self._config.embedding_tables,
tmp,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand Down Expand Up @@ -1326,6 +1354,16 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def named_pmts(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -1340,11 +1378,14 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[torch._tensor.Tensor]:
pmts = self.emb_module.split_embedding_weights(no_snapshot)
emb_weights = []
for pmt, _emb_table in zip(pmts, self._config.embedding_tables):
emb_weights.append(pmt)
return emb_weights


class BatchedFusedEmbeddingBag(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
qbias = param[2]
param = param[0]

assert embedding_table.local_rows == param.size(0) # pyre-ignore[16]
assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]

if qscale is not None:
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def named_pmts(self) -> Iterator[Tuple[str, torch.Tensor]]:
for emb_module in self._emb_modules:
if isinstance(emb_module, KeyValueEmbeddingBag):
yield from emb_module.named_pmts()

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()
Expand Down
43 changes: 33 additions & 10 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,9 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_local_shards = OrderedDict()
self._model_parallel_name_to_shards_wrapper = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_sharded_pmt = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()
self._name_to_kv_module = OrderedDict()

model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
Expand Down Expand Up @@ -903,10 +905,6 @@ def _initialize_torch_state(self) -> None: # noqa
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]
if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
continue

if self._output_dtensor:
if shards_wrapper_map["local_tensors"]:
Expand Down Expand Up @@ -937,13 +935,21 @@ def _initialize_torch_state(self) -> None: # noqa
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)
sharded_tensor = ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)
if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
self._model_parallel_name_to_sharded_pmt[table_name] = (
sharded_tensor
)
else:
self._model_parallel_name_to_sharded_tensor[table_name] = (
sharded_tensor
)

def post_state_dict_hook(
module: ShardedEmbeddingBagCollection,
Expand All @@ -965,6 +971,23 @@ def post_state_dict_hook(
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = d_tensor

sharded_pmts = copy.deepcopy(module._model_parallel_name_to_sharded_pmt)
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
if isinstance(sharding, DpPooledEmbeddingSharding):
# unwrap DDP
lookup = lookup.module
else:
for key, v in lookup.named_pmts():
destination_key = f"{prefix}embedding_bags.{key}.weight"
assert key in sharded_pmts
sharded_pmts[key].local_shards()[0].tensor = v
for (
table_name,
sharded_pmt,
) in sharded_pmts.items():
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = sharded_pmt

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,82 @@ def test_ssd_load_state_dict(
self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
self._compare_models(m1, m2, is_deterministic=is_deterministic)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-ignore[56]
@given(
sharder_type=st.sampled_from(
[
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.KEY_VALUE.value,
]
),
sharding_type=st.sampled_from(
[
ShardingType.ROW_WISE.value,
]
),
is_training=st.booleans(),
stochastic_rounding=st.booleans(),
dtype=st.sampled_from([DataType.FP32, DataType.FP16]),
)
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
def test_ssd_partially_materialized_wrapper_load_state_dict(
self,
sharder_type: str,
kernel_type: str,
sharding_type: str,
is_training: bool,
stochastic_rounding: bool,
dtype: DataType,
) -> None:
"""
This test checks that if SSD TBE is deterministic. That is, if two SSD
TBEs start with the same state, they would produce the same output.
"""
self._set_table_weights_precision(dtype)

fused_params = {
"learning_rate": 0.1,
"stochastic_rounding": stochastic_rounding,
}
is_deterministic = dtype == DataType.FP32 or not stochastic_rounding
constraints = {
table.name: ParameterConstraints(
sharding_types=[sharding_type],
compute_kernels=[kernel_type],
)
for i, table in enumerate(self.tables)
}
sharders = [
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
fused_params=fused_params,
),
]

# pyre-ignore
models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints)
m1, m2 = models

# load state dict for dense modules
src_state_dict = m1.state_dict()
m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", src_state_dict))
self._copy_ssd_emb_modules(m1, m2)

if is_training:
self._train_models(m1, m2, batch)
self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
self._compare_models(m1, m2, is_deterministic=is_deterministic)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
Expand Down

0 comments on commit 5e4a0b8

Please sign in to comment.