Skip to content

Commit

Permalink
2024-09-19 nightly release (3262651)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 19, 2024
1 parent 745a0a9 commit e725c16
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 116 deletions.
211 changes: 155 additions & 56 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import copy
import logging
from collections import defaultdict, OrderedDict
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type

Expand Down Expand Up @@ -54,6 +55,8 @@
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)


class ManagedCollisionCollectionAwaitable(LazyAwaitable[KeyedJaggedTensor]):
def __init__(
Expand Down Expand Up @@ -140,6 +143,7 @@ def __init__(
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__()
self.need_preprocess: bool = module.need_preprocess
self._device = device
self._env = env
self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = (
Expand Down Expand Up @@ -249,15 +253,33 @@ def _create_managed_collision_modules(
) -> None:

self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict()
self._feature_to_offset: Dict[str, int] = {}
# To map mch output indices from local to global. key: table_name
self._table_to_offset: Dict[str, int] = {}

# the split sizes of tables belonging to each sharding. outer len is # shardings
self._sharding_per_table_feature_splits: List[List[int]] = []
# the split sizes of features per sharding. len is # shardings
self._sharding_feature_splits: List[int] = []
# the split sizes of features per table. len is # tables sum over all shardings
self._table_feature_splits: List[int] = []
self._feature_names: List[str] = []

# table names of each sharding
self._sharding_tables: List[List[str]] = []
self._sharding_features: List[List[str]] = []

for sharding in self._embedding_shardings:
assert isinstance(sharding, BaseRwEmbeddingSharding)
self._sharding_tables.append([])
self._sharding_features.append([])
self._sharding_per_table_feature_splits.append([])

grouped_embedding_configs: List[GroupedEmbeddingConfig] = (
sharding._grouped_embedding_configs
)
self._sharding_feature_splits.append(len(sharding.feature_names()))

num_sharding_features = 0
for group_config in grouped_embedding_configs:
for table in group_config.embedding_tables:
# pyre-ignore [16]
Expand All @@ -271,6 +293,9 @@ def _create_managed_collision_modules(
] + [table.num_embeddings]
mc_module = module._managed_collision_modules[table.name]

self._sharding_tables[-1].append(table.name)
self._sharding_features[-1].extend(table.feature_names)
self._feature_names.extend(table.feature_names)
self._managed_collision_modules[table.name] = (
mc_module.rebuild_with_output_id_range(
output_id_range=(
Expand Down Expand Up @@ -315,21 +340,42 @@ def _create_managed_collision_modules(
zch_size,
zch_size_cumsum[-1],
)
for feature in table.feature_names:
self._feature_to_offset[feature] = new_min_output_id
self._table_to_offset[table.name] = new_min_output_id

self._table_feature_splits.append(len(table.feature_names))
self._sharding_per_table_feature_splits[-1].append(
self._table_feature_splits[-1]
)
num_sharding_features += self._table_feature_splits[-1]

assert num_sharding_features == len(
sharding.feature_names()
), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}"

if self._sharding_features[-1] != sharding.feature_names():
logger.warn(
"The order of tables of this sharding is altered due to grouping: "
f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}"
)

logger.info(f"{self._table_feature_splits=}")
logger.info(f"{self._sharding_per_table_feature_splits=}")
logger.info(f"{self._feature_names=}")
logger.info(f"{self._table_to_offset=}")
logger.info(f"{self._sharding_tables=}")
logger.info(f"{self._sharding_features=}")

def _create_input_dists(
self,
input_feature_names: List[str],
) -> None:
feature_names: List[str] = []
self._feature_splits: List[int] = []
for sharding in self._embedding_shardings:
for sharding, sharding_features in zip(
self._embedding_shardings, self._sharding_features
):
assert isinstance(sharding, BaseRwEmbeddingSharding)
feature_hash_sizes: List[int] = [
self._managed_collision_modules[self._feature_to_table[f]].input_size()
for f in sharding.feature_names()
for f in sharding_features
]

input_dist = RwSparseFeaturesDist(
Expand All @@ -344,11 +390,9 @@ def _create_input_dists(
keep_original_indices=True,
)
self._input_dists.append(input_dist)
feature_names.extend(sharding.feature_names())
self._feature_splits.append(len(sharding.feature_names()))

self._features_order: List[int] = []
for f in feature_names:
for f in self._feature_names:
self._features_order.append(input_feature_names.index(f))
self._features_order = (
[]
Expand Down Expand Up @@ -386,30 +430,48 @@ def input_dist(
self._has_uninitialized_input_dists = False

with torch.no_grad():
features_dict = features.to_dict()
output: Dict[str, JaggedTensor] = features_dict.copy()
for table, mc_module in self._managed_collision_modules.items():
feature_list: List[str] = self._table_to_features[table]
mc_input: Dict[str, JaggedTensor] = {}
for feature in feature_list:
mc_input[feature] = features_dict[feature]
mc_input = mc_module.preprocess(mc_input)
output.update(mc_input)

# NOTE shared features not currently supported
features = KeyedJaggedTensor.from_jt_dict(output)

if self._features_order:
features = features.permute(
self._features_order,
self._features_order_tensor,
)
features_by_sharding = features.split(
self._feature_splits,
)

feature_splits: List[KeyedJaggedTensor] = []
if self.need_preprocess:
# NOTE: No shared features allowed!
feature_splits = features.split(self._table_feature_splits)
else:
feature_splits = features.split(self._sharding_feature_splits)

ti: int = 0
awaitables = []
for input_dist, features in zip(self._input_dists, features_by_sharding):
awaitables.append(input_dist(features))
for i, tables in enumerate(self._sharding_tables):
if self.need_preprocess:
output: Dict[str, JaggedTensor] = {}
for table in tables:
kjt: KeyedJaggedTensor = feature_splits[ti]
mc_module = self._managed_collision_modules[table]
# TODO: change to Dict[str, Tensor]
mc_input: Dict[str, JaggedTensor] = {
table: JaggedTensor(
values=kjt.values(),
lengths=kjt.lengths(),
)
}
mc_input = mc_module.preprocess(mc_input)
output.update(mc_input)
ti += 1
shard_kjt = KeyedJaggedTensor(
keys=self._sharding_features[i],
values=torch.cat([jt.values() for jt in output.values()]),
lengths=torch.cat([jt.lengths() for jt in output.values()]),
)
else:
shard_kjt = feature_splits[i]

input_dist = self._input_dists[i]

awaitables.append(input_dist(shard_kjt))
ctx.sharding_contexts.append(
SequenceShardingContext(
features_before_input_dist=features,
Expand All @@ -426,26 +488,31 @@ def input_dist(
def _kjt_list_to_tensor_list(
self,
kjt_list: KJTList,
feature_to_offset: Dict[str, int],
) -> List[torch.Tensor]:
remapped_ids_ret: List[torch.Tensor] = []
# TODO: find a better solution
for kjt in kjt_list:
jt_dict = kjt.to_dict()
for feature, jt in jt_dict.items():
offset = feature_to_offset[feature]
jt._values = jt.values().add(offset)
new_kjt = KeyedJaggedTensor.from_jt_dict(jt_dict)
remapped_ids_ret.append(new_kjt.values().view(-1, 1))
# TODO: find a better solution, could be padding
for kjt, tables, splits in zip(
kjt_list, self._sharding_tables, self._sharding_per_table_feature_splits
):
if len(splits) > 1:
feature_splits = kjt.split(splits)
vals: List[torch.Tensor] = []
# assert len(feature_splits) == len(sharding.embedding_tables())
for feature_split, table in zip(feature_splits, tables):
offset = self._table_to_offset[table]
vals.append(feature_split.values() + offset)
remapped_ids_ret.append(torch.cat(vals).view(-1, 1))
else:
remapped_ids_ret.append(kjt.values() + self._table_to_offset[tables[0]])
return remapped_ids_ret

def global_to_local_index(
self,
feature_dict: Dict[str, JaggedTensor],
jt_dict: Dict[str, JaggedTensor],
) -> Dict[str, JaggedTensor]:
for feature, jt in feature_dict.items():
jt._values = jt.values() - self._feature_to_offset[feature]
return feature_dict
for table, jt in jt_dict.items():
jt._values = jt.values() - self._table_to_offset[table]
return jt_dict

def compute(
self,
Expand All @@ -454,26 +521,59 @@ def compute(
) -> KJTList:
remapped_kjts: List[KeyedJaggedTensor] = []

for features, sharding_ctx in zip(
# per shard
for features, sharding_ctx, tables, splits, fns in zip(
dist_input,
ctx.sharding_contexts,
self._sharding_tables,
self._sharding_per_table_feature_splits,
self._sharding_features,
):
sharding_ctx.lengths_after_input_dist = features.lengths().view(
-1, features.stride()
)
features_dict = features.to_dict()
output: Dict[str, JaggedTensor] = features_dict.copy()
for table, mc_module in self._managed_collision_modules.items():
feature_list: List[str] = self._table_to_features[table]
mc_input: Dict[str, JaggedTensor] = {}
for feature in feature_list:
mc_input[feature] = features_dict[feature]
mc_input = mc_module.profile(mc_input)
mc_input = mc_module.remap(mc_input)
mc_input = self.global_to_local_index(mc_input)
output.update(mc_input)

remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(output))
values: torch.Tensor
if len(splits) > 1:
# features per shard split by tables
feature_splits = features.split(splits)
output: Dict[str, JaggedTensor] = {}
for table, kjt in zip(tables, feature_splits):
# TODO: Dict[str, Tensor]
mc_input: Dict[str, JaggedTensor] = {
table: JaggedTensor(
values=kjt.values(),
lengths=kjt.lengths(),
)
}
mcm = self._managed_collision_modules[table]
mc_input = mcm.profile(mc_input)
mc_input = mcm.remap(mc_input)
mc_input = self.global_to_local_index(mc_input)
output.update(mc_input)
values = torch.cat([jt.values() for jt in output.values()])
else:
table: str = tables[0]
mc_input: Dict[str, JaggedTensor] = {
table: JaggedTensor(
values=features.values(),
lengths=features.lengths(),
)
}
mcm = self._managed_collision_modules[table]
mc_input = mcm.profile(mc_input)
mc_input = mcm.remap(mc_input)
mc_input = self.global_to_local_index(mc_input)
values = mc_input[table].values()

remapped_kjts.append(
KeyedJaggedTensor(
keys=fns,
values=values,
lengths=features.lengths(),
weights=features.weights_or_none(),
)
)
return KJTList(remapped_kjts)

def evict(self) -> Dict[str, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -505,8 +605,7 @@ def output_dist(
ctx: ManagedCollisionCollectionContext,
output: KJTList,
) -> LazyAwaitable[KeyedJaggedTensor]:

global_remapped = self._kjt_list_to_tensor_list(output, self._feature_to_offset)
global_remapped = self._kjt_list_to_tensor_list(output)
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
for odist, remapped_ids, sharding_ctx in zip(
Expand Down
Loading

0 comments on commit e725c16

Please sign in to comment.