Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FlashMask bidirectional attention function and improve performance #68381

Merged
merged 7 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions paddle/fluid/pybind/cuda_streams_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,32 @@ void BindCudaStream(py::module *m_ptr) {
>>> event = paddle.device.cuda.Event()
>>> is_done = event.query()

)DOC")
.def(
"elapsed_time",
[](phi::CudaEvent &self, phi::CudaEvent &end_event) {
return self.ElapsedTime(&end_event);
},
R"DOC(
Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.

Returns: A int which indicates the elapsed time.

Examples:
.. code-block:: python

>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle

>>> paddle.set_device('gpu')
>>> e1 = paddle.device.Event(enable_timing=True)
>>> e1.record()

>>> e2 = paddle.device.Event(enable_timing=True)
>>> e2.record()
>>> e1.elapsed_time(e2)

)DOC")
.def(
"synchronize",
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/api/profiler/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ class CudaEvent {
return false;
}

float ElapsedTime(CudaEvent *end_event) {
float milliseconds = 0;
#ifdef PADDLE_WITH_HIP
hipEventSynchronize(end_event->GetRawCudaEvent());
PADDLE_ENFORCE_GPU_SUCCESS(hipEventElapsedTime(
&milliseconds, event_, end_event->GetRawCudaEvent()));
#else
cudaEventSynchronize(end_event->GetRawCudaEvent());
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(
&milliseconds, event_, end_event->GetRawCudaEvent()));
#endif
return milliseconds;
}

void Synchronize() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
Expand Down
21 changes: 16 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,9 @@ void FlashAttnGradBaseKernel(
int num_splits = get_num_split();

DenseTensor flashmask_maxmin, downstart_row_indices, upend_row_indices,
downend_row_indices;
downend_row_indices, upstart_row_indices;
void *downstart_row_indices_data = nullptr, *upend_row_indices_data = nullptr,
*downend_row_indices_data = nullptr;
*downend_row_indices_data = nullptr, *upstart_row_indices_data = nullptr;
bool is_flashmask = params.startend_row_indices != nullptr;
if (is_flashmask) {
PADDLE_ENFORCE_EQ(
Expand All @@ -652,10 +652,11 @@ void FlashAttnGradBaseKernel(
"[batch_size, num_heads,seq_len, mask_bounds]"));
PADDLE_ENFORCE_EQ(
startend_row_indices->dims()[3] == 1 ||
startend_row_indices->dims()[3] == 2,
startend_row_indices->dims()[3] == 2 ||
startend_row_indices->dims()[3] == 4,
true,
phi::errors::InvalidArgument("flashmask_attention startend_row_indices "
"mask_bounds in [1,2] are supported now"));
"mask_bounds must in [1, 2, 4]"));
auto flashmask_maxmin_shape = params.startend_row_indices->dims();
flashmask_maxmin_shape[2] = (flashmask_maxmin_shape[2] + 31) / 32 * 8;
flashmask_maxmin.set_type(phi::DataType::INT32);
Expand All @@ -675,6 +676,16 @@ void FlashAttnGradBaseKernel(
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
downend_row_indices_data = downend_row_indices.data();
}
} else if (startend_row_indices->dims()[3] == 4) {
upend_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {3}, {4});
upend_row_indices_data = upend_row_indices.data();
downend_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
downend_row_indices_data = downend_row_indices.data();
upstart_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {2}, {3});
upstart_row_indices_data = upstart_row_indices.data();
}
}

Expand Down Expand Up @@ -715,7 +726,7 @@ void FlashAttnGradBaseKernel(
is_flashmask ? params.startend_row_indices_dims.data() : nullptr,
is_flashmask ? upend_row_indices_data : nullptr,
is_flashmask ? downend_row_indices_data : nullptr,
nullptr,
is_flashmask ? upstart_row_indices_data : nullptr,
is_flashmask ? flashmask_maxmin.data() : nullptr,
q.strides()[1],
k.strides()[1],
Expand Down
21 changes: 16 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ void FlashAttnBaseKernel(
if (!out->IsInitialized()) ctx.template Alloc<T>(out);

DenseTensor flashmask_maxmin, downstart_row_indices, upend_row_indices,
downend_row_indices;
downend_row_indices, upstart_row_indices;
void *downstart_row_indices_data = nullptr, *upend_row_indices_data = nullptr,
*downend_row_indices_data = nullptr;
*downend_row_indices_data = nullptr, *upstart_row_indices_data = nullptr;
bool is_flashmask = params.startend_row_indices != nullptr;
if (is_flashmask) {
PADDLE_ENFORCE_EQ(
Expand All @@ -406,10 +406,11 @@ void FlashAttnBaseKernel(
"[batch_size, num_heads,seq_len, mask_bounds]"));
PADDLE_ENFORCE_EQ(
startend_row_indices->dims()[3] == 1 ||
startend_row_indices->dims()[3] == 2,
startend_row_indices->dims()[3] == 2 ||
startend_row_indices->dims()[3] == 4,
true,
phi::errors::InvalidArgument("flashmask_attention startend_row_indices "
"mask_bounds in [1,2] are supported now"));
"mask_bounds must in [1, 2, 4]"));
auto flashmask_maxmin_shape = params.startend_row_indices->dims();
flashmask_maxmin_shape[2] = (flashmask_maxmin_shape[2] + 31) / 32 * 8;
flashmask_maxmin.set_type(phi::DataType::INT32);
Expand All @@ -429,6 +430,16 @@ void FlashAttnBaseKernel(
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
downend_row_indices_data = downend_row_indices.data();
}
} else if (startend_row_indices->dims()[3] == 4) {
upend_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {3}, {4});
upend_row_indices_data = upend_row_indices.data();
downend_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
downend_row_indices_data = downend_row_indices.data();
upstart_row_indices =
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {2}, {3});
upstart_row_indices_data = upstart_row_indices.data();
}
}

