Skip to content

Commit

Permalink
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to …
Browse files Browse the repository at this point in the history
…default decomposition table (#3456)
  • Loading branch information
william0021224 authored Jun 14, 2024
1 parent 919b599 commit a02e14e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
5 changes: 1 addition & 4 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
"ScaledDotProductAttentionDifferentModule_basic",
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",
# Lowering not present for this case
Expand Down Expand Up @@ -731,7 +729,6 @@
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down Expand Up @@ -1978,6 +1975,7 @@
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
Expand Down Expand Up @@ -3349,7 +3347,6 @@
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _get_decomposition_table():
aten.sigmoid_backward,
aten._native_batch_norm_legit,
aten.squeeze,
aten._scaled_dot_product_flash_attention_for_cpu,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/extras/fx_decomp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
torch.ops.aten.triu.default,
torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
]


Expand Down

0 comments on commit a02e14e

Please sign in to comment.