Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchrec change for dynamic embedding #2533

Open
wants to merge 1 commit into
base: release/v0.7.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ def create_sharding_infos_by_sharding(
if parameter_sharding.compute_kernel not in [
kernel.value for kernel in EmbeddingComputeKernel
]:
raise ValueError(
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
)
compute_kernel_params = parameter_sharding.get_params_of_compute_kernel()
if "customized_compute_kernel" in compute_kernel_params and \
parameter_sharding.compute_kernel != compute_kernel_params["customized_compute_kernel"]:
raise ValueError(
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
)

param_name = "embeddings." + config.name + ".weight"
assert param_name in parameter_by_name or param_name in state_dict
Expand Down Expand Up @@ -200,6 +203,7 @@ def create_sharding_infos_by_sharding(
per_table_fused_params, parameter_sharding
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
per_table_fused_params.update(parameter_sharding.get_params_of_compute_kernel())

sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
(
Expand Down Expand Up @@ -507,12 +511,18 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_local_shards = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
customized_table_names = []
for (
table_name,
parameter_sharding,
) in self.module_sharding_plan.items():
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
continue
if parameter_sharding.compute_kernel not in [compute_kernel.value for
compute_kernel in EmbeddingComputeKernel]:
customized_table_names.append(table_name)
continue

self._model_parallel_name_to_local_shards[table_name] = []
model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
Expand All @@ -535,6 +545,8 @@ def _initialize_torch_state(self) -> None: # noqa
# save local_shards for transforming MP params to shardedTensor
for key, v in lookup.state_dict().items():
table_name = key[: -len(".weight")]
if table_name in customized_table_names:
continue
self._model_parallel_name_to_local_shards[table_name].extend(
v.local_shards()
)
Expand Down Expand Up @@ -798,7 +810,7 @@ def input_dist(
features_before_input_dist=features,
unbucketize_permute_tensor=(
input_dist.unbucketize_permute_tensor
if isinstance(input_dist, RwSparseFeaturesDist)
if hasattr(input_dist, "unbucketize_permute_tensor")
else None
),
)
Expand Down
16 changes: 16 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> None:
def _exist_customized_compute_kernel(config: GroupedEmbeddingConfig):
# only confirm that config.compute_kernel not in EmbeddingComputeKernel
exist_key = "customized_compute_kernel"
if exist_key in config.fused_params:
if config.compute_kernel == config.fused_params[exist_key]:
return True
return False

# TODO rename to _create_embedding_kernel
def _create_lookup(
config: GroupedEmbeddingConfig,
Expand All @@ -147,6 +155,14 @@ def _create_lookup(
pg=pg,
device=device,
)
elif _exist_customized_compute_kernel(config):
assert "ComputeKernel" in config.fused_params
ComputeKernel = config.fused_params["ComputeKernel"]
return ComputeKernel(
config=config,
pg=pg,
device=device,
)
else:
raise ValueError(
f"Compute kernel not supported {config.compute_kernel}"
Expand Down
36 changes: 25 additions & 11 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,31 @@ def create_input_dist(
) -> BaseSparseFeaturesDist[KeyedJaggedTensor]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()
return RwSparseFeaturesDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
)
if self._customized_dist:
return self._customized_dist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
dist_type_per_feature=self._dist_type_per_feature,
)
else:
return RwSparseFeaturesDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
)

def create_lookup(
self,
Expand Down
76 changes: 62 additions & 14 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
device = torch.device("cpu")
self._device: torch.device = device
sharded_tables_per_rank = self._shard(sharding_infos)
self._init_customized_distributor(sharding_infos)
self._need_pos = need_pos
self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = (
[]
Expand Down Expand Up @@ -161,6 +162,15 @@ def _shard(
),
)

if info.param_sharding.compute_kernel not in [
kernel.value for kernel in EmbeddingComputeKernel
]:
compute_kernel = info.param_sharding.get_params_of_compute_kernel(
)["customized_compute_kernel"]
else:
compute_kernel = EmbeddingComputeKernel(
info.param_sharding.compute_kernel
)
for rank in range(self._world_size):
tables_per_rank[rank].append(
ShardedEmbeddingTable(
Expand All @@ -175,9 +185,7 @@ def _shard(
has_feature_processor=info.embedding_config.has_feature_processor,
local_rows=shards[rank].shard_sizes[0],
local_cols=info.embedding_config.embedding_dim,
compute_kernel=EmbeddingComputeKernel(
info.param_sharding.compute_kernel
),
compute_kernel=compute_kernel,
local_metadata=shards[rank],
global_metadata=global_metadata,
weight_init_max=info.embedding_config.weight_init_max,
Expand All @@ -187,6 +195,31 @@ def _shard(
)
return tables_per_rank

def _init_customized_distributor(self, sharding_infos: List[EmbeddingShardingInfo]):
common_dist_type = None
common_customized_dist = None

self._dist_type_per_feature: Dict[str, str] = {}
for sharding_info in sharding_infos:
if "dist_type" in sharding_info.fused_params:
dist_type = sharding_info.fused_params["dist_type"]
if common_dist_type == None:
common_dist_type = dist_type
else:
assert(dist_type != common_dist_type, "Customized distributor type must keep the same.")
dist_ = sharding_info.fused_params["Distributor"]
if common_customized_dist == None:
common_customized_dist = dist_
else:
assert(common_customized_dist != dist_, "Customized distributor implementation must keep the same.")
else:
dist_type = "continuous"
feature_names = sharding_info.embedding_config.feature_names
for f in feature_names:
self._dist_type_per_feature[f] = dist_type

self._customized_dist = common_customized_dist

def embedding_dims(self) -> List[int]:
embedding_dims = []
for grouped_config in self._grouped_embedding_configs:
Expand Down Expand Up @@ -465,17 +498,32 @@ def create_input_dist(
) -> BaseSparseFeaturesDist[KeyedJaggedTensor]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()
return RwSparseFeaturesDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=False,
has_feature_processor=self._has_feature_processor,
need_pos=self._need_pos,
)

if self._customized_dist:
return self._customized_dist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=False,
has_feature_processor=self._has_feature_processor,
need_pos=self._need_pos,
dist_type_per_feature=self._dist_type_per_feature,
)
else:
return RwSparseFeaturesDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
pg=self._pg,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=False,
has_feature_processor=self._has_feature_processor,
need_pos=self._need_pos,
)

def create_lookup(
self,
Expand Down