diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index e4955fec80b4f3..0eefe06055b7db 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -81,6 +81,10 @@ struct DoacrossClauseOps { IntegerAttr doacrossNumLoopsAttr; }; +struct FilterClauseOps { + Value filteredThreadIdVar; +}; + struct FinalClauseOps { Value finalVar; }; @@ -254,8 +258,7 @@ using DistributeClauseOps = using LoopNestClauseOps = detail::Clauses; -// TODO `filter` clause. -using MaskedClauseOps = detail::Clauses<>; +using MaskedClauseOps = detail::Clauses; using OrderedOpClauseOps = detail::Clauses; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 1fa6edb28a288f..99150bc5dff393 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1204,4 +1204,32 @@ class OpenMP_UseDevicePtrClauseSkip< def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [10.5.1] `filter` clause +//===----------------------------------------------------------------------===// + +class OpenMP_FilterClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + Optional:$filtered_thread_id + ); + + let assemblyFormat = [{ + `filter` `(` $filtered_thread_id `:` type($filtered_thread_id) `)` + }]; + + let description = [{ + If `filter` is specified, the masked construct masks the execution of + the region to only the thread id filtered. Other threads executing the + parallel region are not expected to execute the region specified within + the `masked` directive. If `filter` is not specified, master thread is + expected to execute the region enclosed within `masked` directive. + }]; +} + +def OpenMP_FilterClause : OpenMP_FilterClauseSkip<>; + #endif // OPENMP_CLAUSES diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 99e14cd1b7b483..1a1ca5e71b3e2e 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1577,4 +1577,21 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, let hasRegionVerifier = 1; } +//===----------------------------------------------------------------------===// +// [Spec 5.2] 10.5 masked Construct +//===----------------------------------------------------------------------===// +def MaskedOp : OpenMP_Op<"masked", clauses = [ + OpenMP_FilterClause + ], singleRegion = 1> { + let summary = "masked construct"; + let description = [{ + Masked construct allows to specify a structured block to be executed by a subset of + threads of the current team. + }] # clausesDescription; + + let builders = [ + OpBuilder<(ins CArg<"const MaskedClauseOps &">:$clauses)> + ]; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index abbd857dad67ac..23f291bfc22329 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2578,6 +2578,15 @@ LogicalResult PrivateClauseOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Spec 5.2: Masked construct (10.5) +//===----------------------------------------------------------------------===// + +void MaskedOp::build(OpBuilder &builder, OperationState &state, + const MaskedClauseOps &clauses) { + MaskedOp::build(builder, state, clauses.filteredThreadIdVar); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 2915963f704d37..6a04b9ead746c6 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2358,3 +2358,21 @@ func.func @byref_in_private(%arg0: index) { return } + +// ----- +func.func @masked_arg_type_mismatch(%arg0: f32) { + // expected-error @below {{'omp.masked' op operand #0 must be integer or index, but got 'f32'}} + "omp.masked"(%arg0) ({ + omp.terminator + }) : (f32) -> () + return +} + +// ----- +func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) { + // expected-error @below {{'omp.masked' op operand group starting at #0 requires 0 or 1 element, but found 2}} + "omp.masked"(%arg0, %arg1) ({ + omp.terminator + }) : (i32, i32) -> () + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index eb283840aa7ee5..d6b655dd20ef81 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -16,6 +16,20 @@ func.func @omp_master() -> () { return } +// CHECK-LABEL: omp_masked +func.func @omp_masked(%filtered_thread_id : i32) -> () { + // CHECK: omp.masked filter(%{{.*}} : i32) + "omp.masked" (%filtered_thread_id) ({ + omp.terminator + }) : (i32) -> () + + // CHECK: omp.masked + "omp.masked" () ({ + omp.terminator + }) : () -> () + return +} + func.func @omp_taskwait() -> () { // CHECK: omp.taskwait omp.taskwait