From 2d52760b540752386124c686625ee5df435d16d9 Mon Sep 17 00:00:00 2001 From: shababayub Date: Wed, 31 Jan 2024 17:15:43 -0800 Subject: [PATCH] debug 2c fbshipit-source-id: 409d502ab179c1e8c77b283af4ce0aab47c7a6a5 --- .../distributed/tests/test_infer_shardings.py | 48 +++++-------------- 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 9a28b0c6f..31c2f72c9 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -14,6 +14,8 @@ import torch from hypothesis import given, settings + +from tgif.lib.generate_model_package_pass import generate_submodule_to_device_str from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology @@ -66,10 +68,7 @@ def placement_helper(device_type: str, index: int = 0) -> str: class InferShardingsTest(unittest.TestCase): - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-ignore @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -132,10 +131,7 @@ def test_rw(self, weight_dtype: torch.dtype) -> None: ShardingType.ROW_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. @given( test_permute=st.booleans(), @@ -238,10 +234,7 @@ def test_cw(self, test_permute: bool, weight_dtype: torch.dtype) -> None: ShardingType.COLUMN_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. @given( emb_dim=st.sampled_from([192, 128]), @@ -325,10 +318,7 @@ def test_cw_with_smaller_emb_dim( ShardingType.COLUMN_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-ignore @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -420,10 +410,7 @@ def test_cw_multiple_tables_with_permute(self, weight_dtype: torch.dtype) -> Non ShardingType.COLUMN_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 3, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-ignore @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -512,10 +499,7 @@ def test_cw_irregular_shard_placement(self, weight_dtype: torch.dtype) -> None: gm_script_output = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-ignore @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -613,10 +597,7 @@ def test_cw_sequence(self, weight_dtype: torch.dtype) -> None: ShardingType.COLUMN_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-ignore @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -718,10 +699,7 @@ def test_rw_sequence(self, weight_dtype: torch.dtype) -> None: ShardingType.ROW_WISE.value, ) - @unittest.skipIf( - torch.cuda.device_count() <= 2, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -819,10 +797,7 @@ def test_rw_uneven_sharding( gm_script_output = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) - @unittest.skipIf( - torch.cuda.device_count() <= 3, - "Not enough GPUs available", - ) + @unittest.skipIf(True, "") # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. @given( weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), @@ -1133,6 +1108,7 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None: ], ) + breakpoint() # Check that FP was traced as a call_module fp_call_module: int = 0 for node in gm.graph.nodes: