Skip to content

Commit

Permalink
[mlir][emitc] Add a structured for operation (llvm#68206)
Browse files Browse the repository at this point in the history
Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.
  • Loading branch information
aniragil authored Oct 26, 2023
1 parent 585da26 commit 2633d94
Show file tree
Hide file tree
Showing 9 changed files with 426 additions and 122 deletions.
3 changes: 0 additions & 3 deletions mlir/docs/Dialects/emitc.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,5 @@ translating the following operations:
* `func.constant`
* `func.func`
* `func.return`
* 'scf' Dialect
* `scf.for`
* `scf.yield`
* 'arith' Dialect
* `arith.constant`
64 changes: 63 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,67 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}

def EmitC_ForOp : EmitC_Op<"for",
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
let description = [{
The `emitc.for` operation represents a C loop of the following form:

```c++
for (T i = lb; i < ub; i += step) { /* ... */ } // where T is typeof(lb)
```

The operation takes 3 SSA values as operands that represent the lower bound,
upper bound and step respectively, and defines an SSA value for its
induction variable. It has one region capturing the loop body. The induction
variable is represented as an argument of this region. This SSA value is a
signless integer or index. The step is a value of same type.

This operation has no result. The body region must contain exactly one block
that terminates with `emitc.yield`. Calling ForOp::build will create such a
region and insert the terminator implicitly if none is defined, so will the
parsing even in cases when it is absent from the custom format. For example:

```mlir
// Index case.
emitc.for %iv = %lb to %ub step %step {
... // body
}
...
// Integer case.
emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
... // body
}
```
}];
let arguments = (ins IntegerIndexOrOpaqueType:$lowerBound,
IntegerIndexOrOpaqueType:$upperBound,
IntegerIndexOrOpaqueType:$step);
let results = (outs);
let regions = (region SizedRegion<1>:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
];

let extraClassDeclaration = [{
using BodyBuilderFn =
function_ref<void(OpBuilder &, Location, Value)>;
Value getInductionVar() { return getBody()->getArgument(0); }
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

def EmitC_IncludeOp
: EmitC_Op<"include", [HasParent<"ModuleOp">]> {
let summary = "Include operation";
Expand Down Expand Up @@ -430,7 +491,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
}

def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
Expand Down
123 changes: 99 additions & 24 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};

// Lower scf::if to emitc::if, implementing return values as emitc::variable's
// Lower scf::for to emitc::for, implementing result values using
// emitc::variable's updated within the loop body.
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
};

// Create an uninitialized emitc::variable op for each result of the given op.
template <typename T>
static SmallVector<Value> createVariablesForResults(T op,
PatternRewriter &rewriter) {
SmallVector<Value> resultVariables;

if (!op.getNumResults())
return resultVariables;

Location loc = op->getLoc();
MLIRContext *context = op.getContext();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
resultVariables.push_back(var);
}

return resultVariables;
}

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables =
createVariablesForResults(forOp, rewriter);
SmallVector<Value> iterArgsVariables =
createVariablesForResults(forOp, rewriter);

assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);

emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());

Block *loweredBody = loweredFor.getBody();

// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(iterArgsVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

// Copy iterArgs into results after the for loop.
assignValues(iterArgsVariables, resultVariables, rewriter, loc);

rewriter.replaceOp(forOp, resultVariables);
return success();
}

// Lower scf::if to emitc::if, implementing result values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
Expand All @@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();

SmallVector<Value> resultVariables;

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
if (ifOp.getNumResults()) {
MLIRContext *context = ifOp.getContext();
rewriter.setInsertionPoint(ifOp);
for (OpResult result : ifOp.getResults()) {
Type resultType = result.getType();
auto noInit = emitc::OpaqueAttr::get(context, "");
auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
resultVariables.push_back(var);
}
}
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
Expand All @@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
Location terminatorLoc = terminator->getLoc();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&loweredRegion.back());
for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
Value resultValue = std::get<0>(value2Var);
Value resultVar = std::get<1>(value2Var);
rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
}
rewriter.create<emitc::YieldOp>(terminatorLoc);
rewriter.eraseOp(terminator);
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
};

Region &thenRegion = ifOp.getThenRegion();
Expand All @@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
}

Expand All @@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::IfOp>();
target.addIllegalOp<scf::ForOp, scf::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down
95 changes: 95 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,101 @@ LogicalResult emitc::ConstantOp::verify() {

OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//

void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, BodyBuilderFn bodyBuilder) {
result.addOperands({lb, ub, step});
Type t = lb.getType();
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArgument(t, result.location);

// Create the default terminator if the builder is not provided.
if (!bodyBuilder) {
ForOp::ensureTerminator(*bodyRegion, builder, result.location);
} else {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
}
}

void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}

ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
Type type;

OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;

// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
parser.parseOperand(lb) || parser.parseKeyword("to") ||
parser.parseOperand(ub) || parser.parseKeyword("step") ||
parser.parseOperand(step))
return failure();

// Parse the optional initial iteration arguments.
SmallVector<OpAsmParser::Argument, 4> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
regionArgs.push_back(inductionVariable);

// Parse optional type, else assume Index.
if (parser.parseOptionalColon())
type = builder.getIndexType();
else if (parser.parseType(type))
return failure();

// Resolve input operands.
regionArgs.front().type = type;
if (parser.resolveOperand(lb, type, result.operands) ||
parser.resolveOperand(ub, type, result.operands) ||
parser.resolveOperand(step, type, result.operands))
return failure();

// Parse the body region.
Region *body = result.addRegion();
if (parser.parseRegion(*body, regionArgs))
return failure();

ForOp::ensureTerminator(*body, builder, result.location);

// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();

return success();
}

void ForOp::print(OpAsmPrinter &p) {
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();

p << ' ';
if (Type t = getInductionVar().getType(); !t.isIndex())
p << " : " << t << ' ';
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p.printOptionalAttrDict((*this)->getAttrs());
}

LogicalResult ForOp::verifyRegions() {
// Check that the body defines as single block argument for the induction
// variable.
if (getInductionVar().getType() != getLowerBound().getType())
return emitOpError(
"expected induction variable to be same type as bounds and step");

return success();
}

//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 2633d94

Please sign in to comment.