Skip to content

Commit

Permalink
GRID_SHARD in planner only if specified in constraints (#2494)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2494

For a minimally intrusive change that works so users don't unexpectedly get Grid Sharding, it must be specified in parameter constraints for the sharding option to be considered. Otherwise it will not show up in sharding plans.

Reviewed By: Nayef211

Differential Revision: D64610523

fbshipit-source-id: 66dc11e1fc51e4db7b1b6ea74a76892ba556ccda
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Oct 22, 2024
1 parent 55682a9 commit 3d08f7b
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 19 deletions.
9 changes: 7 additions & 2 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,16 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
def _filter_sharding_types(
self, name: str, allowed_sharding_types: List[str]
) -> List[str]:
# GRID_SHARD is only supported if specified by user in parameter constraints
if not self._constraints or not self._constraints.get(name):
return allowed_sharding_types
return [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
constraints: ParameterConstraints = self._constraints[name]
if not constraints.sharding_types:
return allowed_sharding_types
return [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
constrained_sharding_types: List[str] = constraints.sharding_types

filtered_sharding_types = list(
Expand Down
22 changes: 22 additions & 0 deletions torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def setUp(self) -> None:
self.uniform_proposer = UniformProposer()
self.grid_search_proposer = GridSearchProposer()
self.dynamic_programming_proposer = DynamicProgrammingProposer()
self._sharding_types = [x.value for x in ShardingType]

def test_greedy_two_table(self) -> None:
tables = [
Expand All @@ -127,6 +128,17 @@ def test_greedy_two_table(self) -> None:
feature_names=["feature_1"],
),
]
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present. This means
the greedy proposer will have a different order of sharding types on each test invocation
which we cannot have a harcoded "correct" answer for. We mock the call to _filter_sharding_types
to ensure the order of the sharding types list is always the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)

model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
search_space = self.enumerator.enumerate(
Expand Down Expand Up @@ -335,6 +347,16 @@ def test_grid_search_three_table(self) -> None:
for i in range(1, 4)
]
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)
search_space = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down
32 changes: 31 additions & 1 deletion torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import unittest
from typing import cast, Dict, List, Tuple

from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import torch
import torchrec.optim as trec_optim
Expand Down Expand Up @@ -59,6 +59,7 @@ def setUp(self) -> None:
self.enumerator = EmbeddingEnumerator(
topology=self.topology, batch_size=BATCH_SIZE, estimator=self.estimator
)
self._sharding_types = [x.value for x in ShardingType]

def test_1_table_perf(self) -> None:
tables = [
Expand All @@ -70,6 +71,16 @@ def test_1_table_perf(self) -> None:
)
]
model = TestSparseNN(tables=tables, weighted_tables=[])
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)
sharding_options = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down Expand Up @@ -321,6 +332,17 @@ def test_1_table_perf_with_fp8_comm(self) -> None:
)
)

"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)

sharding_options = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down Expand Up @@ -530,6 +552,14 @@ def cacheability(self) -> float:
estimator=self.estimator,
constraints=constraints,
)
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
enumerator._filter_sharding_types = MagicMock(return_value=self._sharding_types)
model = TestSparseNN(tables=tables, weighted_tables=[])
sharding_options = enumerator.enumerate(
module=model,
Expand Down
64 changes: 48 additions & 16 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,30 @@ def test_sharding_grid(
backend=self.backend,
qcomms_config=qcomms_config,
constraints={
"table_0": ParameterConstraints(min_partition=8),
"table_1": ParameterConstraints(min_partition=12),
"table_2": ParameterConstraints(min_partition=16),
"table_3": ParameterConstraints(min_partition=20),
"table_4": ParameterConstraints(min_partition=8),
"table_5": ParameterConstraints(min_partition=12),
"weighted_table_0": ParameterConstraints(min_partition=8),
"weighted_table_1": ParameterConstraints(min_partition=12),
"table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_2": ParameterConstraints(
min_partition=16, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_3": ParameterConstraints(
min_partition=20, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_4": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_5": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
Expand Down Expand Up @@ -800,14 +816,30 @@ def test_sharding_grid_8gpu(
backend=self.backend,
qcomms_config=qcomms_config,
constraints={
"table_0": ParameterConstraints(min_partition=8),
"table_1": ParameterConstraints(min_partition=12),
"table_2": ParameterConstraints(min_partition=8),
"table_3": ParameterConstraints(min_partition=10),
"table_4": ParameterConstraints(min_partition=4),
"table_5": ParameterConstraints(min_partition=6),
"weighted_table_0": ParameterConstraints(min_partition=2),
"weighted_table_1": ParameterConstraints(min_partition=3),
"table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_2": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_3": ParameterConstraints(
min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_4": ParameterConstraints(
min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_5": ParameterConstraints(
min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_0": ParameterConstraints(
min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_1": ParameterConstraints(
min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value]
),
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
Expand Down

0 comments on commit 3d08f7b

Please sign in to comment.