From c62554ada5e92845be3ecef2376010dfc9aee3d8 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Mon, 29 Jul 2024 23:38:07 -0700 Subject: [PATCH] Revise skip torchrec logic (#2240) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2240 X-link: https://github.com/pytorch/pytorch/pull/130783 The previous logic adds skipped files when the file was imported which happens at very early stage. However, we could set skip_torchrec at later stage (e.g, in APS, we set it during the trainer execution). In that case, the skip logic will still take effect since skipped files have been added. So in this diff, we revise the logic so that it can adapt to changes of skip_torchrec at later stages. Since we revise the skip torchrec logic, we also update related torchrec tests. Reviewed By: IvanKobzarev, yanboliang Differential Revision: D59779153 --- .../distributed/test_utils/infer_utils.py | 27 ----------- torchrec/distributed/tests/test_comm.py | 13 ++++-- torchrec/distributed/tests/test_pt2.py | 46 +++++++++++++++---- .../tests/test_pt2_multiprocess.py | 15 +++--- 4 files changed, 55 insertions(+), 46 deletions(-) diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 6dad40828..eb3456081 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -977,33 +977,6 @@ def assert_weight_spec( assert wqcomp_spec.sharding_type == sharding_type -# TODO(ivankobzarev): Remove once torchrec is not in dynamo skipfiles -@contextmanager -# pyre-ignore -def dynamo_skipfiles_allow(exclude_from_skipfiles_pattern: str): - replaced: bool = False - try: - # Temporary wrapping, as preparation for removal of trace_rules.FBCODE_SKIP_DIRS_RE - # Remove dynamo_skipfiles_allow once trace_rules.FBCODE_SKIP_DIRS removed - original_FBCODE_SKIP_DIRS_RE = copy.deepcopy(trace_rules.FBCODE_SKIP_DIRS_RE) - new_FBCODE_SKIP_DIRS = { - s - for s in trace_rules.FBCODE_SKIP_DIRS - if exclude_from_skipfiles_pattern not in s - } - trace_rules.FBCODE_SKIP_DIRS_RE = re.compile( - # pyre-ignore - f".*({'|'.join(map(re.escape, new_FBCODE_SKIP_DIRS))})" - ) - replaced = True - except Exception: - pass - yield - if replaced: - # pyre-ignore - trace_rules.FBCODE_SKIP_DIRS_RE = original_FBCODE_SKIP_DIRS_RE - - class MockTBE(nn.Module): def __init__( self, diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py index 24db1518b..d427bfbb8 100644 --- a/torchrec/distributed/tests/test_comm.py +++ b/torchrec/distributed/tests/test_comm.py @@ -23,9 +23,10 @@ import torchrec.distributed.comm_ops as comm_ops from hypothesis import given, settings from torch.distributed.distributed_c10d import GroupMember -from torchrec.distributed.test_utils.infer_utils import dynamo_skipfiles_allow from torchrec.test_utils import get_free_port, seed_and_log +torch.ops.import_module("fbgemm_gpu.sparse_ops") + @dataclass class _CompileConfig: @@ -128,7 +129,10 @@ def _test_async_sync_compile( if compile_config.backend is not None: fn_transform = compile_config_to_fn_transform(compile_config) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): if compile_config.test_compiled_with_noncompiled_ranks and rank == 1: # Turn off compilation for rank==1 to test compatibility of compiled rank and non-compiled fn_transform = lambda x: x @@ -241,7 +245,10 @@ def fn(*args, **kwargs) -> List[torch.Tensor]: fn_transform = compile_config_to_fn_transform(compile_config) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): v_embs_out = fn_transform(fn)( input_embeddings, out_split=out_split, group=pg if specify_pg else None ) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 0ab5ce1e5..976e6180a 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -53,7 +53,6 @@ from torchrec.distributed.test_utils.infer_utils import ( assert_close, create_test_model_ebc_only, - dynamo_skipfiles_allow, KJTInputExportWrapper, prep_inputs, replace_registered_tbes_with_mock_tbes, @@ -241,7 +240,10 @@ def _test_compile_fwd_bwd( eager_loss.backward() eager_bwd_out = _grad_detach_clone(eager_input) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -281,7 +283,10 @@ def _test_kjt_input_module( test_aot_inductor: bool = True, test_pt2_ir_export: bool = False, ) -> None: - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys()) em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs) eager_output = EM(*em_inputs) @@ -332,7 +337,10 @@ def _test_kjt_input_module_dynamo_compile( inputs, backend: str = "eager", ) -> None: - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): EM: torch.nn.Module = KJTInputExportWrapperWithStrides( kjt_input_module, kjt_keys ) @@ -442,7 +450,10 @@ def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: torch._dynamo.decorators.mark_dynamic(t, 1) eager_output = m(inputs) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch_compile_backend = "eager" torch._dynamo.config.capture_scalar_outputs = True @@ -581,7 +592,10 @@ def kjt_to_inputs(kjt): device: str = "cuda" - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): tracing_values = kjt.values() tracing_lengths = kjt.lengths() torch._dynamo.mark_dynamic(tracing_values, 0) @@ -690,7 +704,10 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None: ) def test_kjt_values_specialization(self): - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): from torch._dynamo.testing import CompileCounter kjt0 = KeyedJaggedTensor( @@ -722,7 +739,10 @@ def f(kjt): self.assertEqual(counter.frame_count, 1) def test_kjt_values_specialization_utils(self): - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): from torch._dynamo.testing import CompileCounter kjt0 = KeyedJaggedTensor( @@ -890,7 +910,10 @@ def get_weights(m): # COMPILE orig_compile_weights = get_weights(m_compile) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -969,7 +992,10 @@ def get_weights(m): # COMPILE orig_compile_weights = get_weights(m_compile) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 9960db8b8..77d8eff7d 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -40,10 +40,7 @@ from torchrec.distributed.planner.types import ShardingPlan from torchrec.distributed.sharding_plan import EmbeddingBagCollectionSharder -from torchrec.distributed.test_utils.infer_utils import ( - dynamo_skipfiles_allow, - TestModelInfo, -) +from torchrec.distributed.test_utils.infer_utils import TestModelInfo from torchrec.distributed.test_utils.multi_process import ( MultiProcessContext, @@ -424,7 +421,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: ##### COMPILE ##### run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"] - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( @@ -457,7 +457,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: ##### COMPILE END ##### ##### NUMERIC CHECK ##### - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): for i in range(n_extra_numerics_checks): local_model_input = ins[1 + i][rank].to(device) kjt = local_model_input.idlist_features