Skip to content

Commit

Permalink
[NJT]Add unit tests that cover the internal use cases using new NJT A…
Browse files Browse the repository at this point in the history
  • Loading branch information
YuqingJ authored and pytorchmergebot committed Aug 22, 2024
1 parent 1a7e8e5 commit b459ca7
Showing 1 changed file with 72 additions and 22 deletions.
94 changes: 72 additions & 22 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
onlyCUDA,
ops,
PYTORCH_CUDA_MEMCHECK,
skipCPUIf,
skipCUDAIf,
skipCUDAIfRocm,
skipMeta,
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_op_name(layout):

# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor(values):
def convert_dense_to_nested_tensor_legacy(values):
offsets = torch.arange(
0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
)
Expand All @@ -251,7 +252,7 @@ def convert_dense_to_nested_tensor(values):

# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor(
def convert_jagged_to_nested_tensor_legacy(
values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
Expand All @@ -261,10 +262,33 @@ def convert_jagged_to_nested_tensor(

# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_nt_to_jagged(nt):
def convert_nt_to_jagged_legacy(nt):
return buffer_from_jagged(nt)


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor(values):
nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
return nt


# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor(
values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
nt = torch.nested.nested_tensor_from_jagged(
values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
)
return nt


# Helper function for test_dummy_mha_with_nt
def convert_nt_to_jagged(nt):
return nt.values()


@markDynamoStrictTest
class TestNestedTensor(NestedTensorTestCase):
@parametrize("batch_size", [2, 4])
Expand Down Expand Up @@ -6677,48 +6701,73 @@ def fn(values, lengths):
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_dummy_mha_with_nt(self, device):
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2
d2 = 4
d3 = 6
d3 = 16
n_heads = 2
d_head = d3 // n_heads
max_length_1 = 10
max_length_2 = 20
torch.manual_seed(0)

class mha(torch.nn.Module):
def __init__(self) -> None:
def __init__(self, use_legacy_api) -> None:
super().__init__()
torch.manual_seed(0)
self.linear = torch.nn.Linear(d2, d3, device=device)
self.use_legacy_api = use_legacy_api

def forward(self, query, value, offsets):
value = self.linear(value)
key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
value = convert_jagged_to_nested_tensor(value, offsets, max_length_2)
query = convert_dense_to_nested_tensor(query)
if self.use_legacy_api:
key = convert_jagged_to_nested_tensor_legacy(
value, offsets, max_length_1
)
value = convert_jagged_to_nested_tensor_legacy(
value, offsets, max_length_2
)
query = convert_dense_to_nested_tensor_legacy(query)
else:
key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
value = convert_jagged_to_nested_tensor(
value, offsets, max_length_2
)
query = convert_dense_to_nested_tensor(query)
q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)

with torch.nn.attention.sdpa_kernel(
[
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
]
):
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2)
attn_output = convert_nt_to_jagged(attn_output)
if self.use_legacy_api:
attn_output = convert_nt_to_jagged_legacy(attn_output)
else:
attn_output = convert_nt_to_jagged(attn_output)
return attn_output, key._max_seqlen, value._max_seqlen

query = torch.rand(bs, d1, d3, device=device)
value = torch.rand(6, d2, requires_grad=True, device=device)
offsets = torch.tensor([0, 2, 3, 6], device=device)
value = torch.rand(30, d2, requires_grad=True, device=device)
# total_length must > than max_length otherwise flash_attn backwark will fail
offsets = torch.tensor([0, 2, 3, 30], device=device)

m = mha()
m = mha(use_legacy_api)
symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
m = torch.compile(symbolic_traced)
attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
Expand All @@ -6736,7 +6785,8 @@ def forward(self, query, value, offsets):
self.assertEqual(cached_value_max_seqlen, max_length_2)

# check if the output is numerically equivalent with the eager mode
m_eager = mha()
m_eager = mha(use_legacy_api)

value.grad = None
attn_output_eager, _, _ = m_eager(query, value, offsets)
attn_output_eager.sum().backward()
Expand Down

0 comments on commit b459ca7

Please sign in to comment.