-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] 【Infer Symbolic Shape BUAA 】Add flashmask_attention op #68385
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1543,6 +1543,75 @@ bool FlashAttnOpInferSymbolicShape( | |
// return true; | ||
// } | ||
|
||
bool FlashmaskAttentionOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &q = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const symbol::ShapeOrDataDimExprs &k = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
const symbol::ShapeOrDataDimExprs &v = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
|
||
PADDLE_ENFORCE_EQ(q.shape().size(), | ||
4, | ||
common::errors::InvalidArgument( | ||
"flash_attn receive input with dim " | ||
"[batch_size, seq_len, num_heads, head_dim]")); | ||
|
||
infer_context->AddEqualCstr(q.shape()[0], k.shape()[0]); | ||
infer_context->AddEqualCstr(q.shape()[0], v.shape()[0]); | ||
infer_context->AddEqualCstr(k.shape()[1], v.shape()[1]); | ||
|
||
if (op->operand_source(3)) { | ||
const std::vector<symbol::DimExpr> &startend_row_indices = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(4)).shape(); | ||
PADDLE_ENFORCE_EQ( | ||
startend_row_indices.size(), | ||
4, | ||
common::errors::InvalidArgument( | ||
"flashmask_attention receive startend_row_indices with dim " | ||
"[batch_size, num_heads,seq_len, mask_bounds]")); | ||
} | ||
std::vector<symbol::DimExpr> out_shape = q.shape(); | ||
|
||
out_shape.back() = v.shape().back(); | ||
|
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape)); | ||
|
||
// GPU has round for seqlen, but XPU has not. Here we align with the GPU | ||
// version. | ||
auto round_multiple = [](symbol::DimExpr x) { | ||
auto m = symbol::DimExpr{128}; | ||
auto m_minus_one = symbol::DimExpr{127}; | ||
return (x + m_minus_one) / m * m; | ||
}; | ||
auto batch_size_expr = q.shape()[0]; | ||
auto num_heads_expr = q.shape()[2]; | ||
auto seqlen_q_rounded_expr = round_multiple(q.shape()[1]); | ||
auto seqlen_k_rounded_expr = round_multiple(k.shape()[1]); | ||
|
||
if (op->result(1)) { | ||
std::vector<symbol::DimExpr> softmax_shape{batch_size_expr, | ||
num_heads_expr, | ||
seqlen_q_rounded_expr, | ||
seqlen_k_rounded_expr}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(1), symbol::TensorShapeOrDataDimExprs(softmax_shape)); | ||
} | ||
if (op->result(2)) { | ||
std::vector<symbol::DimExpr> softmax_lse_shape{ | ||
batch_size_expr, num_heads_expr, seqlen_q_rounded_expr}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(2), symbol::TensorShapeOrDataDimExprs(softmax_lse_shape)); | ||
} | ||
if (op->result(3)) { | ||
std::vector<symbol::DimExpr> seed_offset_shape{symbol::DimExpr{2}}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(3), symbol::TensorShapeOrDataDimExprs(out_shape)); | ||
} | ||
return true; | ||
Comment on lines
+1589
to
+1613
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分推导逻辑没在kernel、infermeta、以及FlashAttnOpInferSymbolicShape中找到对应依据 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 与刘旭东的类似 |
||
} | ||
bool FusedBatchNormActOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
return BatchNormOpInferSymbolicShape(op, infer_context); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对比下和 FlashAttnOpInferSymbolicShape 的区别,看看是不是能直接复用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个函数在qkv的处理和output的计算是一样的。
但是输入参数所取的operand位置不一样,同时输入的参数也不太相同,不能直接复用。