Skip to content

Commit

Permalink
[mlir][linalg] Add transform operator for Winograd Conv2D algorithm (#…
Browse files Browse the repository at this point in the history
…96182)

Add a transform operation structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operations.

Reviewers: ftynse, Max191, GeorgeARM, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin

Reviewed By: ftynse, Max191

Pull Request: #96182
  • Loading branch information
Hsiangkai authored Jul 11, 2024
1 parent ddbad86 commit d9c26b9
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2646,4 +2646,55 @@ def MapCopyToThreadsOp :
}];
}

//===----------------------------------------------------------------------===//
// Winograd Conv2D
//===----------------------------------------------------------------------===//

def WinogradConv2DOp : Op<Transform_Dialect,
"structured.winograd_conv2d",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

#### Return modes:

This operation produces a silenceable failure if `target` is unsupported.
Otherwise, the operation succeeds and returns a handle of the sequence that
replaces the original convolution.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$m,
I64Attr:$r);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // LINALG_TRANSFORM_OPS
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);

/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
/// size of filter.
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);

//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,37 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// WinogradConv2DOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeTransformed = failure();
bool supported = TypeSwitch<Operation *, bool>(target)
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
maybeTransformed =
winogradConv2D(rewriter, op, getM(), getR());
return true;
})
.Default([&](Operation *op) { return false; });

if (!supported) {
return emitSilenceableError()
<< "this operation is not supported to convert to Winograd Conv2D";
}

if (supported && failed(maybeTransformed)) {
return emitSilenceableError() << "apply Winograd Conv2D failed";
}

results.push_back(*maybeTransformed);
return DiagnosedSilenceableFailure::success();
}

#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"

#define GET_OP_CLASSES
Expand Down
9 changes: 8 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
Expand Down Expand Up @@ -156,7 +158,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
auto filterType = cast<ShapedType>(filter.getType());
auto outputType = cast<ShapedType>(output.getType());

// TODO: Should we support dynamic shapes?
if (!inputType.hasStaticShape())
return rewriter.notifyMatchFailure(convOp,
"expected a static shape for the input");
Expand Down Expand Up @@ -316,6 +317,12 @@ class WinogradConv2DNhwcFhwc final
} // end anonymous namespace

//===----------------------------------------------------------------------===//
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r) {
return winogradConv2DHelper(rewriter, op, m, r);
}

void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
Expand Down
76 changes: 76 additions & 0 deletions mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s

func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %0 : tensor<2x8x8x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func.func @conv2d
// CHECK: linalg.winograd_filter_transform m(4) r(3)
// CHECK: linalg.winograd_input_transform m(4) r(3)
// CHECK: linalg.batch_matmul
// CHECK: linalg.winograd_output_transform m(4) r(3)

// -----

func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
return %0 : tensor<2x9x9x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func.func @conv2d_unaligned
// CHECK: linalg.winograd_filter_transform m(4) r(3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: linalg.winograd_input_transform m(4) r(3)
// CHECK: tensor.pad
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: linalg.winograd_output_transform m(4) r(3)

// -----

func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %0 : tensor<2x8x8x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// -----

func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
return %0 : tensor<2x?x?x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{apply Winograd Conv2D failed}}
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

0 comments on commit d9c26b9

Please sign in to comment.