From 9922aadf9e9d1b9d10dd69882d8515757f127a91 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Tue, 17 Oct 2023 15:00:39 -0500 Subject: [PATCH] [OpenMPIRBuilder] Added `if` clause for `teams` (#69139) This patch adds support for the `if` clause on `teams` construct. The value of the argument must be an integer value. If the value evaluates to true (non-zero) integer, then the number of threads is determined by `num_threads` clause (or default and ICV if `num_threads` is absent). When the condition evaluates to false (zero), then the bounds are set to 1. ([OpenMP 5.2 Section 10.2](https://www.openmp.org/spec-html/5.2/openmpse58.html)) This essentially means that ``` upperbound = ifexpr ? upperbound : 1 lowerbound = ifexpr ? lowerbound : 1 ``` --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 11 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 21 ++- .../Frontend/OpenMPIRBuilderTest.cpp | 146 +++++++++++++++++- 3 files changed, 165 insertions(+), 13 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 9d2adf229b7865..00b4707a7f820d 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1923,11 +1923,12 @@ class OpenMPIRBuilder { /// \param NumTeamsUpper Upper bound on the number of teams. /// \param ThreadLimit on the number of threads that may participate in a /// contention group created by each team. - InsertPointTy createTeams(const LocationDescription &Loc, - BodyGenCallbackTy BodyGenCB, - Value *NumTeamsLower = nullptr, - Value *NumTeamsUpper = nullptr, - Value *ThreadLimit = nullptr); + /// \param IfExpr is the integer argument value of the if condition on the + /// teams clause. + InsertPointTy + createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, + Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr, + Value *ThreadLimit = nullptr, Value *IfExpr = nullptr); /// Generate conditional branch and relevant BasicBlocks through which private /// threads copy the 'copyin' variables from Master copy to threadprivate diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a658990f2d4535..5b24e9fe2e0c5b 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower, - Value *NumTeamsUpper, Value *ThreadLimit) { + Value *NumTeamsUpper, Value *ThreadLimit, + Value *IfExpr) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); // Push num_teams - if (NumTeamsLower || NumTeamsUpper || ThreadLimit) { + if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) { assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) && "if lowerbound is non-null, then upperbound must also be non-null " "for bounds on num_teams"); @@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, if (NumTeamsLower == nullptr) NumTeamsLower = NumTeamsUpper; + if (IfExpr) { + assert(IfExpr->getType()->isIntegerTy() && + "argument to if clause must be an integer value"); + + // upper = ifexpr ? upper : 1 + if (IfExpr->getType() != Int1) + IfExpr = Builder.CreateICmpNE(IfExpr, + ConstantInt::get(IfExpr->getType(), 0)); + NumTeamsUpper = Builder.CreateSelect( + IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper"); + + // lower = ifexpr ? lower : 1 + NumTeamsLower = Builder.CreateSelect( + IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower"); + } + if (ThreadLimit == nullptr) ThreadLimit = Builder.getInt32(0); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index d770facc173025..97cfc339675f65 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) { }; OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB)); + Builder.restoreIP(OMPBuilder.createTeams( + Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr, + /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr)); OMPBuilder.finalize(); Builder.CreateRetVoid(); @@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr, - /*ThreadLimit=*/F->arg_begin())); + /*ThreadLimit=*/F->arg_begin(), + /*IfExpr=*/nullptr)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) { // `num_teams` Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, - /*NumTeamsUpper=*/F->arg_begin())); + /*NumTeamsUpper=*/F->arg_begin(), + /*ThreadLimit=*/nullptr, + /*IfExpr=*/nullptr)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) { // `F` already has an integer argument, so we use that as upper bound to // `num_teams` Builder.restoreIP( - OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper)); + OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, + /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { }; OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, - NumTeamsUpper, ThreadLimit)); + Builder.restoreIP(OMPBuilder.createTeams( + Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); } +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(), + Builder.CreateAlloca(Builder.getInt1Ty())); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + // `F` already has an integer argument, so we use that as upper bound to + // `num_teams` + Builder.restoreIP(OMPBuilder.createTeams( + Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr, + /*ThreadLimit=*/nullptr, IfExpr)); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2); + Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3); + Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4); + + // Check the lower_bound + ASSERT_NE(NumTeamsLower, nullptr); + SelectInst *NumTeamsLowerSelectInst = dyn_cast(NumTeamsLower); + ASSERT_NE(NumTeamsLowerSelectInst, nullptr); + EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr); + EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0)); + EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1)); + + // Check the upper_bound + ASSERT_NE(NumTeamsUpper, nullptr); + SelectInst *NumTeamsUpperSelectInst = dyn_cast(NumTeamsUpper); + ASSERT_NE(NumTeamsUpperSelectInst, nullptr); + EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr); + EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0)); + EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1)); + + // Check thread_limit + EXPECT_EQ(ThreadLimit, Builder.getInt32(0)); +} + +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + Value *IfExpr = Builder.CreateLoad( + Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty())); + Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5)); + Value *NumTeamsUpper = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10)); + Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20)); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + // `F` already has an integer argument, so we use that as upper bound to + // `num_teams` + Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, + NumTeamsUpper, ThreadLimit, IfExpr)); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2); + Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3); + Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4); + + // Get the boolean conversion of if expression + ASSERT_EQ(IfExpr->getNumUses(), 1U); + User *IfExprInst = IfExpr->user_back(); + ICmpInst *IfExprCmpInst = dyn_cast(IfExprInst); + ASSERT_NE(IfExprCmpInst, nullptr); + EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE); + EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr); + EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0)); + + // Check the lower_bound + ASSERT_NE(NumTeamsLowerArg, nullptr); + SelectInst *NumTeamsLowerSelectInst = dyn_cast(NumTeamsLowerArg); + ASSERT_NE(NumTeamsLowerSelectInst, nullptr); + EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst); + EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower); + EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1)); + + // Check the upper_bound + ASSERT_NE(NumTeamsUpperArg, nullptr); + SelectInst *NumTeamsUpperSelectInst = dyn_cast(NumTeamsUpperArg); + ASSERT_NE(NumTeamsUpperSelectInst, nullptr); + EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst); + EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper); + EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1)); + + // Check thread_limit + EXPECT_EQ(ThreadLimitArg, ThreadLimit); +} + /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. template