diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp index 20f45296a8159a..a147db2cb5d59a 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp @@ -152,6 +152,14 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) { return false; } +/// We clone pure operations in both the parallel and single blocks. this +/// functions cleans them up if they end up with no uses +static void cleanupBlock(Block *block) { + for (Operation &op : llvm::make_early_inc_range(*block)) + if (isOpTriviallyDead(&op)) + op.erase(); +} + static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, IRMapping &rootMapping, Location loc) { OpBuilder rootBuilder(sourceRegion.getContext()); @@ -258,13 +266,8 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, singleOperands.copyprivateVars = moveToSingle(std::get(opOrSingle), allocaBuilder, singleBuilder, parallelBuilder); + cleanupBlock(singleBlock); for (auto var : singleOperands.copyprivateVars) { - Type ty; - if (auto firAlloca = var.getDefiningOp()) { - ty = firAlloca.getAllocatedType(); - } else { - ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext()); - } mlir::func::FuncOp funcOp = createCopyFunc(loc, var.getType(), firCopyFuncBuilder); singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp)); @@ -302,6 +305,9 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, rootBuilder.clone(*block.getTerminator(), rootMapping); } + + for (Block &targetBlock : targetRegion) + cleanupBlock(&targetBlock); } /// Lowers workshare to a sequence of single-thread regions and parallel loops @@ -372,20 +378,6 @@ class LowerWorksharePass lowerWorkshare(wsOp); }); - - // Do folding - for (Operation *isolatedParent : parents) { - RewritePatternSet patterns(&getContext()); - GreedyRewriteConfig config; - // prevent the pattern driver form merging blocks - config.enableRegionSimplification = - mlir::GreedySimplifyRegionLevel::Disabled; - if (failed(applyPatternsAndFoldGreedily(isolatedParent, - std::move(patterns), config))) { - emitError(isolatedParent->getLoc(), "error in lower workshare\n"); - signalPassFailure(); - } - } } }; } // namespace diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir index 84eded94503282..aee95a464a31bd 100644 --- a/flang/test/Transforms/OpenMP/lower-workshare3.mlir +++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir @@ -3,7 +3,7 @@ // tests if the correct values are stored -func.func @wsfunc(%arg0: !fir.ref>) { +func.func @wsfunc() { omp.parallel { // CHECK: fir.alloca // CHECK: fir.alloca diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir new file mode 100644 index 00000000000000..6cff0075b4fe50 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir @@ -0,0 +1,55 @@ +// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s + +func.func @wsfunc() { + %a = fir.alloca i32 + omp.parallel { + omp.workshare { + %t1 = "test.test1"() : () -> i32 + + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : index + + %c2 = arith.constant 2 : index + "test.test3"(%c2) : (index) -> () + + "omp.workshare_loop_wrapper"() ({ + omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) { + "test.test2"() : () -> () + omp.yield + } + omp.terminator + }) : () -> () + omp.terminator + } + omp.terminator + } + return +} + +// CHECK-LABEL: func.func @wsfunc() { +// CHECK: %[[VAL_0:.*]] = fir.alloca i32 +// CHECK: omp.parallel { +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: fir.if %[[VAL_1]] { +// CHECK: omp.single { +// CHECK: %[[VAL_2:.*]] = "test.test1"() : () -> i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: "test.test3"(%[[VAL_3]]) : (index) -> () +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 42 : index +// CHECK: omp.wsloop nowait { +// CHECK: omp.loop_nest (%[[VAL_6:.*]]) : index = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_4]]) { +// CHECK: "test.test2"() : () -> () +// CHECK: omp.yield +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.barrier +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +