diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 4e9390451..f51d1ae58 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -585,6 +585,25 @@ def _gen_named_parameters_by_table_ssd( yield (table_name, weight) +def _gen_named_parameters_by_table_ssd_pmt( + emb_module: SSDTableBatchedEmbeddingBags, + table_name_to_count: Dict[str, int], + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, +) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Return an PartiallyMaterializedTensor to indicate that the table is on remote device. + """ + pmts = emb_module.split_embedding_weights() + for table_config, pmt in zip(config.embedding_tables, pmts): + table_name = table_config.name + emb_table = pmt + weight: nn.Parameter = nn.Parameter(emb_table) + # pyre-ignore + weight._in_backward_optimizers = [EmptyFusedOptimizer()] + yield (table_name, weight) + + def _gen_named_parameters_by_table_fused( emb_module: SplitTableBatchedEmbeddingBagsCodegen, table_name_to_count: Dict[str, int], @@ -1257,7 +1276,7 @@ def __init__( pg, ) self._param_per_table: Dict[str, nn.Parameter] = dict( - _gen_named_parameters_by_table_ssd( + _gen_named_parameters_by_table_ssd_pmt( emb_module=self._emb_module, table_name_to_count=self.table_name_to_count.copy(), config=self._config, @@ -1291,11 +1310,20 @@ def state_dict( destination: Optional[Dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False, + no_snapshot: bool = True, ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - - return destination + self.flush() + tmp = self.split_embedding_weights(no_snapshot=no_snapshot) + for emb_table in self._config.embedding_tables: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + self._config.embedding_tables, + tmp, + self._pg, + destination, + prefix, + ) + return ret def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -1326,6 +1354,16 @@ def named_split_embedding_weights( key = append_prefix(prefix, f"{config.name}.weight") yield key, tensor + def named_pmts( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(no_snapshot=False), + ): + key = append_prefix(prefix, f"{config.name}") + yield key, tensor + def flush(self) -> None: """ Flush the embeddings in cache back to SSD. Should be pretty expensive. @@ -1340,11 +1378,14 @@ def purge(self) -> None: self.emb_module.lxu_cache_weights.zero_() self.emb_module.lxu_cache_state.fill_(-1) - def split_embedding_weights(self) -> List[torch.Tensor]: - """ - Return fake tensors. - """ - return [param.data for param in self._param_per_table.values()] + def split_embedding_weights( + self, no_snapshot: bool = True + ) -> List[torch._tensor.Tensor]: + pmts = self.emb_module.split_embedding_weights(no_snapshot) + emb_weights = [] + for pmt, _emb_table in zip(pmts, self._config.embedding_tables): + emb_weights.append(pmt) + return emb_weights class BatchedFusedEmbeddingBag( diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index 7c286f81c..84970e4b2 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -99,7 +99,9 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str: qbias = param[2] param = param[0] - assert embedding_table.local_rows == param.size(0) # pyre-ignore[16] + assert embedding_table.local_rows == param.size( # pyre-ignore[16] + 0 + ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16] if qscale is not None: assert embedding_table.local_cols == param.size(1) # pyre-ignore[16] diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index f7e5846fd..eee4702d4 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -649,6 +649,11 @@ def named_parameters_by_table( ) in embedding_kernel.named_parameters_by_table(): yield (table_name, tbe_slice) + def named_pmts(self) -> Iterator[Tuple[str, torch.Tensor]]: + for emb_module in self._emb_modules: + if isinstance(emb_module, KeyValueEmbeddingBag): + yield from emb_module.named_pmts() + def flush(self) -> None: for emb_module in self._emb_modules: emb_module.flush() diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 052049f4c..541c41029 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -830,7 +830,9 @@ def _initialize_torch_state(self) -> None: # noqa self._model_parallel_name_to_local_shards = OrderedDict() self._model_parallel_name_to_shards_wrapper = OrderedDict() self._model_parallel_name_to_sharded_tensor = OrderedDict() + self._model_parallel_name_to_sharded_pmt = OrderedDict() self._model_parallel_name_to_dtensor = OrderedDict() + self._name_to_kv_module = OrderedDict() model_parallel_name_to_compute_kernel: Dict[str, str] = {} for ( @@ -903,10 +905,6 @@ def _initialize_torch_state(self) -> None: # noqa self.embedding_bags[table_name].weight._in_backward_optimizers = [ EmptyFusedOptimizer() ] - if model_parallel_name_to_compute_kernel[table_name] in { - EmbeddingComputeKernel.KEY_VALUE.value - }: - continue if self._output_dtensor: if shards_wrapper_map["local_tensors"]: @@ -937,13 +935,21 @@ def _initialize_torch_state(self) -> None: # noqa ) else: # created ShardedTensors once in init, use in post_state_dict_hook - self._model_parallel_name_to_sharded_tensor[table_name] = ( - ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, - ) + sharded_tensor = ShardedTensor._init_from_local_shards( + local_shards, + self._name_to_table_size[table_name], + process_group=self._env.process_group, ) + if model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.KEY_VALUE.value + }: + self._model_parallel_name_to_sharded_pmt[table_name] = ( + sharded_tensor + ) + else: + self._model_parallel_name_to_sharded_tensor[table_name] = ( + sharded_tensor + ) def post_state_dict_hook( module: ShardedEmbeddingBagCollection, @@ -965,6 +971,23 @@ def post_state_dict_hook( destination_key = f"{prefix}embedding_bags.{table_name}.weight" destination[destination_key] = d_tensor + sharded_pmts = copy.deepcopy(module._model_parallel_name_to_sharded_pmt) + for lookup, sharding in zip(module._lookups, module._embedding_shardings): + if isinstance(sharding, DpPooledEmbeddingSharding): + # unwrap DDP + lookup = lookup.module + else: + for key, v in lookup.named_pmts(): + destination_key = f"{prefix}embedding_bags.{key}.weight" + assert key in sharded_pmts + sharded_pmts[key].local_shards()[0].tensor = v + for ( + table_name, + sharded_pmt, + ) in sharded_pmts.items(): + destination_key = f"{prefix}embedding_bags.{table_name}.weight" + destination[destination_key] = sharded_pmt + self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) self._register_load_state_dict_pre_hook( diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index d1a8854de..c28b48baa 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -198,6 +198,82 @@ def test_ssd_load_state_dict( self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) self._compare_models(m1, m2, is_deterministic=is_deterministic) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_partially_materialized_wrapper_load_state_dict( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + src_state_dict = m1.state_dict() + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", src_state_dict)) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models(m1, m2, is_deterministic=is_deterministic) + @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU",