Skip to content

Commit

Permalink
[OpenMP]Support for lowering masked op (#98401)
Browse files Browse the repository at this point in the history
  • Loading branch information
anchuraj authored Jul 12, 2024
1 parent 83845b1 commit db41a30
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,43 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
llvm_unreachable("Unknown ClauseProcBindKind kind");
}

/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto maskedOp = cast<omp::MaskedOp>(opInst);
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
// relying on captured variables.
LogicalResult bodyGenStatus = success();

auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
// MaskedOp has only one region associated with it.
auto &region = maskedOp.getRegion();
builder.restoreIP(codeGenIP);
convertOmpOpRegions(region, "omp.masked.region", builder, moduleTranslation,
bodyGenStatus);
};

// TODO: Perform finalization actions for variables. This has to be
// called for variables which have destructors/finalizers.
auto finiCB = [&](InsertPointTy codeGenIP) {};

llvm::Value *filterVal = nullptr;
if (auto filterVar = maskedOp.getFilteredThreadId()) {
filterVal = moduleTranslation.lookupValue(filterVar);
} else {
llvm::LLVMContext &llvmContext = builder.getContext();
filterVal =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
}
assert(filterVal != nullptr);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMasked(
ompLoc, bodyGenCB, finiCB, filterVal));
return success();
}

/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
Expand Down Expand Up @@ -3414,6 +3451,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::ParallelOp op) {
return convertOmpParallel(op, builder, moduleTranslation);
})
.Case([&](omp::MaskedOp) {
return convertOmpMasked(*op, builder, moduleTranslation);
})
.Case([&](omp::MasterOp) {
return convertOmpMaster(*op, builder, moduleTranslation);
})
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,26 @@ llvm.func @test_omp_master() -> () {

// -----

// CHECK-LABEL: define void @test_omp_masked({{.*}})
llvm.func @test_omp_masked(%arg0: i32)-> () {
// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @{{.*}})
// CHECK: omp.par.region1:
omp.parallel {
omp.masked filter(%arg0: i32) {
// CHECK: [[OMP_THREAD_3_4:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}})
// CHECK: {{[0-9]+}} = call i32 @__kmpc_masked(ptr @{{[0-9]+}}, i32 [[OMP_THREAD_3_4]], i32 %{{[0-9]+}})
// CHECK: omp.masked.region
// CHECK: call void @__kmpc_end_masked(ptr @{{[0-9]+}}, i32 [[OMP_THREAD_3_4]])
// CHECK: br label %omp_region.end
omp.terminator
}
omp.terminator
}
llvm.return
}

// -----

// CHECK: %struct.ident_t = type
// CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00"
// CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}}
Expand Down

0 comments on commit db41a30

Please sign in to comment.