Skip to content

Commit

Permalink
[mlir][scf] Allow unrolling loops with integer-typed IV. (#106164)
Browse files Browse the repository at this point in the history
SCF loops now can operate on integer-typed IV, thus I'm changing the
loop unroller correspondingly.
  • Loading branch information
htyu authored Aug 29, 2024
1 parent e05c224 commit c08c6a7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
35 changes: 21 additions & 14 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
int64_t divisor) {
assert(divisor > 0 && "expected positive divisor");
assert(dividend.getType().isIndex() && "expected index-typed value");
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");

Value divisorMinusOneCst =
builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
Value divisorCst = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(dividend.getType(), divisor));
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
}
Expand All @@ -285,9 +287,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
Value divisor) {
assert(dividend.getType().isIndex() && "expected index-typed value");

Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");
Value cstOne = builder.create<arith::ConstantOp>(
loc, builder.getOneAttr(dividend.getType()));
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
return builder.create<arith::DivUIOp>(loc, sum, divisor);
Expand Down Expand Up @@ -409,16 +412,18 @@ LogicalResult mlir::loopUnrollByFactor(
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
loc, upperBoundUnrolledCst);
upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
upperBoundUnrolledCst));
else
upperBoundUnrolled = forOp.getUpperBound();

// Create constant for 'stepUnrolled'.
stepUnrolled = stepCst == stepUnrolledCst
? step
: boundsBuilder.create<arith::ConstantIndexOp>(
loc, stepUnrolledCst);
: boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(
step.getType(), stepUnrolledCst));
} else {
// Dynamic loop bounds computation.
// TODO: Add dynamic asserts for negative lb/ub/step, or
Expand All @@ -428,8 +433,8 @@ LogicalResult mlir::loopUnrollByFactor(
Value diff =
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
Value unrollFactorCst =
boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
Value tripCountRem =
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
Expand Down Expand Up @@ -476,7 +481,9 @@ LogicalResult mlir::loopUnrollByFactor(
[&](unsigned i, Value iv, OpBuilder b) {
// iv' = iv + step * i;
auto stride = b.create<arith::MulIOp>(
loc, step, b.create<arith::ConstantIndexOp>(loc, i));
loc, step,
b.create<arith::ConstantOp>(loc,
b.getIntegerAttr(iv.getType(), i)));
return b.create<arith::AddIOp>(loc, iv, stride);
},
annotateFn, iterArgs, yieldedValues);
Expand Down
41 changes: 41 additions & 0 deletions mlir/test/Dialect/SCF/loop-unroll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,44 @@ func.func @loop_unroll_yield_iter_arg() {
// CHECK-NEXT: affine.yield %[[ITER_ARG]] : index
// CHECK-NEXT: }
// CHECK-NEXT: return

// -----

// Test the loop unroller works with integer IV type.
func.func @static_loop_unroll_with_integer_iv() -> (f32, f32) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : i32
%ub = arith.constant 20 : i32
%step = arith.constant 1 : i32
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
%add = arith.addf %arg0, %arg1 : f32
%mul = arith.mulf %arg0, %arg1 : f32
scf.yield %add, %mul : f32, f32
}
return %result#0, %result#1 : f32, f32
}
// UNROLL-BY-3-LABEL: func @static_loop_unroll_with_integer_iv
//
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32

0 comments on commit c08c6a7

Please sign in to comment.