diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 8ff170992..729de3dec 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -10,6 +10,7 @@ import copy import unittest +from contextlib import ExitStack from dataclasses import dataclass from functools import partial from typing import cast, List, Optional, Tuple, Type, Union @@ -862,7 +863,7 @@ def custom_model_fwd( batch_size = pred.size(0) return loss, pred.expand(batch_size * 2, -1) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -2203,15 +2204,22 @@ def gpu_preproc(x: StageOut) -> StageOut: class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest): def setUp(self) -> None: super().setUp() + torch.manual_seed(42) self.pipeline_class = TrainPipelineSparseDistCompAutograd torch._dynamo.reset() counters["compiled_autograd"].clear() # Compiled Autograd don't work with Anomaly Mode torch.autograd.set_detect_anomaly(False) + self._exit_stack = ExitStack() + self._exit_stack.enter_context( + torch._dynamo.config.patch( + optimize_ddp="python_reducer_without_compiled_forward" + ), + ) def tearDown(self) -> None: - # Every single test has two captures, one for forward and one for backward - self.assertEqual(counters["compiled_autograd"]["captures"], 2) + self._exit_stack.close() + self.assertEqual(counters["compiled_autograd"]["captures"], 3) return super().tearDown() @unittest.skip("Dynamo only supports FSDP with use_orig_params=True") @@ -2219,3 +2227,14 @@ def tearDown(self) -> None: @given(execute_all_batches=st.booleans()) def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: super().test_pipelining_fsdp_pre_trace() + + @unittest.skip( + "TrainPipelineSparseDistTest.test_equal_to_non_pipelined was called from multiple different executors, which fails hypothesis HealthChek, so we skip it here" + ) + def test_equal_to_non_pipelined( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + super().test_equal_to_non_pipelined()