Skip to content

Commit

Permalink
debug 2c
Browse files Browse the repository at this point in the history
fbshipit-source-id: 409d502ab179c1e8c77b283af4ce0aab47c7a6a5
  • Loading branch information
s4ayub authored and facebook-github-bot committed Feb 1, 2024
1 parent 7a495ea commit 2d52760
Showing 1 changed file with 12 additions and 36 deletions.
48 changes: 12 additions & 36 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import torch
from hypothesis import given, settings

from tgif.lib.generate_model_package_pass import generate_submodule_to_device_str
from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
Expand Down Expand Up @@ -66,10 +68,7 @@ def placement_helper(device_type: str, index: int = 0) -> str:


class InferShardingsTest(unittest.TestCase):
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -132,10 +131,7 @@ def test_rw(self, weight_dtype: torch.dtype) -> None:
ShardingType.ROW_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
test_permute=st.booleans(),
Expand Down Expand Up @@ -238,10 +234,7 @@ def test_cw(self, test_permute: bool, weight_dtype: torch.dtype) -> None:
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
emb_dim=st.sampled_from([192, 128]),
Expand Down Expand Up @@ -325,10 +318,7 @@ def test_cw_with_smaller_emb_dim(
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -420,10 +410,7 @@ def test_cw_multiple_tables_with_permute(self, weight_dtype: torch.dtype) -> Non
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -512,10 +499,7 @@ def test_cw_irregular_shard_placement(self, weight_dtype: torch.dtype) -> None:
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -613,10 +597,7 @@ def test_cw_sequence(self, weight_dtype: torch.dtype) -> None:
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -718,10 +699,7 @@ def test_rw_sequence(self, weight_dtype: torch.dtype) -> None:
ShardingType.ROW_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 2,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -819,10 +797,7 @@ def test_rw_uneven_sharding(
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
@unittest.skipIf(True, "")
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
Expand Down Expand Up @@ -1133,6 +1108,7 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None:
],
)

breakpoint()
# Check that FP was traced as a call_module
fp_call_module: int = 0
for node in gm.graph.nodes:
Expand Down

0 comments on commit 2d52760

Please sign in to comment.