diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index d507e58b164dd8..433efb16c2d699 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1044,14 +1044,6 @@ bool ClauseProcessor::processReduction( }); } -bool ClauseProcessor::processSectionsReduction( - mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const { - return findRepeatableClause( - [&](const omp::clause::Reduction &, const parser::CharBlock &) { - TODO(currentLocation, "OMPC_Reduction"); - }); -} - bool ClauseProcessor::processTo( llvm::SmallVectorImpl &result) const { return findRepeatableClause( diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 43795d5c253996..ff39eb72ff24c8 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -125,8 +125,6 @@ class ClauseProcessor { llvm::SmallVectorImpl *reductionTypes = nullptr, llvm::SmallVectorImpl *reductionSyms = nullptr) const; - bool processSectionsReduction(mlir::Location currentLocation, - mlir::omp::ReductionClauseOps &result) const; bool processTo(llvm::SmallVectorImpl &result) const; bool processUseDeviceAddr( mlir::omp::UseDeviceAddrClauseOps &result, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 17804ff58edc03..d60f1dd43a1c48 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1068,13 +1068,15 @@ static void genParallelClauses( cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); } -static void genSectionsClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - const List &clauses, mlir::Location loc, - mlir::omp::SectionsClauseOps &clauseOps) { +static void genSectionsClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::SectionsClauseOps &clauseOps, + llvm::SmallVectorImpl &reductionTypes, + llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); - cp.processSectionsReduction(loc, clauseOps); + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); cp.processNowait(clauseOps); // TODO Support delayed privatization. } @@ -1481,27 +1483,20 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return genOpWithBody(genInfo, queue, item, clauseOps); } -static mlir::omp::SectionOp -genSectionOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::Location loc, const ConstructQueue &queue, - ConstructQueue::iterator item) { - // Currently only private/firstprivate clause is handled, and - // all privatization is done within `omp.section` operations. - return genOpWithBody( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, - llvm::omp::Directive::OMPD_section) - .setClauses(&item->clauses), - queue, item); -} - +/// This breaks the normal prototype of the gen*Op functions: adding the +/// sectionBlocks argument so that the enclosed section constructs can be +/// lowered here with correct reduction symbol remapping. static mlir::omp::SectionsOp genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, ConstructQueue::iterator item) { + const ConstructQueue &queue, ConstructQueue::iterator item, + const parser::OmpSectionBlocks §ionBlocks) { + llvm::SmallVector reductionTypes; + llvm::SmallVector reductionSyms; mlir::omp::SectionsClauseOps clauseOps; - genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps); + genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, + reductionTypes, reductionSyms); auto &builder = converter.getFirOpBuilder(); @@ -1530,11 +1525,52 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, } // SECTIONS construct. - mlir::omp::SectionsOp sectionsOp = genOpWithBody( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, - llvm::omp::Directive::OMPD_sections) - .setClauses(&nonDsaClauses), - queue, item, clauseOps); + auto sectionsOp = builder.create(loc, clauseOps); + + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + return reductionSyms; + }; + + reductionCallback(sectionsOp); + // genReductionVars adds a hlfir.declare for the reduction block argument + // but only terminators and sectionOps are allowed inside of a SectionsOp + llvm::SmallVector toErase; + toErase.reserve(reductionSyms.size()); + for (auto decl : sectionsOp.getOps()) + toErase.push_back(decl); + for (mlir::Operation *op : toErase) + op->erase(); + + mlir::Operation *terminator = + lower::genOpenMPTerminator(builder, sectionsOp, loc); + + // Generate nested SECTION constructs. + // This is done here rather than in genOMP([...], OpenMPSectionConstruct ) + // because we need to run genReductionVars on each omp.section so that the + // reduction variable gets mapped to the private version + for (auto [construct, nestedEval] : + llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { + const auto *sectionConstruct = + std::get_if(&construct.u); + if (!sectionConstruct) { + assert(false && + "unexpected construct nested inside of SECTIONS construct"); + continue; + } + + ConstructQueue sectionQueue{buildConstructQueue( + converter.getFirOpBuilder().getModule(), semaCtx, nestedEval, + sectionConstruct->source, llvm::omp::Directive::OMPD_section, {})}; + + builder.setInsertionPoint(terminator); + genOpWithBody( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, + llvm::omp::Directive::OMPD_section) + .setClauses(§ionQueue.begin()->clauses) + .setGenRegionEntryCb(reductionCallback), + sectionQueue, sectionQueue.begin()); + } if (!lastprivates.empty()) { mlir::Region §ionsBody = sectionsOp.getRegion(); @@ -2049,10 +2085,11 @@ static void genOMPDispatch(lower::AbstractConverter &converter, genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item); break; case llvm::omp::Directive::OMPD_section: - genSectionOp(converter, symTable, semaCtx, eval, loc, queue, item); + // Lowered in the enclosing genSectionsOp. break; case llvm::omp::Directive::OMPD_sections: - genSectionsOp(converter, symTable, semaCtx, eval, loc, queue, item); + // Called directly from genOMP([...], OpenMPSectionsConstruct) because it + // has a different prototype. break; case llvm::omp::Directive::OMPD_simd: genSimdOp(converter, symTable, semaCtx, eval, loc, queue, item, *loopDsp); @@ -2464,11 +2501,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPSectionConstruct §ionConstruct) { - mlir::Location loc = converter.getCurrentLocation(); - ConstructQueue queue{buildConstructQueue( - converter.getFirOpBuilder().getModule(), semaCtx, eval, - sectionConstruct.source, llvm::omp::Directive::OMPD_section, {})}; - genOMPDispatch(converter, symTable, semaCtx, eval, loc, queue, queue.begin()); + // Do nothing here. SECTION is lowered inside of the lowering for Sections } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, @@ -2481,6 +2514,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, std::get(beginSectionsDirective.t), semaCtx); const auto &endSectionsDirective = std::get(sectionsConstruct.t); + const auto §ionBlocks = + std::get(sectionsConstruct.t); clauses.append(makeClauses( std::get(endSectionsDirective.t), semaCtx)); mlir::Location currentLocation = converter.getCurrentLocation(); @@ -2492,8 +2527,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue queue{ buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx, eval, source, directive, clauses)}; - genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, - queue.begin()); + ConstructQueue::iterator next = queue.begin(); + // Generate constructs that come first e.g. Parallel + while (next != queue.end() && + next->id != llvm::omp::Directive::OMPD_sections) { + genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, + next); + next = std::next(next); + } + + // call genSectionsOp directly (not via genOMPDispatch) so that we can add the + // sectionBlocks argument + assert(next != queue.end()); + assert(next->id == llvm::omp::Directive::OMPD_sections); + genSectionsOp(converter, symTable, semaCtx, eval, currentLocation, queue, + next, sectionBlocks); + assert(std::next(next) == queue.end()); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, diff --git a/flang/test/Lower/OpenMP/sections-reduction.f90 b/flang/test/Lower/OpenMP/sections-reduction.f90 new file mode 100644 index 00000000000000..854f9ea22a7ddd --- /dev/null +++ b/flang/test/Lower/OpenMP/sections-reduction.f90 @@ -0,0 +1,105 @@ +! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +subroutine sectionsReduction(x,y) + real :: x, y + + !$omp parallel + !$omp sections reduction(+:x,y) + x = x + 1 + y = x + !$omp section + x = x + 2 + y = x + !$omp end sections + !$omp end parallel + + !$omp parallel sections reduction(+:x) reduction(+:y) + x = x + 1 + y = x + !$omp section + x = x + 2 + y = x + !$omp end parallel sections +end subroutine + +! CHECK-LABEL: omp.declare_reduction @add_reduction_f32 : f32 init { +! CHECK: ^bb0(%[[VAL_0:.*]]: f32): +! CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +! CHECK: omp.yield(%[[VAL_1]] : f32) +! CHECK-LABEL: } combiner { +! CHECK: ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32): +! CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] fastmath : f32 +! CHECK: omp.yield(%[[VAL_2]] : f32) +! CHECK: } + +! CHECK-LABEL: func.func @_QPsectionsreduction( +! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref {fir.bindc_name = "x"}, +! CHECK-SAME: %[[VAL_1:.*]]: !fir.ref {fir.bindc_name = "y"}) { +! CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) +! CHECK: omp.parallel { +! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref) { +! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref, %[[VAL_6:.*]]: !fir.ref): +! CHECK: omp.section { +! CHECK: ^bb0(%[[VAL_7:.*]]: !fir.ref, %[[VAL_8:.*]]: !fir.ref): +! CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_7]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref +! CHECK: %[[VAL_12:.*]] = arith.constant 1.000000e+00 : f32 +! CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath : f32 +! CHECK: hlfir.assign %[[VAL_13]] to %[[VAL_9]]#0 : f32, !fir.ref +! CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref +! CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_10]]#0 : f32, !fir.ref +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.section { +! CHECK: ^bb0(%[[VAL_15:.*]]: !fir.ref, %[[VAL_16:.*]]: !fir.ref): +! CHECK: %[[VAL_17:.*]]:2 = hlfir.declare %[[VAL_15]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_18:.*]]:2 = hlfir.declare %[[VAL_16]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_17]]#0 : !fir.ref +! CHECK: %[[VAL_20:.*]] = arith.constant 2.000000e+00 : f32 +! CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] fastmath : f32 +! CHECK: hlfir.assign %[[VAL_21]] to %[[VAL_17]]#0 : f32, !fir.ref +! CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_17]]#0 : !fir.ref +! CHECK: hlfir.assign %[[VAL_22]] to %[[VAL_18]]#0 : f32, !fir.ref +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.parallel { +! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref) { +! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref, %[[VAL_24:.*]]: !fir.ref): +! CHECK: omp.section { +! CHECK: ^bb0(%[[VAL_25:.*]]: !fir.ref, %[[VAL_26:.*]]: !fir.ref): +! CHECK: %[[VAL_27:.*]]:2 = hlfir.declare %[[VAL_25]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_28:.*]]:2 = hlfir.declare %[[VAL_26]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_29:.*]] = fir.load %[[VAL_27]]#0 : !fir.ref +! CHECK: %[[VAL_30:.*]] = arith.constant 1.000000e+00 : f32 +! CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] fastmath : f32 +! CHECK: hlfir.assign %[[VAL_31]] to %[[VAL_27]]#0 : f32, !fir.ref +! CHECK: %[[VAL_32:.*]] = fir.load %[[VAL_27]]#0 : !fir.ref +! CHECK: hlfir.assign %[[VAL_32]] to %[[VAL_28]]#0 : f32, !fir.ref +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.section { +! CHECK: ^bb0(%[[VAL_33:.*]]: !fir.ref, %[[VAL_34:.*]]: !fir.ref): +! CHECK: %[[VAL_35:.*]]:2 = hlfir.declare %[[VAL_33]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_36:.*]]:2 = hlfir.declare %[[VAL_34]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_37:.*]] = fir.load %[[VAL_35]]#0 : !fir.ref +! CHECK: %[[VAL_38:.*]] = arith.constant 2.000000e+00 : f32 +! CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] fastmath : f32 +! CHECK: hlfir.assign %[[VAL_39]] to %[[VAL_35]]#0 : f32, !fir.ref +! CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_35]]#0 : !fir.ref +! CHECK: hlfir.assign %[[VAL_40]] to %[[VAL_36]]#0 : f32, !fir.ref +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: return +! CHECK: }