diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index a95555e88f6f1..85cefe7d43966 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -35,6 +35,7 @@ onlyCUDA, ops, PYTORCH_CUDA_MEMCHECK, + skipCPUIf, skipCUDAIf, skipCUDAIfRocm, skipMeta, @@ -238,7 +239,7 @@ def get_op_name(layout): # Helper function for test_dummy_mha_with_nt @torch.fx.wrap -def convert_dense_to_nested_tensor(values): +def convert_dense_to_nested_tensor_legacy(values): offsets = torch.arange( 0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device ) @@ -251,7 +252,7 @@ def convert_dense_to_nested_tensor(values): # Helper function for test_dummy_mha_with_nt @torch.fx.wrap -def convert_jagged_to_nested_tensor( +def convert_jagged_to_nested_tensor_legacy( values: torch.Tensor, offsets: torch.Tensor, max_length: int ) -> torch.Tensor: metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1} @@ -261,10 +262,33 @@ def convert_jagged_to_nested_tensor( # Helper function for test_dummy_mha_with_nt @torch.fx.wrap -def convert_nt_to_jagged(nt): +def convert_nt_to_jagged_legacy(nt): return buffer_from_jagged(nt) +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_dense_to_nested_tensor(values): + nt = torch.nested.as_nested_tensor(values, layout=torch.jagged) + return nt + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_jagged_to_nested_tensor( + values: torch.Tensor, offsets: torch.Tensor, max_length: int +) -> torch.Tensor: + nt = torch.nested.nested_tensor_from_jagged( + values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length + ) + return nt + + +# Helper function for test_dummy_mha_with_nt +def convert_nt_to_jagged(nt): + return nt.values() + + @markDynamoStrictTest class TestNestedTensor(NestedTensorTestCase): @parametrize("batch_size", [2, 4]) @@ -6677,11 +6701,13 @@ def fn(values, lengths): @skipIfTorchDynamo("compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_dummy_mha_with_nt(self, device): + @parametrize("use_legacy_api", [True, False]) + @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") + def test_dummy_mha_with_nt(self, device, use_legacy_api): bs = 3 d1 = 2 d2 = 4 - d3 = 6 + d3 = 16 n_heads = 2 d_head = d3 // n_heads max_length_1 = 10 @@ -6689,36 +6715,59 @@ def test_dummy_mha_with_nt(self, device): torch.manual_seed(0) class mha(torch.nn.Module): - def __init__(self) -> None: + def __init__(self, use_legacy_api) -> None: super().__init__() torch.manual_seed(0) self.linear = torch.nn.Linear(d2, d3, device=device) + self.use_legacy_api = use_legacy_api def forward(self, query, value, offsets): value = self.linear(value) - key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) - value = convert_jagged_to_nested_tensor(value, offsets, max_length_2) - query = convert_dense_to_nested_tensor(query) + if self.use_legacy_api: + key = convert_jagged_to_nested_tensor_legacy( + value, offsets, max_length_1 + ) + value = convert_jagged_to_nested_tensor_legacy( + value, offsets, max_length_2 + ) + query = convert_dense_to_nested_tensor_legacy(query) + else: + key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) + value = convert_jagged_to_nested_tensor( + value, offsets, max_length_2 + ) + query = convert_dense_to_nested_tensor(query) q = query.view(bs, -1, n_heads, d_head).transpose(1, 2) k = key.view(bs, -1, n_heads, d_head).transpose(1, 2) v = value.view(bs, -1, n_heads, d_head).transpose(1, 2) - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - ) + + with torch.nn.attention.sdpa_kernel( + [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + ] + ): + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) attn_output = attn_output.transpose(1, 2) - attn_output = convert_nt_to_jagged(attn_output) + if self.use_legacy_api: + attn_output = convert_nt_to_jagged_legacy(attn_output) + else: + attn_output = convert_nt_to_jagged(attn_output) return attn_output, key._max_seqlen, value._max_seqlen query = torch.rand(bs, d1, d3, device=device) - value = torch.rand(6, d2, requires_grad=True, device=device) - offsets = torch.tensor([0, 2, 3, 6], device=device) + value = torch.rand(30, d2, requires_grad=True, device=device) + # total_length must > than max_length otherwise flash_attn backwark will fail + offsets = torch.tensor([0, 2, 3, 30], device=device) - m = mha() + m = mha(use_legacy_api) symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m) m = torch.compile(symbolic_traced) attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m( @@ -6736,7 +6785,8 @@ def forward(self, query, value, offsets): self.assertEqual(cached_value_max_seqlen, max_length_2) # check if the output is numerically equivalent with the eager mode - m_eager = mha() + m_eager = mha(use_legacy_api) + value.grad = None attn_output_eager, _, _ = m_eager(query, value, offsets) attn_output_eager.sum().backward()