Skip to content

Commit

Permalink
2024-09-14 nightly release (3c815c6)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 14, 2024
1 parent 87e3c92 commit be33e9e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
10 changes: 9 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,15 @@ def _kt_unflatten(


def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]:
return _kt_flatten(kt)[0]
_keys, _length_per_key = spec.context
# please read https://fburl.com/workplace/8bei5iju for more context,
# you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict
logger.warning(
"KT's key order might change from spec from the torch.export, this could have perf impact. "
f"{kt.keys()} vs {_keys}"
)
res = permute_multi_embedding([kt], [_keys])
return [res[0]]


# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static
Expand Down
33 changes: 29 additions & 4 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.utils._pytree as pytree
from torch.fx._pytree import tree_flatten_spec
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
Expand Down Expand Up @@ -2691,21 +2692,45 @@ def test_string_values(self) -> None:

def test_pytree(self) -> None:
tensor_list = [
torch.Tensor([[1.0, 1.0]]),
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]),
torch.Tensor([[1.0, 1.0]]).T,
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T,
]
keys = ["dense_0", "dense_1"]
kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0)

kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1)
# generate the out_spec in the torch.export run
flattened, out_spec = pytree.tree_flatten(kt)

# first element of flattened list should be the kt._values
self.assertTrue(torch.equal(flattened[0], kt.values()))
# re-construct the unflattened kt from the flattened list plus the out_spec
unflattened = pytree.tree_unflatten(flattened, out_spec)

self.assertTrue(isinstance(unflattened, KeyedTensor))
self.assertListEqual(unflattened.keys(), keys)
self.assertListEqual(unflattened._length_per_key, kt._length_per_key)

# for ir export, key order in KT could change
tensor_list = [
torch.Tensor([[2.0, 2.0], [3.0, 3.0]]).T,
torch.Tensor([[1.0, 1.0]]).T,
]
keys = ["dense_1", "dense_0"]
kt2 = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=1, key_dim=1)

# flatten the kt2 based on previously generated out_spec
# this is to mimic the exported_program module run
# the kt2 could have different key order but out_spec is the same
flattened2 = tree_flatten_spec(kt2, out_spec)

# re-construct the unflattened kt from the flattened list plus the out_spec
# the rebuilt kt2 should contain the same effective data as kt (ignoring key order)
unflattened2 = pytree.tree_unflatten(flattened2, out_spec)
self.assertTrue(isinstance(unflattened2, KeyedTensor))
self.assertSetEqual(set(unflattened.keys()), set(unflattened2.keys()))
for key in kt.keys():
torch.testing.assert_close(unflattened[key], unflattened2[key])
torch.testing.assert_close(kt[key], unflattened2[key])


class TestKeyedTensorRegroupOp(unittest.TestCase):
@repeat_test(device_str=["cpu", "meta", "cuda"])
Expand Down

0 comments on commit be33e9e

Please sign in to comment.