Skip to content

Commit

Permalink
migrate from deprecated torch dist collectives (#2497)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2497

comm_ops currently uses the deprecated torch dist collective ops, which throw deprecation warnings when invoked. Migrate use cases to new methods to address warnings

Differential Revision: D64702417

fbshipit-source-id: 796d26fb034c583a99b88dd6e30324c8f3c41553
  • Loading branch information
Matthew Murphy authored and facebook-github-bot committed Oct 22, 2024
1 parent 0bf9802 commit 7cd3483
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,8 +2153,8 @@ def forward(
if rsi.codecs is not None:
inputs = rsi.codecs.forward.encode(inputs)
output = inputs.new_empty((inputs.size(0) // my_size, inputs.size(1)))
with record_function("## reduce_scatter_base ##"):
req = dist._reduce_scatter_base(
with record_function("## reduce_scatter_tensor ##"):
req = dist.reduce_scatter_tensor(
output,
inputs,
group=pg,
Expand Down Expand Up @@ -2222,7 +2222,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
grad_output = rsi.codecs.backward.encode(grad_output)
grad_inputs = grad_output.new_empty(rsi.input_sizes)
with record_function("## reduce_scatter_base_bw (all_gather) ##"):
req = dist._all_gather_base(
req = dist.all_gather_into_tensor(
grad_inputs,
grad_output.contiguous(),
group=ctx.pg,
Expand Down Expand Up @@ -2250,8 +2250,8 @@ def forward(
input = agi.codecs.forward.encode(input)

outputs = input.new_empty((input.size(0) * my_size, input.size(1)))
with record_function("## all_gather_base ##"):
req = dist._all_gather_base(
with record_function("## all_gather_into_tensor ##"):
req = dist.all_gather_into_tensor(
outputs,
input,
group=pg,
Expand Down Expand Up @@ -2319,7 +2319,7 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]:
grad_outputs = agi.codecs.backward.encode(grad_outputs)
grad_input = grad_outputs.new_empty(agi.input_size)
with record_function("## all_gather_base_bw (reduce_scatter) ##"):
req = dist._reduce_scatter_base(
req = dist.reduce_scatter_tensor(
grad_input,
grad_outputs.contiguous(),
group=ctx.pg,
Expand Down Expand Up @@ -2349,11 +2349,11 @@ def forward(

output = input.new_empty(rsi.input_sizes[my_rank])

# Use dist._reduce_scatter_base when a vector reduce-scatter is not needed
# Use dist.reduce_scatter_tensor when a vector reduce-scatter is not needed
# else use dist.reduce_scatter which internally supports vector reduce-scatter
if rsi.equal_splits:
with record_function("## reduce_scatter_base ##"):
req = dist._reduce_scatter_base(
with record_function("## reduce_scatter_tensor ##"):
req = dist.reduce_scatter_tensor(
output,
input,
group=pg,
Expand Down Expand Up @@ -2434,7 +2434,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:

if rsi.equal_splits:
with record_function("## reduce_scatter_base_bw (all_gather) ##"):
req = dist._all_gather_base(
req = dist.all_gather_into_tensor(
grad_input,
grad_output.contiguous(),
group=ctx.pg,
Expand Down

0 comments on commit 7cd3483

Please sign in to comment.