diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 031804937..d3d52c91f 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -124,6 +124,15 @@ def _wait_impl(self) -> W: return ret +def wait_req(req: Request[W]) -> None: + if is_torchdynamo_compiling(): + assert req.tensor is not None + torch.ops._c10d_functional.wait_tensor(req.tensor) + else: + assert isinstance(req.req, dist.Work) + req.req.wait() + + @dataclass class All2AllPooledInfo(object): """ @@ -1334,7 +1343,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: myreq = ctx.myreq a2ai = myreq.a2ai assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_output = myreq.tensor dim_sum_per_rank = a2ai.dim_sum_per_rank @@ -1368,7 +1377,7 @@ def forward( a2ai = myreq.a2ai ctx.a2ai = a2ai assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) sharded_output_embeddings = myreq.tensor myreq.req = None myreq.tensor = None @@ -1573,9 +1582,9 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: myreq = ctx.myreq a2ai = myreq.a2ai assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) if isinstance(myreq.req, dist.Work): - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_output = myreq.tensor @@ -1606,7 +1615,7 @@ def forward( ctx.a2ai = a2ai assert myreq.req is not None if isinstance(myreq.req, dist.Work): - myreq.req.wait() + wait_req(myreq) sharded_output_embeddings = myreq.tensor myreq.req = None myreq.tensor = None @@ -1797,7 +1806,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: a2ai.permuted_lengths_after_sparse_data_all2all ) assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) sharded_grad_input = myreq.tensor if a2ai.codecs is not None: codecs = none_throws(a2ai.codecs) @@ -1845,7 +1854,7 @@ def forward( D = a2ai.embedding_dim ctx.a2ai = a2ai assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None sharded_output_embeddings = myreq.tensor myreq.tensor = None @@ -1952,7 +1961,7 @@ def forward( def backward(ctx, *grad_output): a2ai = ctx.a2ai myreq = ctx.myreq - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_input = myreq.tensor if a2ai.codecs is not None: @@ -1980,7 +1989,7 @@ def forward( a2ai = myreq.a2ai ctx.a2ai = a2ai assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None output = myreq.tensor myreq.tensor = None @@ -2067,7 +2076,7 @@ def forward( def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]: myreq = ctx.myreq assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_inputs = list(myreq.tensor) rsi = myreq.rsi @@ -2095,7 +2104,7 @@ def forward( *dummy_tensor: Tensor, ) -> Tensor: assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None output = myreq.tensor myreq.tensor = None @@ -2174,7 +2183,7 @@ def forward( # pyre-fixme[2]: Parameter must be annotated. def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]: myreq = ctx.myreq - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_inputs = myreq.tensor rsi = myreq.rsi @@ -2199,7 +2208,7 @@ def forward( *dummy_Tensor: Tensor, ) -> Tensor: assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None output = myreq.tensor myreq.tensor = None @@ -2270,7 +2279,7 @@ def forward( def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]: myreq = ctx.myreq assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None agi = myreq.agi grad_input = myreq.tensor @@ -2296,7 +2305,7 @@ def forward( *dummy_tensor: Tensor, ) -> Tensor: assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None outputs = myreq.tensor myreq.tensor = None @@ -2382,7 +2391,7 @@ def forward( def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]: myreq = ctx.myreq assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None grad_input = myreq.tensor rsi = myreq.rsi @@ -2407,7 +2416,7 @@ def forward( *dummy_tensor: Tensor, ) -> Tensor: assert myreq.req is not None - myreq.req.wait() + wait_req(myreq) myreq.req = None # pyre-ignore output: torch.Tensor = myreq.tensor