Skip to content

Commit

Permalink
removed contiguous check.
Browse files Browse the repository at this point in the history
  • Loading branch information
ganeshcolfax committed Jul 19, 2024
1 parent cdc966e commit d5c2d1a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions hopper/benchmark_flash_attention_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def time_fwd(func, *args, **kwargs):
# dtype = torch.float16
dtype = torch.float8_e4m3fn

#bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)]
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
#bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)]
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
# bs_seqlen_vals = [(4, 8448)]
causal_vals = [False, True]
Expand Down
8 changes: 4 additions & 4 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def _flash_attn_forward(q, k, v, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
#maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
#q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
k,
Expand All @@ -39,9 +39,9 @@ def _flash_attn_backward(
softmax_scale,
causal
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
#maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
#dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
dout,
q,
Expand Down
2 changes: 1 addition & 1 deletion hopper/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_flash_attn_output_fp8(
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.float16, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True)
out, lse = flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype).transpose(1,3).clone(), causal=causal)
out, lse = flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype).transpose(1,3).contiguous().clone(), causal=causal)
q = q.to(dtype).to(torch.float16)
k = k.to(dtype).to(torch.float16)
v = v.to(dtype).to(torch.float16)
Expand Down

0 comments on commit d5c2d1a

Please sign in to comment.