Skip to content

Commit

Permalink
[flang] Introduce ws loop nest generation for HLFIR lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanradanov committed Aug 1, 2024
1 parent 8068d60 commit c2cbd77
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 39 deletions.
12 changes: 7 additions & 5 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,22 @@ hlfir::ElementalOp genElementalOp(

/// Structure to describe a loop nest.
struct LoopNest {
fir::DoLoopOp outerLoop;
fir::DoLoopOp innerLoop;
mlir::Operation *outerOp;
mlir::Block *body;
llvm::SmallVector<mlir::Value> oneBasedIndices;
};

/// Generate a fir.do_loop nest looping from 1 to extents[i].
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered = false);
mlir::ValueRange extents, bool isUnordered = false,
bool emitWsLoop = false);
inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value shape, bool isUnordered = false) {
mlir::Value shape, bool isUnordered = false,
bool emitWsLoop = false) {
return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
isUnordered);
isUnordered, emitWsLoop);
}

/// Inline the body of an hlfir.elemental at the current insertion point
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
callContext.stmtCtx.pushScope();
for (auto &preparedActual : loweredActuals)
if (preparedActual)
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
// know this won't miss any opportuinties for clever elemental inlining
hlfir::LoopNest nest = hlfir::genLoopNest(
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
builder.setInsertionPointToStart(nest.innerLoop.getBody());
builder.setInsertionPointToStart(nest.body);
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
Expand All @@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder, loc, redId, refTy, lhsEle, rhsEle);
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);

builder.setInsertionPointAfter(nest.outerLoop);
builder.setInsertionPointAfter(nest.outerOp);
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
}

Expand Down
43 changes: 31 additions & 12 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
#include <optional>

// Return explicit extents. If the base is a fir.box, this won't read it to
Expand Down Expand Up @@ -855,26 +856,44 @@ mlir::Value hlfir::inlineElementalOp(

hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered) {
mlir::ValueRange extents, bool isUnordered,
bool emitWsLoop) {
hlfir::LoopNest loopNest;
assert(!extents.empty() && "must have at least one extent");
auto insPt = builder.saveInsertionPoint();
mlir::OpBuilder::InsertionGuard guard(builder);
loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
unsigned dim = extents.size() - 1;
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
loopNest.innerLoop =
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
// Reverse the indices so they are in column-major order.
loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
if (!loopNest.outerLoop)
loopNest.outerLoop = loopNest.innerLoop;

if (emitWsLoop) {
auto wsloop = builder.create<mlir::omp::WsloopOp>(loc, mlir::ArrayRef<mlir::NamedAttribute>());
loopNest.outerOp = wsloop;
builder.createBlock(wsloop.getBody());
mlir::omp::LoopNestOperands lnops;
lnops.loopInclusive = builder.getUnitAttr();
for (auto extent : llvm::reverse(extents)) {
lnops.loopLowerBounds.push_back(one);
lnops.loopUpperBounds.push_back(extent);
lnops.loopSteps.push_back(one);
}
auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
builder.create<mlir::omp::YieldOp>(loc);
builder.createBlock(&lnOp.getRegion().front());
builder.create<mlir::omp::YieldOp>(loc);
} else {
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
auto doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
loopNest.body = doLoop.getBody();
builder.setInsertionPointToStart(loopNest.body);
// Reverse the indices so they are in column-major order.
loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
if (!loopNest.outerOp)
loopNest.outerOp = doLoop;
}
}
builder.restoreInsertionPoint(insPt);
return loopNest;
}

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/TypeSwitch.h"

namespace hlfir {
Expand Down Expand Up @@ -793,7 +794,7 @@ struct ElementalOpConversion
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue(yield.getElementValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
// if the LHS is not).
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
builder.setInsertionPointToStart(elementalLoopNest->body);
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
elementalLoopNest->oneBasedIndices);
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
Expand All @@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
for (auto &cleanupConversion : argConversionCleanups)
cleanupConversion();
if (elementalLoopNest)
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
builder.setInsertionPointAfter(elementalLoopNest->outerOp);
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
Expand All @@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
loweredLhs.vectorSubscriptLoopNest->outerOp);
generateCleanupIfAny(oldRhsYield);
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}
Expand All @@ -518,16 +518,16 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
hlfir::Entity savedMask{maybeSaved->first};
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
whereLoopNest->oneBasedIndices);
generateMaskIfOp(cdt);
if (maybeSaved->second) {
// If this is the same run as the one that saved the value, the clean-up
// was left-over to be done now.
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
generateCleanupIfAny(maybeSaved->second);
builder.restoreInsertionPoint(insertionPoint);
}
Expand All @@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
mask.generateNoneElementalPart(builder, mapper);
mlir::Value shape = mask.generateShape(builder, mapper);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = generateMaskedEntity(mask);
generateMaskIfOp(cdt);
return;
Expand Down Expand Up @@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
loc, builder, loweredLhs.vectorSubscriptShape.value());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
loweredLhs.vectorSubscriptLoopNest->body);
}
loweredLhs.lhs = temp->second.fetch(loc, builder);
return loweredLhs;
Expand All @@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
!elementalAddrLhs.isOrdered());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
loweredLhs.vectorSubscriptLoopNest->body);
mapper.map(elementalAddrLhs.getIndices(),
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
Expand All @@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
if (!maskedExpr.noneElementalPartWasGenerated) {
// Generate none elemental part before the where loops (but inside the
// current forall loops if any).
builder.setInsertionPoint(whereLoopNest->outerLoop);
builder.setInsertionPoint(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalPart(builder, mapper);
}
// Generate the none elemental part cleanup after the where loops.
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
// Generate the value of the current element for the masked expression
// at the current insertion point (inside the where loops, and any fir.if
Expand Down Expand Up @@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
fir::factory::TemporaryStorage *temp = nullptr;
if (loweredLhs.vectorSubscriptLoopNest)
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
// Vector subscripted entity for which the shape must also be saved on top
// of the element addresses (e.g. the shape may change in each forall
Expand All @@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
// subscripted LHS.
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
vectorTmp.pushShape(loc, builder, shape);
builder.restoreInsertionPoint(insertionPoint);
} else {
Expand All @@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
if (loweredLhs.vectorSubscriptLoopNest) {
constructStack.pop_back();
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
loweredLhs.vectorSubscriptLoopNest->outerOp);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue{yield.getElementValue()};
Expand Down Expand Up @@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
Expand Down Expand Up @@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
Expand Down

0 comments on commit c2cbd77

Please sign in to comment.