diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 6dad40828..eb3456081 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -977,33 +977,6 @@ def assert_weight_spec( assert wqcomp_spec.sharding_type == sharding_type -# TODO(ivankobzarev): Remove once torchrec is not in dynamo skipfiles -@contextmanager -# pyre-ignore -def dynamo_skipfiles_allow(exclude_from_skipfiles_pattern: str): - replaced: bool = False - try: - # Temporary wrapping, as preparation for removal of trace_rules.FBCODE_SKIP_DIRS_RE - # Remove dynamo_skipfiles_allow once trace_rules.FBCODE_SKIP_DIRS removed - original_FBCODE_SKIP_DIRS_RE = copy.deepcopy(trace_rules.FBCODE_SKIP_DIRS_RE) - new_FBCODE_SKIP_DIRS = { - s - for s in trace_rules.FBCODE_SKIP_DIRS - if exclude_from_skipfiles_pattern not in s - } - trace_rules.FBCODE_SKIP_DIRS_RE = re.compile( - # pyre-ignore - f".*({'|'.join(map(re.escape, new_FBCODE_SKIP_DIRS))})" - ) - replaced = True - except Exception: - pass - yield - if replaced: - # pyre-ignore - trace_rules.FBCODE_SKIP_DIRS_RE = original_FBCODE_SKIP_DIRS_RE - - class MockTBE(nn.Module): def __init__( self, diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py index 24db1518b..d427bfbb8 100644 --- a/torchrec/distributed/tests/test_comm.py +++ b/torchrec/distributed/tests/test_comm.py @@ -23,9 +23,10 @@ import torchrec.distributed.comm_ops as comm_ops from hypothesis import given, settings from torch.distributed.distributed_c10d import GroupMember -from torchrec.distributed.test_utils.infer_utils import dynamo_skipfiles_allow from torchrec.test_utils import get_free_port, seed_and_log +torch.ops.import_module("fbgemm_gpu.sparse_ops") + @dataclass class _CompileConfig: @@ -128,7 +129,10 @@ def _test_async_sync_compile( if compile_config.backend is not None: fn_transform = compile_config_to_fn_transform(compile_config) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): if compile_config.test_compiled_with_noncompiled_ranks and rank == 1: # Turn off compilation for rank==1 to test compatibility of compiled rank and non-compiled fn_transform = lambda x: x @@ -241,7 +245,10 @@ def fn(*args, **kwargs) -> List[torch.Tensor]: fn_transform = compile_config_to_fn_transform(compile_config) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): v_embs_out = fn_transform(fn)( input_embeddings, out_split=out_split, group=pg if specify_pg else None ) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 0ab5ce1e5..976e6180a 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -53,7 +53,6 @@ from torchrec.distributed.test_utils.infer_utils import ( assert_close, create_test_model_ebc_only, - dynamo_skipfiles_allow, KJTInputExportWrapper, prep_inputs, replace_registered_tbes_with_mock_tbes, @@ -241,7 +240,10 @@ def _test_compile_fwd_bwd( eager_loss.backward() eager_bwd_out = _grad_detach_clone(eager_input) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -281,7 +283,10 @@ def _test_kjt_input_module( test_aot_inductor: bool = True, test_pt2_ir_export: bool = False, ) -> None: - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys()) em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs) eager_output = EM(*em_inputs) @@ -332,7 +337,10 @@ def _test_kjt_input_module_dynamo_compile( inputs, backend: str = "eager", ) -> None: - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): EM: torch.nn.Module = KJTInputExportWrapperWithStrides( kjt_input_module, kjt_keys ) @@ -442,7 +450,10 @@ def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: torch._dynamo.decorators.mark_dynamic(t, 1) eager_output = m(inputs) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch_compile_backend = "eager" torch._dynamo.config.capture_scalar_outputs = True @@ -581,7 +592,10 @@ def kjt_to_inputs(kjt): device: str = "cuda" - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): tracing_values = kjt.values() tracing_lengths = kjt.lengths() torch._dynamo.mark_dynamic(tracing_values, 0) @@ -690,7 +704,10 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None: ) def test_kjt_values_specialization(self): - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): from torch._dynamo.testing import CompileCounter kjt0 = KeyedJaggedTensor( @@ -722,7 +739,10 @@ def f(kjt): self.assertEqual(counter.frame_count, 1) def test_kjt_values_specialization_utils(self): - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): from torch._dynamo.testing import CompileCounter kjt0 = KeyedJaggedTensor( @@ -890,7 +910,10 @@ def get_weights(m): # COMPILE orig_compile_weights = get_weights(m_compile) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -969,7 +992,10 @@ def get_weights(m): # COMPILE orig_compile_weights = get_weights(m_compile) - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 9960db8b8..77d8eff7d 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -40,10 +40,7 @@ from torchrec.distributed.planner.types import ShardingPlan from torchrec.distributed.sharding_plan import EmbeddingBagCollectionSharder -from torchrec.distributed.test_utils.infer_utils import ( - dynamo_skipfiles_allow, - TestModelInfo, -) +from torchrec.distributed.test_utils.infer_utils import TestModelInfo from torchrec.distributed.test_utils.multi_process import ( MultiProcessContext, @@ -424,7 +421,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: ##### COMPILE ##### run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"] - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( @@ -457,7 +457,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: ##### COMPILE END ##### ##### NUMERIC CHECK ##### - with dynamo_skipfiles_allow("torchrec"): + with unittest.mock.patch( + "torch._dynamo.config.skip_torchrec", + False, + ): for i in range(n_extra_numerics_checks): local_model_input = ins[1 + i][rank].to(device) kjt = local_model_input.idlist_features