Skip to content

Commit

Permalink
Add unit test to test feature_processer appears in generate_submodule…
Browse files Browse the repository at this point in the history
…_to_device_str

Summary: As titled

Differential Revision: https://internalfb.com/D51486848

fbshipit-source-id: 8a4d7df771bbed34a481cfd002988a55504e5260
  • Loading branch information
Qiang Zhang authored and facebook-github-bot committed Feb 1, 2024
1 parent f960c8d commit 7a495ea
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 58 deletions.
84 changes: 84 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#!/usr/bin/env python3

import copy
import math
import re
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -77,6 +78,7 @@
EmbeddingBagConfig,
)
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
Expand Down Expand Up @@ -890,3 +892,85 @@ def replace_sharded_quant_modules_tbes_with_mock_tbes(M: torch.nn.Module) -> Non
for lookup in m._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
replace_registered_tbes_with_mock_tbes(lookup_per_rank)


@torch.fx.wrap
def fx_wrap_fp_forward(
features: KeyedJaggedTensor,
feature_pow: Dict[str, float],
feature_min: Dict[str, float],
feature_max: Dict[str, float],
bucket_w_dict: Dict[str, torch.Tensor],
) -> KeyedJaggedTensor:
scores_list = []
for feature_name in features.keys():
jt = features[feature_name]
if feature_name in feature_min:
scores = jt.weights()
scores = torch.clamp(
scores,
min=feature_min[feature_name],
max=feature_max[feature_name],
)
indices = torch.floor(torch.pow(scores, feature_pow[feature_name]))
indices = indices.to(torch.int32)
scores = torch.index_select(bucket_w_dict[feature_name], 0, indices)
scores_list.append(scores)
else:
scores_list.append(
jt.weights()
if jt.weights_or_none() is not None
else torch.ones(jt.values().shape[0], device=jt.values().device)
)

return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=torch.cat(scores_list) if scores_list else features.weights_or_none(),
lengths=features.lengths(),
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
)


class TimeGapPoolingCollectionModule(FeatureProcessorsCollection):
def __init__(
self,
feature_pow: Dict[str, float],
feature_min: Dict[str, float],
feature_max: Dict[str, float],
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.feature_min = feature_min
self.feature_max = feature_max
self.feature_pow = feature_pow

self.bucket_w: nn.ParameterDict = nn.ParameterDict()
self.device = device
# needed since nn.ParameterDict isn't torchscriptable (get_items)
self.bucket_w_dict: Dict[str, torch.Tensor] = {}

for feature_name in feature_pow.keys():
param = torch.empty(
[
math.ceil(
math.pow(feature_max[feature_name], feature_pow[feature_name])
)
+ 2
],
device=device,
)
self.bucket_w[feature_name] = param
self.bucket_w_dict[feature_name] = param
self.register_buffer(f"bucket_w_dict_{feature_name}", param)

def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
return fx_wrap_fp_forward(
features,
self.feature_pow,
self.feature_min,
self.feature_max,
{k: getattr(self, f"bucket_w_dict_{k}") for k in self.bucket_w_dict.keys()},
)
69 changes: 11 additions & 58 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,14 @@

#!/usr/bin/env python3

import math
import unittest
from typing import Dict, List, Tuple

import hypothesis.strategies as st

import torch
from hypothesis import given, settings
from torchrec import (
EmbeddingBagConfig,
EmbeddingCollection,
EmbeddingConfig,
KeyedJaggedTensor,
)
from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
Expand All @@ -36,7 +30,6 @@
from torchrec.distributed.sharding_plan import (
column_wise,
construct_module_sharding_plan,
placement,
row_wise,
)
from torchrec.distributed.test_utils.infer_utils import (
Expand All @@ -54,63 +47,17 @@
shard_qec,
TestModelInfo,
TestQuantEBCSharder,
TimeGapPoolingCollectionModule,
)
from torchrec.distributed.test_utils.test_model import ModelInput
from torchrec.distributed.types import ShardingEnv, ShardingPlan
from torchrec.fx import symbolic_trace
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection

torch.fx.wrap("len")


class TimeGapPoolingCollectionModule(FeatureProcessorsCollection):
def __init__(
self,
feature_pow: float,
feature_min: float,
feature_max: float,
device: torch.device,
) -> None:
super().__init__()
self.feature_min = feature_min
self.feature_max = feature_max
self.feature_pow = feature_pow
self.device = device

param = torch.empty(
[math.ceil(math.pow(feature_max, feature_pow)) + 2],
device=device,
)
self.register_buffer("w", param)

def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
scores_list = []
for feature_name in features.keys():
jt = features[feature_name]
scores = jt.weights()
scores = torch.clamp(
scores,
min=self.feature_min,
max=self.feature_max,
)
indices = torch.floor(torch.pow(scores, self.feature_pow))
indices = indices.to(torch.int32)
scores = torch.index_select(self.w, 0, indices)
scores_list.append(scores)

return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=torch.cat(scores_list)
if scores_list
else features.weights_or_none(),
lengths=features.lengths(),
stride=features.stride(),
)


def placement_helper(device_type: str, index: int = 0) -> str:
if device_type == "cpu":
return f"rank:0/{device_type}" # cpu only use rank 0
Expand Down Expand Up @@ -1085,9 +1032,15 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None:
device=mi.sparse_device,
),
TimeGapPoolingCollectionModule(
feature_pow=1.0,
feature_min=-1.0,
feature_max=1.0,
feature_pow={
table.feature_names[0]: 1.0 for table in mi.tables
},
feature_min={
table.feature_names[0]: -1.0 for table in mi.tables
},
feature_max={
table.feature_names[0]: 1.0 for table in mi.tables
},
device=mi.sparse_device,
),
)
Expand Down

0 comments on commit 7a495ea

Please sign in to comment.