Skip to content

Commit

Permalink
[torch] Add dynamic support for tm_tensor.attention (#18527)
Browse files Browse the repository at this point in the history
The result tensor was assumed to be statically shaped. Updated to make a
dynamic result tensor for the destination.

---------

Signed-off-by: Rob Suderman <[email protected]>
  • Loading branch information
rsuderman committed Sep 16, 2024
1 parent 898a95f commit 7d823d2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,21 @@ struct AttentionOpConversion
Value value = op.getValue();

ShapedType outputType = op.getOutputType();

SmallVector<Value> dynSizes;
for (int i = 0, s = outputType.getRank() - 1; i < s; ++i) {
if (outputType.isDynamicDim(i)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
}
}

if (outputType.getShape().back() == ShapedType::kDynamic) {
dynSizes.push_back(
rewriter.create<tensor::DimOp>(loc, value, outputType.getRank() - 1));
}

Value result = rewriter.create<tensor::EmptyOp>(
loc, outputType.getShape(), outputType.getElementType());
loc, outputType.getShape(), outputType.getElementType(), dynSizes);

// TODO: This is a hack. This should be replaced with a simple getScale()
// when support for scaling is plumbed to TMTensor on the torch-mlir side.
Expand Down
23 changes: 23 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/test/attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,26 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2:
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<1x3x4xf32>

// -----
func.func @attention_dyn(%arg0: tensor<?x?x4xf32>, %arg1: tensor<?x?x4xf32>, %arg2: tensor<?x?x4xf32>, %arg3: tensor<?x?x4xf32>) -> (tensor<?x?x4xf32>) {
%0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>) outs(%arg3: tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
return %0 : tensor<?x?x4xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>

// CHECK-LABEL: func.func @attention_dyn(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x4xf32>, %[[ARG1:.*]]: tensor<?x?x4xf32>, %[[ARG2:.*]]: tensor<?x?x4xf32>,
// CHECK: %arg3: tensor<?x?x4xf32>) -> tensor<?x?x4xf32> {
// CHECK-DAG: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x4xf32>

0 comments on commit 7d823d2

Please sign in to comment.