From 15cfb9a4b85960673d49e607a988dd12e622974c Mon Sep 17 00:00:00 2001 From: Branko Trifkovic Date: Sun, 2 Jun 2024 17:40:12 +0200 Subject: [PATCH] Implement lowering of torch.aten.triu_indices --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 35 ++ .../Transforms/AbstractInterpLibrary.cpp | 80 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 301 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 33 ++ .../build_tools/torch_ods_gen.py | 5 + .../test_suite/elementwise.py | 60 ++++ 9 files changed, 548 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c22f46ebe442..221171510523 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15362,6 +15362,36 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ let hasCanonicalizer = 1; } +def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$row, + Torch_IntType:$col, + Torch_IntType:$offset, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTriuIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenTriuIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 61a0857a8894..53942db7d4a2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5088,3 +5088,38 @@ LogicalResult BindSymbolicShapeOp::verify() { return success(); } +// AtenTriuIndicesOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenTriuIndicesOp::verify() { + + // Check if row, col and offset are constant ints + int64_t row; + if (!matchPattern(getRow(), m_TorchConstantInt(&row))) + return success(); + + int64_t col; + if (!matchPattern(getCol(), m_TorchConstantInt(&col))) + return success(); + + int64_t offset; + if (!matchPattern(getOffset(), m_TorchConstantInt(&offset))) + return success(); + + // Check if values of row, and col are valid + if (row < 0) + return emitOpError("row must be non-negative, got ") << row; + + if (col < 0) + return emitOpError("col must be non-negative, got ") << col; + + // Check if dtype is valid + int64_t dtype; + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) + return success(); + if (dtype != 3 && dtype != 4) + return emitOpError( + "'triu_indices' implemented only for torch.int32 and torch.int64"); + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2eca3ab44961..c871b444fd16 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9713,6 +9713,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int, !torch.optional>, !torch.optional) -> !torch.tuple, list, list, list>\n" " return %1 : !torch.tuple, list, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.neg.int %arg2 : !torch.int -> !torch.int\n" +" %4 = torch.prim.min.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.max.int %int0, %5 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %17 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" %10:2 = torch.prim.If %9 -> (!torch.int, !torch.int) {\n" +" torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n" +" } else {\n" +" %17 = torch.aten.gt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" %33 = torch.aten.add.int %int1, %7 : !torch.int, !torch.int -> !torch.int\n" +" %34 = torch.prim.min.int %arg1, %33 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %34 : !torch.int\n" +" } else {\n" +" %33 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n" +" %34 = torch.aten.gt.int %33, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %35 = torch.aten.Int.bool %34 : !torch.bool -> !torch.int\n" +" torch.prim.If.yield %35 : !torch.int\n" +" }\n" +" %19 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.prim.min.int %arg1, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.prim.max.int %int0, %20 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.prim.min.int %arg0, %22 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.prim.max.int %int0, %23 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.sub.int %21, %18 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.add.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %27 = torch.aten.add.int %18, %21 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.mul.int %27, %26 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.floordiv.int %28, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %30 = torch.aten.sub.int %24, %26 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.mul.int %30, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %32 = torch.prim.max.int %int0, %31 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %29, %32 : !torch.int, !torch.int\n" +" }\n" +" %11 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.add.int %10#0, %10#1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %11, %12 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.sub.int %13, %6 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.add.int %6, %14 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %int2, %15 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -13936,6 +14004,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int6 = torch.constant.int 6\n" " return %int6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e1759ceb0769..74f939a57481 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -732,6 +732,306 @@ class DecomposeAtenTriuOp : public OpRewritePattern { }; } // namespace +Value arange(PatternRewriter &rewriter, Location loc, Value size, Value dtype, + Value layout, Value device, Value pin_memory, Type type) { + + auto arrangeType = getTensorTypeFromShapeValues({size}, type); + return rewriter.create(loc, arrangeType, size, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pin_memory); +} + +void _getTrilSizes(PatternRewriter &rewriter, Location loc, int64_t row, + int64_t col, int64_t offset, + /*return values*/ int64_t &trapezoidSize, + int64_t &rectangleSize) { + + // Base case + if (row == 0 || col == 0) { + trapezoidSize = 0; + rectangleSize = 0; + + return; + } + + // Calculate mFirstRow size + int64_t mFirstRow; + if (offset > 0) + mFirstRow = (col < offset + 1) ? col : offset + 1; + else + mFirstRow = (row + offset > 0) ? 1 : 0; + + // Calculate mLastRow size + int64_t minimum = (col < row + offset) ? col : row + offset; + int64_t mLastRow = (minimum > 0) ? minimum : 0; + + // Calculate nRowAll + minimum = (row < row + offset) ? row : row + offset; + int64_t nRowAll = (minimum > 0) ? minimum : 0; + + // Calucltae nRowTrapezoid + int64_t nRowTrapezoid = mLastRow - mFirstRow + 1; + + // Number of elements in top trapezoid - trapezoidSize + trapezoidSize = (mFirstRow + mLastRow) * nRowTrapezoid / 2; + + // Number of elements in bottom rectangle - rectangleSize + int64_t diffRow = nRowAll - nRowTrapezoid; + rectangleSize = (diffRow * col > 0) ? diffRow * col : 0; +} + +void _getTriuSizes(PatternRewriter &rewriter, Location loc, int64_t row, + int64_t col, int64_t offset, + /*return values*/ Value &trapezoidSize, Value &rectangleSize, + Value &mFirstRow) { + + // Constants + Value cstZero = rewriter.create(loc, 0); + + // Base case + if (row == 0 || col == 0) { + trapezoidSize = cstZero; + rectangleSize = cstZero; + mFirstRow = cstZero; + + return; + } + + // Calculate mFirstRow size + int64_t maximum = (col - offset > 0) ? col - offset : 0; + int64_t mFirstRowInt = (offset > 0) ? maximum : col; + + // Number of elements in top rectangle - calculate rectangle size + int64_t minimum = (row < -offset) ? row : -offset; + int64_t rectangleSizeInt = (minimum * col > 0) ? minimum * col : 0; + + // Number of elements in bottom trapezoid - calculte trapezoid size + int64_t trapezoidSizeTril; + int64_t rectangleSizeTril; + + _getTrilSizes(rewriter, loc, row, col, offset - 1, trapezoidSizeTril, + rectangleSizeTril); + int64_t triuSize = row * col - (trapezoidSizeTril + rectangleSizeTril); + int64_t trapezoidSizeInt = triuSize - rectangleSizeInt; + + // Create Value from int + trapezoidSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + rectangleSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + mFirstRow = rewriter.create( + loc, rewriter.getI64IntegerAttr(mFirstRowInt)); +} + +// decomposition of torch.triu_indices +// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5829 +namespace { +class DecomposeAtenTriuIndicesOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTriuIndicesOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + // Required parameters + Value row = op.getRow(); + Value col = op.getCol(); + Value offset = op.getOffset(); + + // Check if row, col and offset are constant ints + int64_t rowInt; + if (!matchPattern(row, m_TorchConstantInt(&rowInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: row not constant int"); + + int64_t colInt; + if (!matchPattern(col, m_TorchConstantInt(&colInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: col not constant int"); + + int64_t offsetInt; + if (!matchPattern(offset, m_TorchConstantInt(&offsetInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: offset not constant int"); + + // Optional parameters + Value dtype = op.getDtype(); + Value layout = op.getLayout(); + Value device = op.getDevice(); + Value pinMemory = op.getPinMemory(); + + // Get int value for dtype + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dtype not constant int"); + + FailureOr dtypeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(dtypeType)) + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + + // Constants + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstFalse = rewriter.create(loc, false); + Value cstMinusZeroPointFive = rewriter.create( + loc, rewriter.getF64FloatAttr(-0.5)); + Value cstMinusTwoFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(-2.0)); + + // Calculte trapezoidSize, rectangleSize and mFirstRow + Value trapezoidSize; + Value rectangleSize; + Value mFirstRow; + + _getTriuSizes(rewriter, loc, rowInt, colInt, offsetInt, trapezoidSize, + rectangleSize, mFirstRow); + + // Get const ints from Values + int64_t rectangleSizeInt; + matchPattern(rectangleSize, m_TorchConstantInt(&rectangleSizeInt)); + int64_t trapezoidSizeInt; + matchPattern(trapezoidSize, m_TorchConstantInt(&trapezoidSizeInt)); + int64_t mFirstRowInt; + matchPattern(mFirstRow, m_TorchConstantInt(&mFirstRowInt)); + + // Calculte column offset + Value colOffset = (offsetInt > 0) ? offset : cstZero; + + // Calculate indices for top rectangle + // Type type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + Value xs2 = arange(rewriter, loc, rectangleSize, dtype, layout, device, + pinMemory, *dtypeType); + + // Calculate row_indices2 and column_idices 2 + Value rowInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + Value colInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + + // Bottom trapezoid + auto f64DtypeInt = + getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); + Value xs1 = arange(rewriter, loc, trapezoidSize, f64DtypeInt, layout, + device, pinMemory, rewriter.getF64Type()); + + // b = -0.5 - m_first_row + Value mFirstRowFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = rewriter.create(loc, cstMinusZeroPointFive, + mFirstRowFloat); + + // Implements this piece of code: row_inds1 = torch.floor(-b - torch.sqrt(b + // * b - 2 * xs1)) + Value bSquare = rewriter.create(loc, b, b); + + Value twoTimesXs1 = rewriter.create(loc, xs1.getType(), + xs1, cstMinusTwoFloat); + Value sqrtInput = rewriter.create( + loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + + Value sqrt = + rewriter.create(loc, sqrtInput.getType(), sqrtInput); + Value negativeSqrt = rewriter.create(loc, sqrt.getType(), sqrt); + + Value rowInds1 = rewriter.create( + loc, negativeSqrt.getType(), negativeSqrt, b, cstOne); + rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + + // Implements this piece of code: col_inds1 = torch.floor(xs1 - ((2 * + // m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + Value twoTimesMFirstRow = + rewriter.create(loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + rewriter.create(loc, twoTimesMFirstRow, cstOne); + Value negativeRowInds1 = + rewriter.create(loc, rowInds1.getType(), rowInds1); + + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, twoTimesMFirstRow, + cstOne); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, rowInds1); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, + cstMinusZeroPointFive); + + Value colInds1 = rewriter.create(loc, xs1.getType(), xs1, + negativeRowInds1, cstOne); + colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + + // Convert to dtype + Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); + + auto rowInds1Type = cast(rowInds1.getType()); + ArrayRef sizes = rowInds1Type.getSizes(); + Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); + rowInds1 = rewriter.create( + loc, finalRowType, rowInds1, dtype, + /*non_blocking*/ cstFalse, /*copy*/ cstFalse, + /*memory_format*/ cstOne); + + auto colInds1Type = cast(colInds1.getType()); + sizes = colInds1Type.getSizes(); + Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); + colInds1 = rewriter.create( + loc, finalColType, colInds1, dtype, + /*non_blocking*/ cstFalse, /*copy*/ cstFalse, + /*memory_format*/ cstOne); + + // Final calculation for row and col indices + if (colInt) { + + Value rectangleSizeDivCol = + rewriter.create(loc, rectangleSizeInt / colInt); + + rowInds1 = rewriter.create( + loc, rowInds1.getType(), rowInds1, rectangleSizeDivCol, cstOne); + } + + colInds1 = rewriter.create(loc, colInds1.getType(), + colInds1, colOffset, cstOne); + + Type listElemType = + cast(rowInds1.getType()) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + Value sequenceRow = rewriter.create( + loc, listType, SmallVector{rowInds2, rowInds1}); + Value sequenceCol = rewriter.create( + loc, listType, SmallVector{colInds2, colInds1}); + + // Concatenate row and col indices + Type finalCatType = colInds1Type.getWithSizesAndDtype( + {rectangleSizeInt + trapezoidSizeInt}, int64Type); + + Value catRow = rewriter.create(loc, finalCatType, sequenceRow, + /*dim*/ cstZero); + Value catCol = rewriter.create(loc, finalCatType, sequenceCol, + /*dim*/ cstZero); + + // Make return value + Value sequence = rewriter.create( + loc, Torch::ListType::get(context, rowInds1.getType()), + ValueRange{catRow, catCol}); + Type finalStackType = colInds1Type.getWithSizesAndDtype( + ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); + + rewriter.replaceOpWithNewOp(op, finalStackType, sequence, + cstZero); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -8233,6 +8533,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); // More specific conv ops addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index fb5dd7ea8b2b..8c465173b7f8 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -539,6 +539,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index be9498a53252..b2fbd7c3cd7b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1313,6 +1313,9 @@ "TorchPrimLoopForLikeTensorArgModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", + "TriuIndicesModule_basic", + "TriuIndicesAllZerosModule_basic", + "TriuIndicesNegativeOffsetModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeAsSameModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 3aa1a5ef26de..89fcc730a1cc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1746,6 +1746,36 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode, per_sample_weights, padding_idx) +def aten〇triu_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + if row == 0 or col == 0: + return [2, 0] + + # Number of elements in top rectangle + rectangle_size = max(0, min(row, -offset) * col) + + # _get_tril_indices + offset_tril = offset - 1 + if row == 0 or col == 0: + trapezoid_size_tril = 0 + rectangle_size_tril = 0 + else: + m_first_row = min(col, 1 + offset_tril) if offset_tril > 0 else int(row + offset_tril > 0) + m_last_row = max(0, min(col, row + offset_tril)) + n_row_all = max(0, min(row, row + offset_tril)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size_tril = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size_tril = max(0, diff_row * col) + + # Number of elements in bottom trapezoid + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + trapezoid_size = triu_size - rectangle_size + + return [2, rectangle_size + trapezoid_size] + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -4911,6 +4941,9 @@ def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: return torch.float32 +def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.int64 if dtype is None else dtype + def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 106fa18ae630..9fbbc31fe456 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1054,6 +1054,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) + emit( + "aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)", + has_verifier=True, + ) + # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f3bcefc95330..ce000264efec 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6223,3 +6223,63 @@ def forward(self, x): ) def FakeQuantizePerTensorAffineRoundToEvenModule_basic(module, tu: TestUtils): module.forward(torch.FloatTensor([0.5, 1.5, -0.5, -1.5])) + + +# ============================================================================== + + +class TriuIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(4, 3, 1) + + +@register_test_case(module_factory=lambda: TriuIndicesModule()) +def TriuIndicesModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesAllZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(0, 0, 0) + + +@register_test_case(module_factory=lambda: TriuIndicesAllZerosModule()) +def TriuIndicesAllZerosModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesNegativeOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(5, 16, -2) + + +@register_test_case(module_factory=lambda: TriuIndicesNegativeOffsetModule()) +def TriuIndicesNegativeOffsetModule_basic(module, tu: TestUtils): + module.forward()