Skip to content

Commit

Permalink
Back out "Not dynamo trace through conditions"
Browse files Browse the repository at this point in the history
Differential Revision: D55443823
  • Loading branch information
s4ayub authored and facebook-github-bot committed Mar 27, 2024
1 parent 26b6899 commit 85f0c52
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 98 deletions.
6 changes: 1 addition & 5 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,7 @@ def reduce_scatter_v_pooled(
[ip_split if d == 0 else input_size[d] for d in range(len(input_size))]
for ip_split in input_splits
]

equal_splits = False
if not torch.compiler.is_dynamo_compiling():
# We can not check during tracing equality of splits -> fallback on general
equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits)
equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits)

rsvi = ReduceScatterVInfo(
input_sizes=input_sizes,
Expand Down
57 changes: 11 additions & 46 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,10 @@ def _get_recat(
for j in feature_order: # range(num_splits):
recat.append(i + j * local_split)

vb_condition: bool = batch_size_per_rank is not None
if not torch.compiler.is_dynamo_compiling():
vb_condition = vb_condition and any(
# pyre-ignore
bs != batch_size_per_rank[0]
# pyre-ignore
for bs in batch_size_per_rank
)

# variable batch size
if vb_condition:
if batch_size_per_rank is not None and any(
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
):
batch_size_per_feature = list(
itertools.chain.from_iterable(
itertools.repeat(x, local_split) for x in batch_size_per_rank
Expand Down Expand Up @@ -234,8 +227,6 @@ def __init__(
self._device: torch.device = device
self._input = input
self._splits = splits
self._input_splits_list = input_splits
self._output_splits_list = output_splits
self._input_splits: Dict[str, List[int]] = dict(zip(labels, input_splits))
self._output_splits: Dict[str, List[int]] = dict(zip(labels, output_splits))
self._keys = keys
Expand All @@ -253,7 +244,6 @@ def __init__(

self._output_tensors: List[torch.Tensor] = []
self._awaitables: List[dist.Work] = []
self._world_size: int = self._pg.size()

for input_split, output_split, input_tensor, label in zip(
input_splits,
Expand Down Expand Up @@ -384,20 +374,6 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
self._output_splits = output_list[:-1]
self._stride_per_rank = output_list[-1]

if torch.compiler.is_dynamo_compiling():
rank: int = self._pg.rank()
for i in range(len(self._output_splits)):
for j in range(len(self._output_splits[i])):
torch._check_is_size(self._output_splits[i][j])
torch._check(
self._output_splits[i][rank] == self._input_splits[i][rank]
)
if self._stride_per_rank is not None:
# pyre-ignore
for i in range(len(self._stride_per_rank)):
# pyre-ignore
torch._check_is_size(self._stride_per_rank[i])

return KJTAllToAllTensorsAwaitable(
pg=self._pg,
input=self._input,
Expand Down Expand Up @@ -476,7 +452,7 @@ def __init__(
stagger: int = 1,
) -> None:
super().__init__()
torch._check(len(splits) == pg.size())
assert len(splits) == pg.size()
self._pg: dist.ProcessGroup = pg
self._splits = splits
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
Expand Down Expand Up @@ -1028,25 +1004,14 @@ def forward(
PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
"""

# Dynamo can not trace through data dependent condition: len(set(input_splits)) > 1
if torch.compiler.is_dynamo_compiling():
if input_splits is not None:
tensor_awaitable = reduce_scatter_v_pooled(
local_embs, input_splits, self._pg, codecs=self._codecs
)
else:
tensor_awaitable = reduce_scatter_base_pooled(
local_embs, self._pg, codecs=self._codecs
)
if input_splits and len(set(input_splits)) > 1:
tensor_awaitable = reduce_scatter_v_pooled(
local_embs, input_splits, self._pg, codecs=self._codecs
)
else:
if input_splits and len(set(input_splits)) > 1:
tensor_awaitable = reduce_scatter_v_pooled(
local_embs, input_splits, self._pg, codecs=self._codecs
)
else:
tensor_awaitable = reduce_scatter_base_pooled(
local_embs, self._pg, codecs=self._codecs
)
tensor_awaitable = reduce_scatter_base_pooled(
local_embs, self._pg, codecs=self._codecs
)
return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable)


Expand Down
65 changes: 18 additions & 47 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import torch
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec

# pyre-ignore
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node

from torchrec.streamable import Pipelineable
Expand Down Expand Up @@ -676,10 +673,6 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
) -> None:
if torch.compiler.is_dynamo_compiling():
# Skipping assert on tensor.numel() == 0 for dynamo to avoid DataDependentError
return

assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
Expand Down Expand Up @@ -796,27 +789,15 @@ def _maybe_compute_length_per_key(
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
elif len(keys) and lengths is not None:
_length: List[int] = []
if variable_stride_per_key:
_length = _length_per_key_from_stride_per_key(lengths, stride_per_key)
else:
cond: bool = False
if (
torch.compiler.is_dynamo_compiling()
and not torch.jit.is_scripting()
):
# pyre-ignore
cond = guard_size_oblivious(lengths.numel() != 0)
else:
cond = lengths.numel() != 0

_length = (
torch.jit.annotate(
List[int], torch.sum(lengths.view(-1, stride), dim=1).tolist()
)
if cond
_length: List[int] = (
_length_per_key_from_stride_per_key(lengths, stride_per_key)
if variable_stride_per_key
else (
torch.sum(lengths.view(-1, stride), dim=1).tolist()
if lengths.numel() != 0
else [0] * len(keys)
)
)
else:
_length: List[int] = []
length_per_key = _length
Expand Down Expand Up @@ -1324,7 +1305,7 @@ def __init__(
self._stride_per_key_per_rank = stride_per_key_per_rank
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
self._variable_stride_per_key = True
if stride_per_key_per_rank is not None:
if not stride_per_key_per_rank:
self._stride = 0
elif all(s == self.stride_per_key()[0] for s in self.stride_per_key()):
self._stride = self.stride_per_key()[0]
Expand Down Expand Up @@ -2183,20 +2164,17 @@ def dist_init(
cumsum_lengths[strides_cumsum[1:]] - cumsum_lengths[strides_cumsum[:-1]]
)
with record_function("## all2all_data:recat_values ##"):
recat_cond: bool = recat is not None
if recat_cond and not is_torchdynamo_compiling():
recat_cond = torch.jit._unwrap_optional(recat).numel() > 0
if recat_cond:
if recat is not None and recat.numel() > 0:
lengths, _ = _permute_tensor_by_segments(
lengths,
stride_per_rank_per_key,
torch.jit._unwrap_optional(recat),
recat,
None,
)
values, weights = _permute_tensor_by_segments(
values,
length_per_key,
torch.jit._unwrap_optional(recat),
recat,
weights,
)
if not stride_per_key_per_rank:
Expand All @@ -2223,31 +2201,24 @@ def dist_init(
else:
assert stride_per_rank is not None
with record_function("## all2all_data:recat_values ##"):
recat_cond: bool = recat is not None
if recat_cond and not is_torchdynamo_compiling():
recat_cond = torch.jit._unwrap_optional(recat).numel() > 0

if recat_cond:
if recat is not None and recat.numel() > 0:
stride = stride_per_rank[0]

# dynamo don't handle generators well
# so had to unroll the original generator into
# this for loop.
single_batch_per_rank = False
if not is_torchdynamo_compiling():
# Dynamo symbolic shapes could not pass through s != stride condition without hints => Dynamo always use VB path
single_batch_per_rank = True
for s in stride_per_rank:
if s != stride:
single_batch_per_rank = False
single_batch_per_rank = True
for s in stride_per_rank:
if s != stride:
single_batch_per_rank = False

if single_batch_per_rank:
(
lengths,
values,
weights,
) = torch.ops.fbgemm.permute_2D_sparse_data(
torch.jit._unwrap_optional(recat),
recat,
lengths.view(-1, stride),
values,
weights,
Expand All @@ -2260,7 +2231,7 @@ def dist_init(
values,
weights,
) = torch.ops.fbgemm.permute_1D_sparse_data(
torch.jit._unwrap_optional(recat),
recat,
lengths.view(-1),
values,
weights,
Expand Down

0 comments on commit 85f0c52

Please sign in to comment.