diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index fd89d27db74dc0..682ca06cabd6f6 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1470,15 +1470,19 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter, llvm::SmallVector &async, llvm::SmallVector &asyncDeviceTypes, llvm::SmallVector &asyncOnlyDeviceTypes, - mlir::acc::DeviceTypeAttr deviceTypeAttr, + llvm::SmallVector &deviceTypeAttrs, Fortran::lower::StatementContext &stmtCtx) { const auto &asyncClauseValue = asyncClause->v; if (asyncClauseValue) { // async has a value. - async.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx))); - asyncDeviceTypes.push_back(deviceTypeAttr); + mlir::Value asyncValue = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); + for (auto deviceTypeAttr : deviceTypeAttrs) { + async.push_back(asyncValue); + asyncDeviceTypes.push_back(deviceTypeAttr); + } } else { - asyncOnlyDeviceTypes.push_back(deviceTypeAttr); + for (auto deviceTypeAttr : deviceTypeAttrs) + asyncOnlyDeviceTypes.push_back(deviceTypeAttr); } } @@ -1504,10 +1508,9 @@ getDeviceType(Fortran::common::OpenACCDeviceType device) { } static void gatherDeviceTypeAttrs( - fir::FirOpBuilder &builder, mlir::Location clauseLocation, + fir::FirOpBuilder &builder, const Fortran::parser::AccClause::DeviceType *deviceTypeClause, - llvm::SmallVector &deviceTypes, - Fortran::lower::StatementContext &stmtCtx) { + llvm::SmallVector &deviceTypes) { const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = deviceTypeClause->v; for (const auto &deviceTypeExpr : deviceTypeExprList.v) @@ -1560,20 +1563,25 @@ genWaitClause(Fortran::lower::AbstractConverter &converter, llvm::SmallVector &waitOperandsDeviceTypes, llvm::SmallVector &waitOnlyDeviceTypes, llvm::SmallVector &waitOperandsSegments, - mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr, + mlir::Value &waitDevnum, + llvm::SmallVector deviceTypeAttrs, Fortran::lower::StatementContext &stmtCtx) { const auto &waitClauseValue = waitClause->v; if (waitClauseValue) { // wait has a value. const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; const auto &waitList = std::get>(waitArg.t); - auto crtWaitOperands = waitOperands.size(); + llvm::SmallVector waitValues; for (const Fortran::parser::ScalarIntExpr &value : waitList) { - waitOperands.push_back(fir::getBase(converter.genExprValue( + waitValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(value), stmtCtx))); } - waitOperandsDeviceTypes.push_back(deviceTypeAttr); - waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands); + for (auto deviceTypeAttr : deviceTypeAttrs) { + for (auto value : waitValues) + waitOperands.push_back(value); + waitOperandsDeviceTypes.push_back(deviceTypeAttr); + waitOperandsSegments.push_back(waitValues.size()); + } // TODO: move to device_type model. const auto &waitDevnumValue = @@ -1582,7 +1590,8 @@ genWaitClause(Fortran::lower::AbstractConverter &converter, waitDevnum = fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } else { - waitOnlyDeviceTypes.push_back(deviceTypeAttr); + for (auto deviceTypeAttr : deviceTypeAttrs) + waitOnlyDeviceTypes.push_back(deviceTypeAttr); } } @@ -1610,91 +1619,112 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, // device_type attribute is set to `none` until a device_type clause is // encountered. - auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), mlir::acc::DeviceType::None); + llvm::SmallVector crtDeviceTypes; + crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None)); for (const Fortran::parser::AccClause &clause : accClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (const auto *gangClause = std::get_if(&clause.u)) { if (gangClause->v) { - auto crtGangOperands = gangOperands.size(); const Fortran::parser::AccGangArgList &x = *gangClause->v; + mlir::SmallVector gangValues; + mlir::SmallVector gangArgs; for (const Fortran::parser::AccGangArg &gangArg : x.v) { if (const auto *num = std::get_if(&gangArg.u)) { - gangOperands.push_back(fir::getBase(converter.genExprValue( + gangValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(num->v), stmtCtx))); - gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( builder.getContext(), mlir::acc::GangArgType::Num)); } else if (const auto *staticArg = std::get_if( &gangArg.u)) { const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v; if (sizeExpr.v) { - gangOperands.push_back(fir::getBase(converter.genExprValue( + gangValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx))); } else { // * was passed as value and will be represented as a special // constant. - gangOperands.push_back(builder.createIntegerConstant( + gangValues.push_back(builder.createIntegerConstant( clauseLocation, builder.getIndexType(), starCst)); } - gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( builder.getContext(), mlir::acc::GangArgType::Static)); } else if (const auto *dim = std::get_if( &gangArg.u)) { - gangOperands.push_back(fir::getBase(converter.genExprValue( + gangValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(dim->v), stmtCtx))); - gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get( + gangArgs.push_back(mlir::acc::GangArgTypeAttr::get( builder.getContext(), mlir::acc::GangArgType::Dim)); } } - gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands); - gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + for (const auto &pair : llvm::zip(gangValues, gangArgs)) { + gangOperands.push_back(std::get<0>(pair)); + gangArgTypes.push_back(std::get<1>(pair)); + } + gangOperandsSegments.push_back(gangValues.size()); + gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } } else { - gangDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + gangDeviceTypes.push_back(crtDeviceTypeAttr); } } else if (const auto *workerClause = std::get_if(&clause.u)) { if (workerClause->v) { - workerNumOperands.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx))); - workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Value workerNumValue = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + workerNumOperands.push_back(workerNumValue); + workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } } else { - workerNumDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + workerNumDeviceTypes.push_back(crtDeviceTypeAttr); } } else if (const auto *vectorClause = std::get_if(&clause.u)) { if (vectorClause->v) { - vectorOperands.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx))); - vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Value vectorValue = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + vectorOperands.push_back(vectorValue); + vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + } } else { - vectorDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + vectorDeviceTypes.push_back(crtDeviceTypeAttr); } } else if (const auto *tileClause = std::get_if(&clause.u)) { const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; - auto crtTileOperands = tileOperands.size(); + llvm::SmallVector tileValues; for (const auto &accTileExpr : accTileExprList.v) { const auto &expr = std::get>( accTileExpr.t); if (expr) { - tileOperands.push_back(fir::getBase(converter.genExprValue( + tileValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(*expr), stmtCtx))); } else { // * was passed as value and will be represented as a special // constant. mlir::Value tileStar = builder.createIntegerConstant( clauseLocation, builder.getIntegerType(32), starCst); - tileOperands.push_back(tileStar); + tileValues.push_back(tileStar); } } - tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr); - tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + for (auto value : tileValues) + tileOperands.push_back(value); + tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + tileOperandsSegments.push_back(tileValues.size()); + } } else if (const auto *privateClause = std::get_if( &clause.u)) { @@ -1707,21 +1737,20 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, reductionOperands, reductionRecipes); } else if (std::get_if(&clause.u)) { - seqDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + seqDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if( &clause.u)) { - independentDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + independentDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if(&clause.u)) { - autoDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + autoDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = - deviceTypeClause->v; - assert(deviceTypeExprList.v.size() == 1 && - "expect only one device_type expr"); - crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); + crtDeviceTypes.clear(); + gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); } else if (const auto *collapseClause = std::get_if( &clause.u)) { @@ -1729,14 +1758,18 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, const auto &force = std::get(arg.t); if (force) TODO(clauseLocation, "OpenACC collapse force modifier"); + const auto &intExpr = std::get(arg.t); const auto *expr = Fortran::semantics::GetExpr(intExpr); const std::optional collapseValue = Fortran::evaluate::ToInt64(*expr); assert(collapseValue && "expect integer value for the collapse clause"); - collapseValues.push_back(*collapseValue); - collapseDeviceTypes.push_back(crtDeviceTypeAttr); + + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + collapseValues.push_back(*collapseValue); + collapseDeviceTypes.push_back(crtDeviceTypeAttr); + } } } @@ -1923,45 +1956,56 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, // device_type attribute is set to `none` until a device_type clause is // encountered. + llvm::SmallVector crtDeviceTypes; auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( builder.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); - // Lower clauses values mapped to operands. - // Keep track of each group of operands separatly as clauses can appear + // Lower clauses values mapped to operands and array attributes. + // Keep track of each group of operands separately as clauses can appear // more than once. for (const Fortran::parser::AccClause &clause : accClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (const auto *asyncClause = std::get_if(&clause.u)) { genAsyncClause(converter, asyncClause, async, asyncDeviceTypes, - asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx); + asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { genWaitClause(converter, waitClause, waitOperands, waitOperandsDeviceTypes, waitOnlyDeviceTypes, - waitOperandsSegments, waitDevnum, crtDeviceTypeAttr, - stmtCtx); + waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); } else if (const auto *numGangsClause = std::get_if( &clause.u)) { - auto crtNumGangs = numGangs.size(); + llvm::SmallVector numGangValues; for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v) - numGangs.push_back(fir::getBase(converter.genExprValue( + numGangValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(expr), stmtCtx))); - numGangsDeviceTypes.push_back(crtDeviceTypeAttr); - numGangsSegments.push_back(numGangs.size() - crtNumGangs); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + for (auto value : numGangValues) + numGangs.push_back(value); + numGangsDeviceTypes.push_back(crtDeviceTypeAttr); + numGangsSegments.push_back(numGangValues.size()); + } } else if (const auto *numWorkersClause = std::get_if( &clause.u)) { - numWorkers.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx))); - numWorkersDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Value numWorkerValue = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + numWorkers.push_back(numWorkerValue); + numWorkersDeviceTypes.push_back(crtDeviceTypeAttr); + } } else if (const auto *vectorLengthClause = std::get_if( &clause.u)) { - vectorLength.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx))); - vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + vectorLength.push_back(vectorLengthValue); + vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr); + } } else if (const auto *ifClause = std::get_if(&clause.u)) { genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); @@ -2115,12 +2159,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = - deviceTypeClause->v; - assert(deviceTypeExprList.v.size() == 1 && - "expect only one device_type expr"); - crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); + crtDeviceTypes.clear(); + gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); } } @@ -2239,10 +2279,11 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, // device_type attribute is set to `none` until a device_type clause is // encountered. - auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), mlir::acc::DeviceType::None); + llvm::SmallVector crtDeviceTypes; + crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None)); - // Lower clauses values mapped to operands. + // Lower clauses values mapped to operands and array attributes. // Keep track of each group of operands separately as clauses can appear // more than once. for (const Fortran::parser::AccClause &clause : accClauseList.v) { @@ -2323,19 +2364,23 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *asyncClause = std::get_if(&clause.u)) { genAsyncClause(converter, asyncClause, async, asyncDeviceTypes, - asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx); + asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { genWaitClause(converter, waitClause, waitOperands, waitOperandsDeviceTypes, waitOnlyDeviceTypes, - waitOperandsSegments, waitDevnum, crtDeviceTypeAttr, - stmtCtx); + waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); } else if(const auto *defaultClause = std::get_if(&clause.u)) { if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none) hasDefaultNone = true; else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present) hasDefaultPresent = true; + } else if (const auto *deviceTypeClause = + std::get_if( + &clause.u)) { + crtDeviceTypes.clear(); + gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); } } @@ -2727,8 +2772,7 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause, - deviceTypes, stmtCtx); + gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); } } @@ -2777,8 +2821,7 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause, - deviceTypes, stmtCtx); + gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); } } @@ -2835,8 +2878,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause, - deviceTypes, stmtCtx); + gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); } else if (const auto *hostClause = std::get_if(&clause.u)) { genDataOperandOperations( @@ -3592,15 +3634,16 @@ void Fortran::lower::genOpenACCRoutineConstruct( // device_type attribute is set to `none` until a device_type clause is // encountered. - auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), mlir::acc::DeviceType::None); + llvm::SmallVector crtDeviceTypes; + crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None)); for (const Fortran::parser::AccClause &clause : clauses.v) { if (std::get_if(&clause.u)) { - seqDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + seqDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto *gangClause = std::get_if(&clause.u)) { - if (gangClause->v) { const Fortran::parser::AccGangArgList &x = *gangClause->v; for (const Fortran::parser::AccGangArg &gangArg : x.v) { @@ -3611,27 +3654,36 @@ void Fortran::lower::genOpenACCRoutineConstruct( if (!dimValue) mlir::emitError(loc, "dim value must be a constant positive integer"); - gangDimValues.push_back( - builder.getIntegerAttr(builder.getI64Type(), *dimValue)); - gangDimDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Attribute gangDimAttr = + builder.getIntegerAttr(builder.getI64Type(), *dimValue); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + gangDimValues.push_back(gangDimAttr); + gangDimDeviceTypes.push_back(crtDeviceTypeAttr); + } } } } else { - gangDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + gangDeviceTypes.push_back(crtDeviceTypeAttr); } } else if (std::get_if(&clause.u)) { - vectorDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + vectorDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if(&clause.u)) { - workerDeviceTypes.push_back(crtDeviceTypeAttr); + for (auto crtDeviceTypeAttr : crtDeviceTypes) + workerDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if(&clause.u)) { hasNohost = true; } else if (const auto *bindClause = std::get_if(&clause.u)) { if (const auto *name = std::get_if(&bindClause->v.u)) { - bindNames.push_back( - builder.getStringAttr(converter.mangleName(*name->symbol))); - bindNameDeviceTypes.push_back(crtDeviceTypeAttr); + mlir::Attribute bindNameAttr = + builder.getStringAttr(converter.mangleName(*name->symbol)); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + bindNames.push_back(bindNameAttr); + bindNameDeviceTypes.push_back(crtDeviceTypeAttr); + } } else if (const auto charExpr = std::get_if( &bindClause->v.u)) { @@ -3640,18 +3692,18 @@ void Fortran::lower::genOpenACCRoutineConstruct( *charExpr); if (!name) mlir::emitError(loc, "Could not retrieve the bind name"); - bindNames.push_back(builder.getStringAttr(*name)); - bindNameDeviceTypes.push_back(crtDeviceTypeAttr); + + mlir::Attribute bindNameAttr = builder.getStringAttr(*name); + for (auto crtDeviceTypeAttr : crtDeviceTypes) { + bindNames.push_back(bindNameAttr); + bindNameDeviceTypes.push_back(crtDeviceTypeAttr); + } } } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = - deviceTypeClause->v; - assert(deviceTypeExprList.v.size() == 1 && - "expect only one device_type expr"); - crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); + crtDeviceTypes.clear(); + gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); } } diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90 index 871dbc95f60fcb..ae01d0dc5fcde3 100644 --- a/flang/test/Lower/OpenACC/acc-device-type.f90 +++ b/flang/test/Lower/OpenACC/acc-device-type.f90 @@ -40,5 +40,9 @@ subroutine sub1() ! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type]) + !$acc parallel device_type(nvidia, default) num_gangs(1, 1, 1) + !$acc end parallel + +! CHECK: acc.parallel num_gangs({%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type]) end subroutine diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90 index 42e14afb35f522..59c2513332a976 100644 --- a/flang/test/Lower/OpenACC/acc-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-loop.f90 @@ -326,4 +326,10 @@ program acc_loop ! CHECK: acc.loop gang([#acc.device_type], {num=%c8{{.*}} : i32} [#acc.device_type]) + !$acc loop device_type(nvidia, default) gang + DO i = 1, n + END DO + +! CHECK: acc.loop gang([#acc.device_type, #acc.device_type]) { + end program diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90 index 2fe150e70b0cfb..1170af18bc3341 100644 --- a/flang/test/Lower/OpenACC/acc-routine.f90 +++ b/flang/test/Lower/OpenACC/acc-routine.f90 @@ -2,6 +2,7 @@ ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s +! CHECK: acc.routine @acc_routine_17 func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type], "_QPacc_routine17" [#acc.device_type], "_QPacc_routine16" [#acc.device_type]) ! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type], "_QPacc_routine16" [#acc.device_type]) ! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type]) vector ([#acc.device_type]) ! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type]) seq ([#acc.device_type]) @@ -120,3 +121,7 @@ subroutine acc_routine17() subroutine acc_routine18() !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) end subroutine + +subroutine acc_routine19() + !$acc routine device_type(host,default) bind(acc_routine17) dtype(multicore) bind(acc_routine16) +end subroutine diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 20465f6bb86ed1..bc03adbcae64df 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1449,7 +1449,8 @@ void printGangClause(OpAsmPrinter &p, Operation *op, std::optional segments, std::optional gangOnlyDeviceTypes) { - if (operands.begin() == operands.end() && gangOnlyDeviceTypes && + if (operands.begin() == operands.end() && + hasDeviceTypeValues(gangOnlyDeviceTypes) && gangOnlyDeviceTypes->size() == 1) { auto deviceTypeAttr = mlir::dyn_cast((*gangOnlyDeviceTypes)[0]); @@ -1464,7 +1465,7 @@ void printGangClause(OpAsmPrinter &p, Operation *op, hasDeviceTypeValues(deviceTypes)) p << ", "; - if (deviceTypes) { + if (hasDeviceTypeValues(deviceTypes)) { unsigned opIdx = 0; llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { p << "{";