From 7d823d228b9ed9021d4501de98cf2c462957a2f8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 16 Sep 2024 13:57:01 -0700 Subject: [PATCH] [torch] Add dynamic support for `tm_tensor.attention` (#18527) The result tensor was assumed to be statically shaped. Updated to make a dynamic result tensor for the destination. --------- Signed-off-by: Rob Suderman --- .../ConvertTMTensorToLinalgExt.cpp | 15 +++++++++++- .../Torch/InputConversion/test/attention.mlir | 23 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index 4938a07d2668..d775c1ccad5b 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -102,8 +102,21 @@ struct AttentionOpConversion Value value = op.getValue(); ShapedType outputType = op.getOutputType(); + + SmallVector dynSizes; + for (int i = 0, s = outputType.getRank() - 1; i < s; ++i) { + if (outputType.isDynamicDim(i)) { + dynSizes.push_back(rewriter.create(loc, query, i)); + } + } + + if (outputType.getShape().back() == ShapedType::kDynamic) { + dynSizes.push_back( + rewriter.create(loc, value, outputType.getRank() - 1)); + } + Value result = rewriter.create( - loc, outputType.getShape(), outputType.getElementType()); + loc, outputType.getShape(), outputType.getElementType(), dynSizes); // TODO: This is a hack. This should be replaced with a simple getScale() // when support for scaling is plumbed to TMTensor on the torch-mlir side. diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir index 865a8fbeb496..06e85e753a55 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir @@ -55,3 +55,26 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32> // CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> // CHECK: return %[[ATTN]] : tensor<1x3x4xf32> + +// ----- +func.func @attention_dyn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { + %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%arg3: tensor) -> tensor + return %0 : tensor +} + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @attention_dyn( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, +// CHECK: %arg3: tensor) -> tensor { +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor, tensor, tensor, f32) outs(%[[EMPTY]] : tensor) -> tensor +// CHECK: return %[[ATTN]] : tensor