From 7a495ea809de3131b25356c47b5c6637927a9db7 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Wed, 31 Jan 2024 17:15:43 -0800 Subject: [PATCH] Add unit test to test feature_processer appears in generate_submodule_to_device_str Summary: As titled Differential Revision: https://internalfb.com/D51486848 fbshipit-source-id: 8a4d7df771bbed34a481cfd002988a55504e5260 --- .../distributed/test_utils/infer_utils.py | 84 +++++++++++++++++++ .../distributed/tests/test_infer_shardings.py | 69 +++------------ 2 files changed, 95 insertions(+), 58 deletions(-) diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index cf1e64d6c..be1e69f50 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -8,6 +8,7 @@ #!/usr/bin/env python3 import copy +import math import re from contextlib import contextmanager from dataclasses import dataclass, field @@ -77,6 +78,7 @@ EmbeddingBagConfig, ) from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, @@ -890,3 +892,85 @@ def replace_sharded_quant_modules_tbes_with_mock_tbes(M: torch.nn.Module) -> Non for lookup in m._lookups: for lookup_per_rank in lookup._embedding_lookups_per_rank: replace_registered_tbes_with_mock_tbes(lookup_per_rank) + + +@torch.fx.wrap +def fx_wrap_fp_forward( + features: KeyedJaggedTensor, + feature_pow: Dict[str, float], + feature_min: Dict[str, float], + feature_max: Dict[str, float], + bucket_w_dict: Dict[str, torch.Tensor], +) -> KeyedJaggedTensor: + scores_list = [] + for feature_name in features.keys(): + jt = features[feature_name] + if feature_name in feature_min: + scores = jt.weights() + scores = torch.clamp( + scores, + min=feature_min[feature_name], + max=feature_max[feature_name], + ) + indices = torch.floor(torch.pow(scores, feature_pow[feature_name])) + indices = indices.to(torch.int32) + scores = torch.index_select(bucket_w_dict[feature_name], 0, indices) + scores_list.append(scores) + else: + scores_list.append( + jt.weights() + if jt.weights_or_none() is not None + else torch.ones(jt.values().shape[0], device=jt.values().device) + ) + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=torch.cat(scores_list) if scores_list else features.weights_or_none(), + lengths=features.lengths(), + offsets=features.offsets(), + stride=features.stride(), + length_per_key=features.length_per_key(), + ) + + +class TimeGapPoolingCollectionModule(FeatureProcessorsCollection): + def __init__( + self, + feature_pow: Dict[str, float], + feature_min: Dict[str, float], + feature_max: Dict[str, float], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.feature_min = feature_min + self.feature_max = feature_max + self.feature_pow = feature_pow + + self.bucket_w: nn.ParameterDict = nn.ParameterDict() + self.device = device + # needed since nn.ParameterDict isn't torchscriptable (get_items) + self.bucket_w_dict: Dict[str, torch.Tensor] = {} + + for feature_name in feature_pow.keys(): + param = torch.empty( + [ + math.ceil( + math.pow(feature_max[feature_name], feature_pow[feature_name]) + ) + + 2 + ], + device=device, + ) + self.bucket_w[feature_name] = param + self.bucket_w_dict[feature_name] = param + self.register_buffer(f"bucket_w_dict_{feature_name}", param) + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + return fx_wrap_fp_forward( + features, + self.feature_pow, + self.feature_min, + self.feature_max, + {k: getattr(self, f"bucket_w_dict_{k}") for k in self.bucket_w_dict.keys()}, + ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 0b185d0e5..9a28b0c6f 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -7,7 +7,6 @@ #!/usr/bin/env python3 -import math import unittest from typing import Dict, List, Tuple @@ -15,12 +14,7 @@ import torch from hypothesis import given, settings -from torchrec import ( - EmbeddingBagConfig, - EmbeddingCollection, - EmbeddingConfig, - KeyedJaggedTensor, -) +from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.planner.enumerators import EmbeddingEnumerator @@ -36,7 +30,6 @@ from torchrec.distributed.sharding_plan import ( column_wise, construct_module_sharding_plan, - placement, row_wise, ) from torchrec.distributed.test_utils.infer_utils import ( @@ -54,63 +47,17 @@ shard_qec, TestModelInfo, TestQuantEBCSharder, + TimeGapPoolingCollectionModule, ) from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.distributed.types import ShardingEnv, ShardingPlan from torchrec.fx import symbolic_trace from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection torch.fx.wrap("len") -class TimeGapPoolingCollectionModule(FeatureProcessorsCollection): - def __init__( - self, - feature_pow: float, - feature_min: float, - feature_max: float, - device: torch.device, - ) -> None: - super().__init__() - self.feature_min = feature_min - self.feature_max = feature_max - self.feature_pow = feature_pow - self.device = device - - param = torch.empty( - [math.ceil(math.pow(feature_max, feature_pow)) + 2], - device=device, - ) - self.register_buffer("w", param) - - def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: - scores_list = [] - for feature_name in features.keys(): - jt = features[feature_name] - scores = jt.weights() - scores = torch.clamp( - scores, - min=self.feature_min, - max=self.feature_max, - ) - indices = torch.floor(torch.pow(scores, self.feature_pow)) - indices = indices.to(torch.int32) - scores = torch.index_select(self.w, 0, indices) - scores_list.append(scores) - - return KeyedJaggedTensor( - keys=features.keys(), - values=features.values(), - weights=torch.cat(scores_list) - if scores_list - else features.weights_or_none(), - lengths=features.lengths(), - stride=features.stride(), - ) - - def placement_helper(device_type: str, index: int = 0) -> str: if device_type == "cpu": return f"rank:0/{device_type}" # cpu only use rank 0 @@ -1085,9 +1032,15 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None: device=mi.sparse_device, ), TimeGapPoolingCollectionModule( - feature_pow=1.0, - feature_min=-1.0, - feature_max=1.0, + feature_pow={ + table.feature_names[0]: 1.0 for table in mi.tables + }, + feature_min={ + table.feature_names[0]: -1.0 for table in mi.tables + }, + feature_max={ + table.feature_names[0]: 1.0 for table in mi.tables + }, device=mi.sparse_device, ), )