Skip to content

Commit

Permalink
Index dedup integration in Embedding Collection (pytorch#1277)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1277

Integrate the index dedup operator into Embedding Collection.

Reviewed By: xing-liu

Differential Revision: D47192546

fbshipit-source-id: 4144bcffed4c00df73dabdec25509f870b90b8da
  • Loading branch information
GD06 authored and facebook-github-bot committed Aug 15, 2023
1 parent 7b76356 commit ed1a11b
Showing 1 changed file with 138 additions and 7 deletions.
145 changes: 138 additions & 7 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@


import copy
import logging
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, cast, Dict, List, MutableMapping, Optional, Type, Union

import torch
Expand Down Expand Up @@ -53,9 +55,7 @@
)
from torchrec.distributed.utils import (
add_params_from_parameter_sharding,
append_prefix,
convert_to_fbgemm_types,
filter_state_dict,
merge_fused_params,
optimizer_type_to_emb_opt_type,
)
Expand All @@ -78,6 +78,21 @@
except OSError:
pass

logger: logging.Logger = logging.getLogger(__name__)


EC_INDEX_DEDUP: bool = False


def set_ec_index_dedup(val: bool) -> None:
global EC_INDEX_DEDUP
EC_INDEX_DEDUP = val


def get_ec_index_dedup() -> bool:
global EC_INDEX_DEDUP
return EC_INDEX_DEDUP


def create_embedding_sharding(
sharding_type: str,
Expand Down Expand Up @@ -207,7 +222,14 @@ def _construct_jagged_tensors(
embedding_names: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
original_features: Optional[KeyedJaggedTensor] = None,
reverse_indices: Optional[torch.Tensor] = None,
) -> Dict[str, JaggedTensor]:
if original_features is not None:
features = original_features
if reverse_indices is not None:
embeddings = torch.index_select(embeddings, 0, reverse_indices.to(torch.int32))

ret: Dict[str, JaggedTensor] = {}
stride = features.stride()
length_per_key = features.length_per_key()
Expand Down Expand Up @@ -248,10 +270,17 @@ def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
@dataclass
class EmbeddingCollectionContext(Multistreamable):
sharding_contexts: List[SequenceShardingContext] = field(default_factory=list)
input_features: List[KeyedJaggedTensor] = field(default_factory=list)
reverse_indices: List[torch.Tensor] = field(default_factory=list)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for ctx in self.sharding_contexts:
ctx.record_stream(stream)
for f in self.input_features:
f.record_stream(stream)
for r in self.reverse_indices:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
r.record_stream(stream)


class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]):
Expand All @@ -260,6 +289,7 @@ def __init__(
awaitables_per_sharding: List[Awaitable[torch.Tensor]],
features_per_sharding: List[KeyedJaggedTensor],
embedding_names_per_sharding: List[List[str]],
ctx: EmbeddingCollectionContext,
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
) -> None:
Expand All @@ -269,21 +299,36 @@ def __init__(
self._need_indices = need_indices
self._features_to_permute_indices = features_to_permute_indices
self._embedding_names_per_sharding = embedding_names_per_sharding
self._ctx = ctx

def _wait_impl(self) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
for w, f, e in zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
for i, (w, f, e) in enumerate(
zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
)
):
original_features = (
None
if i >= len(self._ctx.input_features)
else self._ctx.input_features[i]
)
reverse_indices = (
None
if i >= len(self._ctx.reverse_indices)
else self._ctx.reverse_indices[i]
)
jt_dict.update(
_construct_jagged_tensors(
embeddings=w.wait(),
features=f,
embedding_names=e,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
original_features=original_features,
reverse_indices=reverse_indices,
)
)
return jt_dict
Expand Down Expand Up @@ -360,6 +405,8 @@ def __init__(
self._features_order: List[int] = []

self._has_uninitialized_input_dist: bool = True
self._ec_index_dedup: bool = get_ec_index_dedup()
logger.info(f"EC index dedup enabled: {self._ec_index_dedup}.")

# Get all fused optimizers and combine them.
optims = []
Expand Down Expand Up @@ -621,6 +668,54 @@ def _generate_permute_indices_per_feature(
else:
self._features_to_permute_indices[feature_name] = permute_indices

def _create_hash_size_info(
self,
feature_names: List[str],
) -> None:
feature_index = 0
for i, lookup in enumerate(self._lookups):
feature_hash_size: List[int] = []
feature_hash_size_lengths: List[int] = []
for group_config in lookup.grouped_configs:
for table in group_config.embedding_tables:
table_hash_size = [0] * table.num_features()
table_hash_size[-1] = table.num_embeddings
feature_hash_size.extend(table_hash_size)

table_hash_size = [0] * table.num_features()
table_hash_size[0] = table.num_features()
feature_hash_size_lengths.extend(table_hash_size)

# Sanity check for feature orders
for f in range(table.num_features()):
assert (
feature_names[feature_index + f] == table.feature_names[f]
)
feature_index += table.num_features()

feature_hash_size_cumsum: List[int] = [0] + list(
accumulate(feature_hash_size)
)
feature_hash_size_offset: List[int] = [0] + list(
accumulate(feature_hash_size_lengths)
)

# Register buffers for this shard
self.register_buffer(
f"_hash_size_cumsum_tensor_{i}",
torch.tensor(
feature_hash_size_cumsum, device=self._device, dtype=torch.int64
),
persistent=False,
)
self.register_buffer(
f"_hash_size_offset_tensor_{i}",
torch.tensor(
feature_hash_size_offset, device=self._device, dtype=torch.int64
),
persistent=False,
)

def _create_input_dist(
self,
input_feature_names: List[str],
Expand All @@ -645,6 +740,9 @@ def _create_input_dist(
persistent=False,
)

if self._ec_index_dedup:
self._create_hash_size_info(feature_names)

def _create_lookups(self) -> None:
for sharding in self._sharding_type_to_sharding.values():
self._lookups.append(sharding.create_lookup())
Expand All @@ -670,9 +768,40 @@ def input_dist(
self._features_order,
self._features_order_tensor,
)
features_by_shards = features.split(

input_feature_splits = features.split(
self._feature_splits,
)

if not self._ec_index_dedup:
features_by_shards = input_feature_splits
else:
features_by_shards = []
for i, input_feature in enumerate(input_feature_splits):
hash_size_cumsum = self.get_buffer(f"_hash_size_cumsum_tensor_{i}")
hash_size_offset = self.get_buffer(f"_hash_size_offset_tensor_{i}")
(
lengths,
offsets,
unique_indices,
reverse_indices,
) = torch.ops.fbgemm.jagged_unique_indices(
hash_size_cumsum,
hash_size_offset,
input_feature.offsets().to(torch.int64),
input_feature.values().to(torch.int64),
)
dedup_features = KeyedJaggedTensor(
keys=input_feature.keys(),
lengths=lengths,
offsets=offsets,
values=unique_indices,
)

ctx.input_features.append(input_feature)
ctx.reverse_indices.append(reverse_indices)
features_by_shards.append(dedup_features)

awaitables = []
for input_dist, features in zip(self._input_dists, features_by_shards):
awaitables.append(input_dist(features))
Expand Down Expand Up @@ -725,6 +854,7 @@ def output_dist(
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
ctx=ctx,
)

def compute_and_output_dist(
Expand Down Expand Up @@ -757,6 +887,7 @@ def compute_and_output_dist(
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
ctx=ctx,
)

def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
Expand Down

0 comments on commit ed1a11b

Please sign in to comment.