diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index db6182caa..69058b0f6 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -31,9 +31,17 @@ Union, ) +import torch from torch import distributed as dist -from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.immutable_collections import ( immutable_dict as fx_immutable_dict,