Skip to content

Commit

Permalink
Implement lowering of torch.aten.linalg_cross (#2986)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrifunovic98 authored Mar 13, 2024
1 parent 6fa21bd commit 524ff99
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11732,6 +11732,32 @@ def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
}];
}

def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinalgCrossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLinalgCrossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
90 changes: 90 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4278,6 +4278,96 @@ LogicalResult AtenPermuteOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//

LogicalResult AtenLinalgCrossOp::verify() {

auto selfType = getSelf().getType().cast<BaseTensorType>();
auto otherType = getOther().getType().cast<BaseTensorType>();

if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
!otherType.hasSizes()) {
return success();
}

Type selfDtype = selfType.getDtype();
Type otherDtype = otherType.getDtype();

// the operation succeeds only if both inputs have the same dtype
if (selfDtype != otherDtype) {
return emitOpError("input tensors must have the same dtype, but got ")
<< selfDtype << " and " << otherDtype;
}

// Check if any of the input tensors has torch.bool dtype.
// The operation does not support this type.
// The docs state that only float, double, cfloat and cdouble dtypes are
// supported, but, when testing, it fails only for boolean dtype. Update to
// fit the docs if necessary.
// https://pytorch.org/docs/stable/generated/torch.linalg.cross.html
if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) {
return emitOpError("input tensors must not have bool dtype");
}

ArrayRef<int64_t> selfShape = selfType.getSizes();
ArrayRef<int64_t> otherShape = otherType.getSizes();

int64_t selfRank = selfShape.size();
int64_t otherRank = otherShape.size();

// check if both input tensors have the same number of dims
if (selfRank != otherRank) {
return emitOpError("input tensors must have the same number of dimensions, "
"but got ")
<< selfRank << " and " << otherRank;
}

// convert dim to an integer type
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
return success();
}

// check if dim is in the correct range
if (dim >= selfRank || dim < -selfRank) {
return emitOpError("dim expected to be in rank of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
}

// compensate for possible negative dim value
if (dim < 0) {
dim += selfRank;
}

// check if the size of the dimensions specified by 'dim' is equal to 3
// (required by the operation)
if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) ||
(otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) {
return emitOpError("inputs dimension ")
<< dim << " must have length 3, but got " << selfShape[dim]
<< " and " << otherShape[dim];
}

// Check if there is a disparity between dimension sizes.
// Dimensions at the same index must either have the same size,
// or one of them must be equal to 1.
int32_t i = 0;
for (auto [selfCurrent, otherCurrent] :
llvm::zip_equal(selfShape, otherShape)) {
if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) {
return emitOpError("the size of first tensor (")
<< selfCurrent << ") must match the size of second tensor ("
<< otherCurrent << ") at dimension " << i
<< " or one of them must be 1";
}
++i;
}

return success();
}

//===----------------------------------------------------------------------===//
// DtypeCalculateYieldDtypesOp
//===----------------------------------------------------------------------===//
Expand Down
76 changes: 76 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6793,6 +6793,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linalg_cross\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"the size of first tensor ({}) must match the size of second tensor ({}) at dimension {}\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: inputs must have the same number of dimensions\"\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" }\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %10 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" }\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.format(%str_0, %10, %11, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n"
" %13 = torch.aten.add.str %str, %12 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -10033,6 +10084,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_cross\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
Expand Down
112 changes: 112 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,117 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
// aten.add.Tensor and aten.mull.Tensor. See
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1):
// broadcast_shape = compute_broadcast_shape(self, other)
// a = torch.broadcast_to(self, broadcast_shape)
// b = torch.broadcast_to(other, broadcast_shape)
// idx = torch.arange(3)
// return a.index_select(dim, (idx + 1) % 3) *
// b.index_select(dim, (idx + 2) % 3) -
// a.index_select(dim, (idx + 2) % 3) *
// b.index_select(dim, (idx + 1) % 3)
namespace {
class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinalgCrossOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value self = op.getSelf();
Value other = op.getOther();
Type opType = op.getType();
Value dim = op.getDim();

auto resType = self.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

Type dtype = resType.getDtype();
if (dtype.isa<mlir::ComplexType>()) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.linalg_cross for complex inputs dtype is "
"currently unimplemented");
}

// calculate common shape for broadcast
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
broadcastShapeValue);

Type broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);

Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);

// broadcast tensors to common shape
auto a = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, self,
indexBroadcastShapeTorchList);
auto b = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, other,
indexBroadcastShapeTorchList);

// create constants
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constTwo = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value constThree = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value none = rewriter.create<ConstantNoneOp>(loc);

// idx = torch.arange(3)
auto outType = opType.dyn_cast<BaseTensorType>();
auto arangeType = outType.getWithSizesAndDtype(
llvm::ArrayRef<int64_t>(3),
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
auto idx = rewriter.create<AtenArangeOp>(
loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

// (idx + 1) and (idx + 2)
auto idxPlusOne = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
constOne, constOne);
auto idxPlusTwo = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
constTwo, constOne);

// (idx + 1) % 3 and (idx + 2) % 3
auto idxPlusOneRemainderThree = rewriter.create<AtenRemainderScalarOp>(
loc, arangeType, idxPlusOne, constThree);
auto idxPlusTwoRemainderThree = rewriter.create<AtenRemainderScalarOp>(
loc, arangeType, idxPlusTwo, constThree);

// a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3)
auto idxSelectAPlusOne = rewriter.create<AtenIndexSelectOp>(
loc, opType, a, dim, idxPlusOneRemainderThree);
auto idxSelectBPlusTwo = rewriter.create<AtenIndexSelectOp>(
loc, opType, b, dim, idxPlusTwoRemainderThree);
auto firstMul = rewriter.create<AtenMulTensorOp>(
loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo);

// a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
auto idxSelectAPlusTwo = rewriter.create<AtenIndexSelectOp>(
loc, opType, a, dim, idxPlusTwoRemainderThree);
auto idxSelectBPlusOne = rewriter.create<AtenIndexSelectOp>(
loc, opType, b, dim, idxPlusOneRemainderThree);
auto secondMul = rewriter.create<AtenMulTensorOp>(
loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne);

// subtract the results of the two multiplications from above
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, opType, firstMul,
secondMul, constOne);

return success();
}
};
} // namespace

// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
// prims.collapse operations.
//
Expand Down Expand Up @@ -7081,6 +7192,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,9 @@
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",

# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
"AtenLinalgCrossDynamic_basic"
}

ONNX_CRASHING_SET = { }
Expand Down
Loading

0 comments on commit 524ff99

Please sign in to comment.