From 4da93bb2a99ac1d59d4924c518503c94ec81c659 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 31 Jul 2024 14:12:34 +0900 Subject: [PATCH] [flang] Introduce ws loop nest generation for HLFIR lowering --- .../flang/Optimizer/Builder/HLFIRTools.h | 12 +++-- flang/lib/Lower/ConvertCall.cpp | 2 +- flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 4 +- flang/lib/Optimizer/Builder/HLFIRTools.cpp | 52 ++++++++++++++----- .../HLFIR/Transforms/BufferizeHLFIR.cpp | 3 +- .../LowerHLFIROrderedAssignments.cpp | 30 +++++------ .../Transforms/OptimizedBufferization.cpp | 6 +-- 7 files changed, 69 insertions(+), 40 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 6b41025eea0780..14e42c6f358e46 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp( /// Structure to describe a loop nest. struct LoopNest { - fir::DoLoopOp outerLoop; - fir::DoLoopOp innerLoop; + mlir::Operation *outerOp; + mlir::Block *body; llvm::SmallVector oneBasedIndices; }; @@ -366,11 +366,13 @@ struct LoopNest { /// \p isUnordered specifies whether the loops in the loop nest /// are unordered. LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::ValueRange extents, bool isUnordered = false); + mlir::ValueRange extents, bool isUnordered = false, + bool emitWsLoop = false); inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Value shape, bool isUnordered = false) { + mlir::Value shape, bool isUnordered = false, + bool emitWsLoop = false) { return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape), - isUnordered); + isUnordered, emitWsLoop); } /// Inline the body of an hlfir.elemental at the current insertion point diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index fd873f55dd844e..0689d6e033dd9c 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -2128,7 +2128,7 @@ class ElementalCallBuilder { hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered); mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices; auto insPt = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); + builder.setInsertionPointToStart(loopNest.body); callContext.stmtCtx.pushScope(); for (auto &preparedActual : loweredActuals) if (preparedActual) diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index c3c1f363033c27..72a90dd0d6f29d 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, // know this won't miss any opportuinties for clever elemental inlining hlfir::LoopNest nest = hlfir::genLoopNest( loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); - builder.setInsertionPointToStart(nest.innerLoop.getBody()); + builder.setInsertionPointToStart(nest.body); mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); auto lhsEleAddr = builder.create( loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, @@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, builder, loc, redId, refTy, lhsEle, rhsEle); builder.create(loc, scalarReduction, lhsEleAddr); - builder.setInsertionPointAfter(nest.outerLoop); + builder.setInsertionPointAfter(nest.outerOp); builder.create(loc, lhsAddr); } diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 8d0ae2f195178c..cd07cb741eb4bb 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/TypeSwitch.h" +#include #include // Return explicit extents. If the base is a fir.box, this won't read it to @@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp( hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::ValueRange extents, bool isUnordered) { + mlir::ValueRange extents, bool isUnordered, + bool emitWsLoop) { hlfir::LoopNest loopNest; assert(!extents.empty() && "must have at least one extent"); - auto insPt = builder.saveInsertionPoint(); + mlir::OpBuilder::InsertionGuard guard(builder); loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{}); // Build loop nest from column to row. auto one = builder.create(loc, 1); mlir::Type indexType = builder.getIndexType(); - unsigned dim = extents.size() - 1; - for (auto extent : llvm::reverse(extents)) { - auto ub = builder.createConvert(loc, indexType, extent); - loopNest.innerLoop = - builder.create(loc, one, ub, one, isUnordered); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); - // Reverse the indices so they are in column-major order. - loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar(); - if (!loopNest.outerLoop) - loopNest.outerLoop = loopNest.innerLoop; + if (emitWsLoop) { + auto wsloop = builder.create( + loc, mlir::ArrayRef()); + loopNest.outerOp = wsloop; + builder.createBlock(&wsloop.getRegion()); + mlir::omp::LoopNestOperands lnops; + lnops.loopInclusive = builder.getUnitAttr(); + for (auto extent : llvm::reverse(extents)) { + lnops.loopLowerBounds.push_back(one); + lnops.loopUpperBounds.push_back(extent); + lnops.loopSteps.push_back(one); + } + auto lnOp = builder.create(loc, lnops); + builder.create(loc); + mlir::Block *block = builder.createBlock(&lnOp.getRegion()); + for (auto extent : llvm::reverse(extents)) + block->addArgument(extent.getType(), extent.getLoc()); + loopNest.body = block; + builder.create(loc); + for (unsigned dim = 0; dim < extents.size(); dim++) + loopNest.oneBasedIndices[extents.size() - dim - 1] = + lnOp.getRegion().front().getArgument(dim); + } else { + unsigned dim = extents.size() - 1; + for (auto extent : llvm::reverse(extents)) { + auto ub = builder.createConvert(loc, indexType, extent); + auto doLoop = + builder.create(loc, one, ub, one, isUnordered); + loopNest.body = doLoop.getBody(); + builder.setInsertionPointToStart(loopNest.body); + // Reverse the indices so they are in column-major order. + loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar(); + if (!loopNest.outerOp) + loopNest.outerOp = doLoop; + } } - builder.restoreInsertionPoint(insPt); return loopNest; } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index a70a6b388c4b1a..b608677c526310 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -31,6 +31,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "llvm/ADT/TypeSwitch.h" namespace hlfir { @@ -793,7 +794,7 @@ struct ElementalOpConversion hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered()); auto insPt = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); + builder.setInsertionPointToStart(loopNest.body); auto yield = hlfir::inlineElementalOp(loc, builder, elemental, loopNest.oneBasedIndices); hlfir::Entity elementValue(yield.getElementValue()); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp index 85dd517cb57914..645abf65d10a32 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { // if the LHS is not). mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity); elementalLoopNest = hlfir::genLoopNest(loc, builder, shape); - builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody()); + builder.setInsertionPointToStart(elementalLoopNest->body); lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity, elementalLoopNest->oneBasedIndices); rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity, @@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { for (auto &cleanupConversion : argConversionCleanups) cleanupConversion(); if (elementalLoopNest) - builder.setInsertionPointAfter(elementalLoopNest->outerLoop); + builder.setInsertionPointAfter(elementalLoopNest->outerOp); } else { // TODO: preserve allocatable assignment aspects for forall once // they are conveyed in hlfir.region_assign. @@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { generateCleanupIfAny(loweredLhs.elementalCleanup); if (loweredLhs.vectorSubscriptLoopNest) builder.setInsertionPointAfter( - loweredLhs.vectorSubscriptLoopNest->outerLoop); + loweredLhs.vectorSubscriptLoopNest->outerOp); generateCleanupIfAny(oldRhsYield); generateCleanupIfAny(loweredLhs.nonElementalCleanup); } @@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) { hlfir::Entity savedMask{maybeSaved->first}; mlir::Value shape = hlfir::genShape(loc, builder, savedMask); whereLoopNest = hlfir::genLoopNest(loc, builder, shape); - constructStack.push_back(whereLoopNest->outerLoop.getOperation()); - builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody()); + constructStack.push_back(whereLoopNest->outerOp); + builder.setInsertionPointToStart(whereLoopNest->body); mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask, whereLoopNest->oneBasedIndices); generateMaskIfOp(cdt); @@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) { // If this is the same run as the one that saved the value, the clean-up // was left-over to be done now. auto insertionPoint = builder.saveInsertionPoint(); - builder.setInsertionPointAfter(whereLoopNest->outerLoop); + builder.setInsertionPointAfter(whereLoopNest->outerOp); generateCleanupIfAny(maybeSaved->second); builder.restoreInsertionPoint(insertionPoint); } @@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) { mask.generateNoneElementalPart(builder, mapper); mlir::Value shape = mask.generateShape(builder, mapper); whereLoopNest = hlfir::genLoopNest(loc, builder, shape); - constructStack.push_back(whereLoopNest->outerLoop.getOperation()); - builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody()); + constructStack.push_back(whereLoopNest->outerOp); + builder.setInsertionPointToStart(whereLoopNest->body); mlir::Value cdt = generateMaskedEntity(mask); generateMaskIfOp(cdt); return; @@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS( loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest( loc, builder, loweredLhs.vectorSubscriptShape.value()); builder.setInsertionPointToStart( - loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody()); + loweredLhs.vectorSubscriptLoopNest->body); } loweredLhs.lhs = temp->second.fetch(loc, builder); return loweredLhs; @@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS( hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape, !elementalAddrLhs.isOrdered()); builder.setInsertionPointToStart( - loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody()); + loweredLhs.vectorSubscriptLoopNest->body); mapper.map(elementalAddrLhs.getIndices(), loweredLhs.vectorSubscriptLoopNest->oneBasedIndices); for (auto &op : elementalAddrLhs.getBody().front().without_terminator()) @@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) { if (!maskedExpr.noneElementalPartWasGenerated) { // Generate none elemental part before the where loops (but inside the // current forall loops if any). - builder.setInsertionPoint(whereLoopNest->outerLoop); + builder.setInsertionPoint(whereLoopNest->outerOp); maskedExpr.generateNoneElementalPart(builder, mapper); } // Generate the none elemental part cleanup after the where loops. - builder.setInsertionPointAfter(whereLoopNest->outerLoop); + builder.setInsertionPointAfter(whereLoopNest->outerOp); maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper); // Generate the value of the current element for the masked expression // at the current insertion point (inside the where loops, and any fir.if @@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide( LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region); fir::factory::TemporaryStorage *temp = nullptr; if (loweredLhs.vectorSubscriptLoopNest) - constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop); + constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp); if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) { // Vector subscripted entity for which the shape must also be saved on top // of the element addresses (e.g. the shape may change in each forall @@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide( // subscripted LHS. auto &vectorTmp = temp->cast(); auto insertionPoint = builder.saveInsertionPoint(); - builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop); + builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp); vectorTmp.pushShape(loc, builder, shape); builder.restoreInsertionPoint(insertionPoint); } else { @@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide( if (loweredLhs.vectorSubscriptLoopNest) { constructStack.pop_back(); builder.setInsertionPointAfter( - loweredLhs.vectorSubscriptLoopNest->outerLoop); + loweredLhs.vectorSubscriptLoopNest->outerOp); } } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index c5b809514c54c6..c4aed6b79df923 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite( // hlfir.elemental region inside the inner loop hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered()); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); + builder.setInsertionPointToStart(loopNest.body); auto yield = hlfir::inlineElementalOp(loc, builder, elemental, loopNest.oneBasedIndices); hlfir::Entity elementValue{yield.getElementValue()}; @@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite( hlfir::getIndexExtents(loc, builder, shape); hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); + builder.setInsertionPointToStart(loopNest.body); auto arrayElement = hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); builder.create(loc, rhs, arrayElement); @@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite( hlfir::getIndexExtents(loc, builder, shape); hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); - builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); + builder.setInsertionPointToStart(loopNest.body); auto rhsArrayElement = hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);