Skip to content

Commit

Permalink
[OpenMPIRBuilder] Added if clause for teams (llvm#69139)
Browse files Browse the repository at this point in the history
This patch adds support for the `if` clause on `teams` construct. The
value of the argument must be an integer value. If the value evaluates
to true (non-zero) integer, then the number of threads is determined by
`num_threads` clause (or default and ICV if `num_threads` is absent).
When the condition evaluates to false (zero), then the bounds are set to
1. ([OpenMP 5.2 Section
10.2](https://www.openmp.org/spec-html/5.2/openmpse58.html))

This essentially means that
```
upperbound = ifexpr ? upperbound : 1
lowerbound = ifexpr ? lowerbound : 1
```
  • Loading branch information
shraiysh authored Oct 17, 2023
1 parent 122064a commit 9922aad
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 13 deletions.
11 changes: 6 additions & 5 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1923,11 +1923,12 @@ class OpenMPIRBuilder {
/// \param NumTeamsUpper Upper bound on the number of teams.
/// \param ThreadLimit on the number of threads that may participate in a
/// contention group created by each team.
InsertPointTy createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB,
Value *NumTeamsLower = nullptr,
Value *NumTeamsUpper = nullptr,
Value *ThreadLimit = nullptr);
/// \param IfExpr is the integer argument value of the if condition on the
/// teams clause.
InsertPointTy
createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr,
Value *ThreadLimit = nullptr, Value *IfExpr = nullptr);

/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
Expand Down
21 changes: 19 additions & 2 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
Value *NumTeamsUpper, Value *ThreadLimit) {
Value *NumTeamsUpper, Value *ThreadLimit,
Value *IfExpr) {
if (!updateToLocation(Loc))
return InsertPointTy();

Expand Down Expand Up @@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");

// Push num_teams
if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
"if lowerbound is non-null, then upperbound must also be non-null "
"for bounds on num_teams");
Expand All @@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
if (NumTeamsLower == nullptr)
NumTeamsLower = NumTeamsUpper;

if (IfExpr) {
assert(IfExpr->getType()->isIntegerTy() &&
"argument to if clause must be an integer value");

// upper = ifexpr ? upper : 1
if (IfExpr->getType() != Int1)
IfExpr = Builder.CreateICmpNE(IfExpr,
ConstantInt::get(IfExpr->getType(), 0));
NumTeamsUpper = Builder.CreateSelect(
IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");

// lower = ifexpr ? lower : 1
NumTeamsLower = Builder.CreateSelect(
IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
}

if (ThreadLimit == nullptr)
ThreadLimit = Builder.getInt32(0);

Expand Down
146 changes: 140 additions & 6 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
};

OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));

OMPBuilder.finalize();
Builder.CreateRetVoid();
Expand Down Expand Up @@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/F->arg_begin()));
/*ThreadLimit=*/F->arg_begin(),
/*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/F->arg_begin()));
/*NumTeamsUpper=*/F->arg_begin(),
/*ThreadLimit=*/nullptr,
/*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
};

OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
NumTeamsUpper, ThreadLimit));
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
Builder.CreateAlloca(Builder.getInt1Ty()));

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/nullptr, IfExpr));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);
Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);

// Check the lower_bound
ASSERT_NE(NumTeamsLower, nullptr);
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));

// Check the upper_bound
ASSERT_NE(NumTeamsUpper, nullptr);
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));

// Check thread_limit
EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Value *IfExpr = Builder.CreateLoad(
Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
Value *NumTeamsUpper =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
NumTeamsUpper, ThreadLimit, IfExpr));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);
Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);

// Get the boolean conversion of if expression
ASSERT_EQ(IfExpr->getNumUses(), 1U);
User *IfExprInst = IfExpr->user_back();
ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
ASSERT_NE(IfExprCmpInst, nullptr);
EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));

// Check the lower_bound
ASSERT_NE(NumTeamsLowerArg, nullptr);
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));

// Check the upper_bound
ASSERT_NE(NumTeamsUpperArg, nullptr);
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));

// Check thread_limit
EXPECT_EQ(ThreadLimitArg, ThreadLimit);
}

/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>
Expand Down

0 comments on commit 9922aad

Please sign in to comment.