-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] 【Infer Symbolic Shape BUAA 】Add flashmask_attention op #68385
Conversation
@@ -1543,6 +1543,84 @@ bool FlashAttnOpInferSymbolicShape( | |||
// return true; | |||
// } | |||
|
|||
bool FlashmaskAttentionOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对比下和 FlashAttnOpInferSymbolicShape 的区别,看看是不是能直接复用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_EQ( | ||
infer_context->IsEqual(startend_row_indices[3], symbol::DimExpr{1}) || | ||
infer_context->IsEqual(startend_row_indices[3], | ||
symbol::DimExpr{2}) || | ||
infer_context->IsEqual(startend_row_indices[3], symbol::DimExpr{4}), | ||
true, | ||
common::errors::InvalidArgument( | ||
"flashmask_attention startend_row_indices " | ||
"mask_bounds must in [1,2,4]")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个enforce先删掉吧,目前还不支持编译期在控制流里判等。要写的话需要先判断下startend_row_indices[3]是不是int类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
auto batch_size_expr = q.shape()[0]; | ||
auto num_heads_expr = q.shape()[2]; | ||
auto seqlen_q_rounded_expr = round_multiple(q.shape()[1]); | ||
auto seqlen_k_rounded_expr = round_multiple(k.shape()[1]); | ||
|
||
if (op->result(1)) { | ||
std::vector<symbol::DimExpr> softmax_shape{batch_size_expr, | ||
num_heads_expr, | ||
seqlen_q_rounded_expr, | ||
seqlen_k_rounded_expr}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(1), symbol::TensorShapeOrDataDimExprs(softmax_shape)); | ||
} | ||
if (op->result(2)) { | ||
std::vector<symbol::DimExpr> softmax_lse_shape{ | ||
batch_size_expr, num_heads_expr, seqlen_q_rounded_expr}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(2), symbol::TensorShapeOrDataDimExprs(softmax_lse_shape)); | ||
} | ||
if (op->result(3)) { | ||
std::vector<symbol::DimExpr> seed_offset_shape{symbol::DimExpr{2}}; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(3), symbol::TensorShapeOrDataDimExprs(out_shape)); | ||
} | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分推导逻辑没在kernel、infermeta、以及FlashAttnOpInferSymbolicShape中找到对应依据
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与刘旭东的类似
Sorry to inform you that 266fe33's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
improvements
Description
添加flashmask_attention算子符号推导接口
找到单测
/test/legacy_test/test_flashmask.py
/test/legacy_test/test_flash_attention.py
只有unnitest
Pcard-67164