Skip to content

Commit

Permalink
Implement lowering of torch.aten.triu_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
Branko Trifkovic committed Jun 14, 2024
1 parent a02e14e commit 15cfb9a
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 0 deletions.
30 changes: 30 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15362,6 +15362,36 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
let hasCanonicalizer = 1;
}

def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$row,
Torch_IntType:$col,
Torch_IntType:$offset,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTriuIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenTriuIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
let hasVerifier = 1;
}

def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
35 changes: 35 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5088,3 +5088,38 @@ LogicalResult BindSymbolicShapeOp::verify() {

return success();
}
// AtenTriuIndicesOp
//===----------------------------------------------------------------------===//

LogicalResult AtenTriuIndicesOp::verify() {

// Check if row, col and offset are constant ints
int64_t row;
if (!matchPattern(getRow(), m_TorchConstantInt(&row)))
return success();

int64_t col;
if (!matchPattern(getCol(), m_TorchConstantInt(&col)))
return success();

int64_t offset;
if (!matchPattern(getOffset(), m_TorchConstantInt(&offset)))
return success();

// Check if values of row, and col are valid
if (row < 0)
return emitOpError("row must be non-negative, got ") << row;

if (col < 0)
return emitOpError("col must be non-negative, got ") << col;

// Check if dtype is valid
int64_t dtype;
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype)))
return success();
if (dtype != 3 && dtype != 4)
return emitOpError(
"'triu_indices' implemented only for torch.int32 and torch.int64");

return success();
}
80 changes: 80 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9713,6 +9713,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.int, !torch.optional<list<int>>, !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.aten.neg.int %arg2 : !torch.int -> !torch.int\n"
" %4 = torch.prim.min.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %6 = torch.prim.max.int %int0, %5 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %17 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %17 : !torch.bool\n"
" }\n"
" %10:2 = torch.prim.If %9 -> (!torch.int, !torch.int) {\n"
" torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n"
" } else {\n"
" %17 = torch.aten.gt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %18 = torch.prim.If %17 -> (!torch.int) {\n"
" %33 = torch.aten.add.int %int1, %7 : !torch.int, !torch.int -> !torch.int\n"
" %34 = torch.prim.min.int %arg1, %33 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %34 : !torch.int\n"
" } else {\n"
" %33 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n"
" %34 = torch.aten.gt.int %33, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %35 = torch.aten.Int.bool %34 : !torch.bool -> !torch.int\n"
" torch.prim.If.yield %35 : !torch.int\n"
" }\n"
" %19 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n"
" %20 = torch.prim.min.int %arg1, %19 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.prim.max.int %int0, %20 : !torch.int, !torch.int -> !torch.int\n"
" %22 = torch.aten.add.int %arg0, %7 : !torch.int, !torch.int -> !torch.int\n"
" %23 = torch.prim.min.int %arg0, %22 : !torch.int, !torch.int -> !torch.int\n"
" %24 = torch.prim.max.int %int0, %23 : !torch.int, !torch.int -> !torch.int\n"
" %25 = torch.aten.sub.int %21, %18 : !torch.int, !torch.int -> !torch.int\n"
" %26 = torch.aten.add.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %27 = torch.aten.add.int %18, %21 : !torch.int, !torch.int -> !torch.int\n"
" %28 = torch.aten.mul.int %27, %26 : !torch.int, !torch.int -> !torch.int\n"
" %29 = torch.aten.floordiv.int %28, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %30 = torch.aten.sub.int %24, %26 : !torch.int, !torch.int -> !torch.int\n"
" %31 = torch.aten.mul.int %30, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %32 = torch.prim.max.int %int0, %31 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %29, %32 : !torch.int, !torch.int\n"
" }\n"
" %11 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %12 = torch.aten.add.int %10#0, %10#1 : !torch.int, !torch.int -> !torch.int\n"
" %13 = torch.aten.sub.int %11, %12 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.aten.sub.int %13, %6 : !torch.int, !torch.int -> !torch.int\n"
" %15 = torch.aten.add.int %6, %14 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.prim.ListConstruct %int2, %15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %16 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
Expand Down Expand Up @@ -13936,6 +14004,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int6 = torch.constant.int 6\n"
" return %int6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int4 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int3 = torch.constant.int 3\n"
" %int1 = torch.constant.int 1\n"
Expand Down
Loading

0 comments on commit 15cfb9a

Please sign in to comment.