Expand Down Expand Up @@ -466,7 +477,7 @@ void FlashAttnBaseKernel(
is_flashmask ? params.startend_row_indices_dims.data() : nullptr,
is_flashmask ? upend_row_indices_data : nullptr,
is_flashmask ? downend_row_indices_data : nullptr,
nullptr,
is_flashmask ? upstart_row_indices_data : nullptr,
is_flashmask ? flashmask_maxmin.data() : nullptr,
q.strides()[1],
k.strides()[1],
Expand Down
13 changes: 8 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ struct FlashAttnParamsBase {
startend_row_indices ? startend_row_indices.get_ptr() : nullptr,
max_seqlen_q);

PADDLE_ENFORCE_NE(attn_mask_tensor && startend_row_indices,
true,
phi::errors::InvalidArgument(
"attn_mask and attn_mask_start_row_indices cannot be "
"set at same time."));
if (startend_row_indices.is_initialized()) {
PADDLE_ENFORCE_EQ(
attn_mask_tensor,
nullptr,
phi::errors::InvalidArgument(
"attn_mask and attn_mask_start_row_indices cannot be "
"set at same time."));
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def elapsed_time(self, end_event):
>>> e1.elapsed_time(e2)

'''
return 0
return self.event_base.elapsed_time(end_event.event_base)

def synchronize(self):
'''
Expand Down
165 changes: 105 additions & 60 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,12 +1165,14 @@ def flashmask_attention(
- When `causal=False` and the shape is [batch_size, num_heads, seq_len, 2],
indicating bidirectional attention. The values represent the starting row index of the left
lower triangular mask and the ending row index of the right upper triangular mask in the dense mask. The values r1, r2 in startend_row_indices indicate that elements in the lower left triangle of the Score matrix starting from the r1-th row downwards (inclusive) will be masked, and elements in the upper right triangle starting from the r2-th row upwards (exclusive) will be masked.
- When `causal=False` and the shape is [batch_size, num_heads, seq_len, 4] (not implemented),
- When `causal=False` and the shape is [batch_size, num_heads, seq_len, 4] ,
indicating bidirectional attention. The values represent the start and end row indices of the
left lower triangular mask and the start and end row indices of the right upper triangular mask in the dense mask. The values r1, r2, r3, r4 in startend_row_indices indicate that elements in the lower left triangle of the Score matrix starting from the r1-th row downwards (inclusive) but above the r2-th row (exclusive) will be masked, and elements in the upper right triangle starting from the r3-th row downwards (inclusive) but above the r4-th row (exclusive) will be masked.
- **dropout** (float) - The dropout ratio. Default is 0.0.
- **causal** (bool) - Whether to enable causal mode. Default is False.
- **window_size* (int, optional) - Window size in sliding window attention. Default is None, will not use sliding window attention. (Not implemented!)
- **window_size** (int|tuple, optional) - Indicates the window size of sliding window local attention.
If causal mode is enabled, Query at position i will only attend to keys between [i - window_size, i] or [i - window_size[0], i].
If causal mode is disabled, Query at position i will only attend to keys between [i - window_size, i + window_size] or [i - window_size[0], i + window_size[1]].
- **return_softmax_lse** (bool) - Whether to return the log-sum-exp of the softmax. Default is False.
- **return_seed_offset** (bool) - Whether to return the random seed offset. Default is False.
- **fixed_seed_offset** (Tensor, optional): With fixed seed, offset for dropout mask.
Expand Down Expand Up @@ -1229,69 +1231,112 @@ def flashmask_attention(
>>> # doctest: -SKIP

"""
assert window_size is None, "window is not implemented now."

assert (
startend_row_indices is not None
), f"startend_row_indices must be not None, but got {startend_row_indices}"
assert (
startend_row_indices.dtype == paddle.int32
), f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
assert (
len(startend_row_indices.shape) == 4
), f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"

assert (
startend_row_indices.shape[0] == key.shape[0]
), f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"

assert (
startend_row_indices.shape[2] == key.shape[1]
), f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
assert startend_row_indices.shape[1] in [
1,
key.shape[2],
], "startend_row_indices head_num must be equal to 1(broadcast) or hean_num_k."

if causal:
if startend_row_indices.shape[-1] not in [1, 2]:
raise ValueError(
f"Invalid shape of startend_row_indices, when causal is True, the last dimension should be either 1 or 2 but got {startend_row_indices.shape[-1]}"

if window_size is not None:
if isinstance(window_size, int):
window_size = (window_size, window_size)
sq = query.shape[1]
bsz = query.shape[0]
assert (
startend_row_indices is None
), "can't use window_size with startend_row_indices"
if causal:
startend_row_indices = paddle.arange(
window_size[0] + 1, sq + window_size[0] + 1, dtype="int32"
).reshape((1, 1, sq, 1))
startend_row_indices = paddle.clip(
startend_row_indices, max=sq
).repeat_interleave(bsz, 0)

else:
startend_row_indices = paddle.empty((1, 1, sq, 2), dtype="int32")
startend_row_indices[0, 0, :, 0] = paddle.arange(
window_size[0] + 1, sq + window_size[0] + 1, dtype="int32"
)
else:
if startend_row_indices.shape[-1] == 2:
pass
elif startend_row_indices.shape[-1] == 4:
raise NotImplementedError(
"ending row index is not implemented yet."
startend_row_indices[0, 0, :, 1] = paddle.arange(
-window_size[1], sq - window_size[1], dtype="int32"
)
startend_row_indices = paddle.clip(
startend_row_indices, min=0, max=sq
).repeat_interleave(bsz, 0)

if startend_row_indices is None:
(
out,
result_softmax,
result_softmax_lse,
result_seed_offset,
) = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
None,
dropout,
causal,
False,
not training,
rng_name,
)

else:
assert (
startend_row_indices.dtype == paddle.int32
), f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}"
assert (
len(startend_row_indices.shape) == 4
), f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}"

assert (
startend_row_indices.shape[0] == key.shape[0]
), f"startend_row_indices.shape[0] must be equal to batch_size, but got {startend_row_indices.shape[0]} and {key.shape[0]}"

assert (
startend_row_indices.shape[2] == key.shape[1]
), f"startend_row_indices.shape[2] must be equal to seqlen_k, but got {startend_row_indices.shape[2]} and {key.shape[2]}"
assert startend_row_indices.shape[1] in [
1,
key.shape[2],
], "startend_row_indices head_num must be equal to 1(broadcast) or hean_num_k."

if causal:
if startend_row_indices.shape[-1] == 1:
has_end = False
elif startend_row_indices.shape[-1] == 2:
has_end = True
else:
raise ValueError(
f"Invalid shape of startend_row_indices, when causal is True, the last dimension should be either 1 or 2 but got {startend_row_indices.shape[-1]}"
)
else:
raise ValueError(
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"
)
if startend_row_indices.shape[-1] == 2:
has_end = False
elif startend_row_indices.shape[-1] == 4:
has_end = True
else:
raise ValueError(
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"
)

(
out,
result_softmax,
result_softmax_lse,
result_seed_offset,
) = _C_ops.flashmask_attention(
query,
key,
value,
startend_row_indices,
fixed_seed_offset,
dropout,
causal,
False,
not training,
rng_name,
)

return_softmax = False

(
out,
result_softmax,
result_softmax_lse,
result_seed_offset,
) = _C_ops.flashmask_attention(
query,
key,
value,
startend_row_indices,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
outputs = [out]
if return_softmax:
outputs += [result_softmax]
if return_softmax_lse:
outputs += [result_softmax_lse]
if return_seed_offset:
Expand Down
2 changes: 1 addition & 1 deletion third_party/flashattn
Submodule flashattn updated 44 files
+2 −0 .gitignore
+41 −53 csrc/CMakeLists.txt
+18 −20 csrc/capi/flash_attn.cu
+4 −18 csrc/flash_attn/flash_api.cpp
+2 −2 csrc/flash_attn/src/flash.h
+0 −20 csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
+0 −26 csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
+0 −16 csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
+0 −16 csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
+0 −16 csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
+0 −35 csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
+0 −20 csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
+0 −22 csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
+1,621 −134 csrc/flash_attn/src/flash_bwd_kernel.h
+69 −80 csrc/flash_attn/src/flash_bwd_launch_template.h
+0 −19 csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
+0 −32 csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
+0 −17 csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
+0 −27 csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
+0 −16 csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
+0 −27 csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
+0 −9 csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
+0 −9 csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
+0 −9 csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
+0 −9 csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+0 −10 csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+0 −23 csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+0 −19 csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+0 −26 csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+0 −17 csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+0 −23 csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+1,153 −159 csrc/flash_attn/src/flash_fwd_kernel.h
+131 −145 csrc/flash_attn/src/flash_fwd_launch_template.h
+121 −0 csrc/flash_attn/src/generate_kernels.py
+28 −16 csrc/flash_attn/src/softmax.h
+1 −1 csrc/flash_attn/src/static_switch.h