From b6e784ee8fcc9405a1ccedb3f491a3b718e25f18 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 11 Oct 2024 00:47:50 -0700 Subject: [PATCH] add forward/backward test for _fbgemm_permute_pooled_embs (#2480) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2480 # context * S443491 is caused by using a customized [version](https://www.internalfb.com/code/fbsource/[552b2a3cb49a261daa48b68b3647e8a951a3aa1b]/fbcode/minimal_viable_ai/models/main_feed_mtml/pytorch_modules.py?lines=2610) (fb.permute_pooled_embs_auto_grad) of fbgemm.permute_pooled_embs_auto_grad * the fb version doesn't dispatch to autograd but relied on a bug in fbgemm.permute_pooled_embs_auto_grad, which was fixed by D48574563 * The SEV was mitigated by switching to fbgemm version: D62040883 * this diff is to add more tests regarding fbgemm.permute_pooled_embs_auto_grad # details * `permute_pooled_embs_auto_grad` is called in `_fbgemm_permute_pooled_embs` function * add forward and backward test for `_fbgemm_permute_pooled_embs` function Reviewed By: ge0405 Differential Revision: D64195848 fbshipit-source-id: 237ad75028eb9583bb02a2f305defb083f0f280d --- torchrec/sparse/tests/test_jagged_tensor.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b341cd584..782728a81 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -17,6 +17,7 @@ from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, _kt_regroup_arguments, _regroup_keyed_tensors, ComputeJTDictToKJT, @@ -2342,6 +2343,7 @@ def test_regroup_multiple_kt(self) -> None: KeyedTensor.regroup, regroup_kts, permute_multi_embedding, + _fbgemm_permute_pooled_embs, ], device_str=["cpu", "cuda", "meta"], ) @@ -2376,6 +2378,7 @@ def test_regroup_kts( KeyedTensor.regroup, regroup_kts, permute_multi_embedding, + _fbgemm_permute_pooled_embs, ], device_str=["cpu", "cuda", "meta"], ) @@ -2446,18 +2449,33 @@ def test_regroup_backward_skips_and_duplicates(self) -> None: torch.allclose(actual_kt_0_grad, expected_kt_0_grad) torch.allclose(actual_kt_1_grad, expected_kt_1_grad) - def test_regroup_backward(self) -> None: + @repeat_test( + regroup_func=[ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + _fbgemm_permute_pooled_embs, + ], + device_str=["cpu", "cuda"], + ) + def test_regroup_backward( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) kts = build_kts( dense_features=20, sparse_features=20, dim_dense=64, dim_sparse=128, batch_size=128, - device=torch.device("cpu"), + device=device, run_backward=True, ) groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) - labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + labels = torch.randint(0, 1, (128,), device=device).float() tensor_groups = KeyedTensor.regroup(kts, groups) pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) @@ -2473,7 +2491,7 @@ def test_regroup_backward(self) -> None: kts[0].values().grad = None kts[1].values().grad = None - tensor_groups = _regroup_keyed_tensors(kts, groups) + tensor_groups = regroup_func(kts, groups) pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) loss = torch.nn.functional.l1_loss(pred1, labels).sum() expected_kt_0_grad = torch.autograd.grad(