diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index e3a09ef1ff6846..bf5044437fd09d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -40,6 +40,8 @@ using namespace mlir; +namespace { + //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// @@ -171,18 +173,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx, IntegerType::get(ctx, options.use64bitIndex ? 64 : 32)); } -Type SPIRVTypeConverter::getIndexType() const { - return ::getIndexType(getContext(), options); -} - -MLIRContext *SPIRVTypeConverter::getContext() const { - return targetEnv.getAttr().getContext(); -} - -bool SPIRVTypeConverter::allows(spirv::Capability capability) const { - return targetEnv.allows(capability); -} - // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. static std::optional @@ -673,9 +663,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, /// This function is meant to handle the **compute** side; so it does not /// involve storage classes in its logic. The storage side is expected to be /// handled by MemRef conversion logic. -std::optional castToSourceType(const spirv::TargetEnv &targetEnv, - OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { +static std::optional castToSourceType(const spirv::TargetEnv &targetEnv, + OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { // We can only cast one value in SPIR-V. if (inputs.size() != 1) { auto castOp = builder.create(loc, type, inputs); @@ -731,140 +721,185 @@ std::optional castToSourceType(const spirv::TargetEnv &targetEnv, } //===----------------------------------------------------------------------===// -// SPIRVTypeConverter +// Builtin Variables //===----------------------------------------------------------------------===// -SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, - const SPIRVConversionOptions &options) - : targetEnv(targetAttr), options(options) { - // Add conversions. The order matters here: later ones will be tried earlier. +static spirv::GlobalVariableOp getBuiltinVariable(Block &body, + spirv::BuiltIn builtin) { + // Look through all global variables in the given `body` block and check if + // there is a spirv.GlobalVariable that has the same `builtin` attribute. + for (auto varOp : body.getOps()) { + if (auto builtinAttr = varOp->getAttrOfType( + spirv::SPIRVDialect::getAttributeName( + spirv::Decoration::BuiltIn))) { + auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); + if (varBuiltIn && *varBuiltIn == builtin) { + return varOp; + } + } + } + return nullptr; +} - // Allow all SPIR-V dialect specific types. This assumes all builtin types - // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) - // were tried before. - // - // TODO: This assumes that the SPIR-V types are valid to use in the given - // target environment, which should be the case if the whole pipeline is - // driven by the same target environment. Still, we probably still want to - // validate and convert to be safe. - addConversion([](spirv::SPIRVType type) { return type; }); +/// Gets name of global variable for a builtin. +std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, + StringRef suffix) { + return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); +} - addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); +/// Gets or inserts a global variable for a builtin within `body` block. +static spirv::GlobalVariableOp +getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, + Type integerType, OpBuilder &builder, + StringRef prefix, StringRef suffix) { + if (auto varOp = getBuiltinVariable(body, builtin)) + return varOp; - addConversion([this](IntegerType intType) -> std::optional { - if (auto scalarType = dyn_cast(intType)) - return convertScalarType(this->targetEnv, this->options, scalarType); - if (intType.getWidth() < 8) - return convertSubByteIntegerType(this->options, intType); - return Type(); - }); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&body); - addConversion([this](FloatType floatType) -> std::optional { - if (auto scalarType = dyn_cast(floatType)) - return convertScalarType(this->targetEnv, this->options, scalarType); - return Type(); - }); + spirv::GlobalVariableOp newVarOp; + switch (builtin) { + case spirv::BuiltIn::NumWorkgroups: + case spirv::BuiltIn::WorkgroupSize: + case spirv::BuiltIn::WorkgroupId: + case spirv::BuiltIn::LocalInvocationId: + case spirv::BuiltIn::GlobalInvocationId: { + auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), + spirv::StorageClass::Input); + std::string name = getBuiltinVarName(builtin, prefix, suffix); + newVarOp = + builder.create(loc, ptrType, name, builtin); + break; + } + case spirv::BuiltIn::SubgroupId: + case spirv::BuiltIn::NumSubgroups: + case spirv::BuiltIn::SubgroupSize: { + auto ptrType = + spirv::PointerType::get(integerType, spirv::StorageClass::Input); + std::string name = getBuiltinVarName(builtin, prefix, suffix); + newVarOp = + builder.create(loc, ptrType, name, builtin); + break; + } + default: + emitError(loc, "unimplemented builtin variable generation for ") + << stringifyBuiltIn(builtin); + } + return newVarOp; +} - addConversion([this](ComplexType complexType) { - return convertComplexType(this->targetEnv, this->options, complexType); - }); +//===----------------------------------------------------------------------===// +// Push constant storage +//===----------------------------------------------------------------------===// - addConversion([this](VectorType vectorType) { - return convertVectorType(this->targetEnv, this->options, vectorType); - }); +/// Returns the pointer type for the push constant storage containing +/// `elementCount` 32-bit integer values. +static spirv::PointerType getPushConstantStorageType(unsigned elementCount, + Builder &builder, + Type indexType) { + auto arrayType = spirv::ArrayType::get(indexType, elementCount, + /*stride=*/4); + auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); + return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); +} - addConversion([this](TensorType tensorType) { - return convertTensorType(this->targetEnv, this->options, tensorType); - }); +/// Returns the push constant varible containing `elementCount` 32-bit integer +/// values in `body`. Returns null op if such an op does not exit. +static spirv::GlobalVariableOp getPushConstantVariable(Block &body, + unsigned elementCount) { + for (auto varOp : body.getOps()) { + auto ptrType = dyn_cast(varOp.getType()); + if (!ptrType) + continue; - addConversion([this](MemRefType memRefType) { - return convertMemrefType(this->targetEnv, this->options, memRefType); - }); + // Note that Vulkan requires "There must be no more than one push constant + // block statically used per shader entry point." So we should always reuse + // the existing one. + if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { + auto numElements = cast( + cast(ptrType.getPointeeType()) + .getElementType(0)) + .getNumElements(); + if (numElements == elementCount) + return varOp; + } + } + return nullptr; +} - // Register some last line of defense casting logic. - addSourceMaterialization( - [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - return castToSourceType(this->targetEnv, builder, type, inputs, loc); - }); - addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }); +/// Gets or inserts a global variable for push constant storage containing +/// `elementCount` 32-bit integer values in `block`. +static spirv::GlobalVariableOp +getOrInsertPushConstantVariable(Location loc, Block &block, + unsigned elementCount, OpBuilder &b, + Type indexType) { + if (auto varOp = getPushConstantVariable(block, elementCount)) + return varOp; + + auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); + auto type = getPushConstantStorageType(elementCount, builder, indexType); + const char *name = "__push_constant_var__"; + return builder.create(loc, type, name, + /*initializer=*/nullptr); } //===----------------------------------------------------------------------===// // func::FuncOp Conversion Patterns //===----------------------------------------------------------------------===// -namespace { /// A pattern for rewriting function signature to convert arguments of functions /// to be of valid SPIR-V types. -class FuncOpConversion final : public OpConversionPattern { -public: +struct FuncOpConversion final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - -LogicalResult -FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto fnType = funcOp.getFunctionType(); - if (fnType.getNumResults() > 1) - return failure(); - - TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - for (const auto &argType : enumerate(fnType.getInputs())) { - auto convertedType = getTypeConverter()->convertType(argType.value()); - if (!convertedType) - return failure(); - signatureConverter.addInputs(argType.index(), convertedType); - } - - Type resultType; - if (fnType.getNumResults() == 1) { - resultType = getTypeConverter()->convertType(fnType.getResult(0)); - if (!resultType) + ConversionPatternRewriter &rewriter) const override { + FunctionType fnType = funcOp.getFunctionType(); + if (fnType.getNumResults() > 1) return failure(); - } - - // Create the converted spirv.func op. - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), - rewriter.getFunctionType(signatureConverter.getConvertedTypes(), - resultType ? TypeRange(resultType) - : TypeRange())); - // Copy over all attributes other than the function name and type. - for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && - namedAttr.getName() != SymbolTable::getSymbolAttrName()) - newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); - } + TypeConverter::SignatureConversion signatureConverter( + fnType.getNumInputs()); + for (const auto &argType : enumerate(fnType.getInputs())) { + auto convertedType = getTypeConverter()->convertType(argType.value()); + if (!convertedType) + return failure(); + signatureConverter.addInputs(argType.index(), convertedType); + } - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - if (failed(rewriter.convertRegionTypes( - &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) - return failure(); - rewriter.eraseOp(funcOp); - return success(); -} + Type resultType; + if (fnType.getNumResults() == 1) { + resultType = getTypeConverter()->convertType(fnType.getResult(0)); + if (!resultType) + return failure(); + } -void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); -} + // Create the converted spirv.func op. + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), + rewriter.getFunctionType(signatureConverter.getConvertedTypes(), + resultType ? TypeRange(resultType) + : TypeRange())); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } -//===----------------------------------------------------------------------===// -// func::FuncOp Conversion Patterns -//===----------------------------------------------------------------------===// + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + if (failed(rewriter.convertRegionTypes( + &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) + return failure(); + rewriter.eraseOp(funcOp); + return success(); + } +}; -namespace { /// A pattern for rewriting function signature to convert vector arguments of /// functions to be of valid types struct FuncOpVectorUnroll final : OpRewritePattern { @@ -1015,17 +1050,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern { return success(); } }; -} // namespace - -void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} //===----------------------------------------------------------------------===// // func::ReturnOp Conversion Patterns //===----------------------------------------------------------------------===// -namespace { /// A pattern for rewriting function signature and the return op to convert /// vectors to be of valid types. struct ReturnOpVectorUnroll final : OpRewritePattern { @@ -1097,81 +1126,13 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { return success(); } }; -} // namespace -void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} +} // namespace //===----------------------------------------------------------------------===// -// Builtin Variables +// Public function for builtin variables //===----------------------------------------------------------------------===// -static spirv::GlobalVariableOp getBuiltinVariable(Block &body, - spirv::BuiltIn builtin) { - // Look through all global variables in the given `body` block and check if - // there is a spirv.GlobalVariable that has the same `builtin` attribute. - for (auto varOp : body.getOps()) { - if (auto builtinAttr = varOp->getAttrOfType( - spirv::SPIRVDialect::getAttributeName( - spirv::Decoration::BuiltIn))) { - auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); - if (varBuiltIn && *varBuiltIn == builtin) { - return varOp; - } - } - } - return nullptr; -} - -/// Gets name of global variable for a builtin. -static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, - StringRef suffix) { - return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); -} - -/// Gets or inserts a global variable for a builtin within `body` block. -static spirv::GlobalVariableOp -getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, - Type integerType, OpBuilder &builder, - StringRef prefix, StringRef suffix) { - if (auto varOp = getBuiltinVariable(body, builtin)) - return varOp; - - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&body); - - spirv::GlobalVariableOp newVarOp; - switch (builtin) { - case spirv::BuiltIn::NumWorkgroups: - case spirv::BuiltIn::WorkgroupSize: - case spirv::BuiltIn::WorkgroupId: - case spirv::BuiltIn::LocalInvocationId: - case spirv::BuiltIn::GlobalInvocationId: { - auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), - spirv::StorageClass::Input); - std::string name = getBuiltinVarName(builtin, prefix, suffix); - newVarOp = - builder.create(loc, ptrType, name, builtin); - break; - } - case spirv::BuiltIn::SubgroupId: - case spirv::BuiltIn::NumSubgroups: - case spirv::BuiltIn::SubgroupSize: { - auto ptrType = - spirv::PointerType::get(integerType, spirv::StorageClass::Input); - std::string name = getBuiltinVarName(builtin, prefix, suffix); - newVarOp = - builder.create(loc, ptrType, name, builtin); - break; - } - default: - emitError(loc, "unimplemented builtin variable generation for ") - << stringifyBuiltIn(builtin); - } - return newVarOp; -} - Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, @@ -1190,60 +1151,9 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, } //===----------------------------------------------------------------------===// -// Push constant storage +// Public function for pushing constant storage //===----------------------------------------------------------------------===// -/// Returns the pointer type for the push constant storage containing -/// `elementCount` 32-bit integer values. -static spirv::PointerType getPushConstantStorageType(unsigned elementCount, - Builder &builder, - Type indexType) { - auto arrayType = spirv::ArrayType::get(indexType, elementCount, - /*stride=*/4); - auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); - return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); -} - -/// Returns the push constant varible containing `elementCount` 32-bit integer -/// values in `body`. Returns null op if such an op does not exit. -static spirv::GlobalVariableOp getPushConstantVariable(Block &body, - unsigned elementCount) { - for (auto varOp : body.getOps()) { - auto ptrType = dyn_cast(varOp.getType()); - if (!ptrType) - continue; - - // Note that Vulkan requires "There must be no more than one push constant - // block statically used per shader entry point." So we should always reuse - // the existing one. - if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { - auto numElements = cast( - cast(ptrType.getPointeeType()) - .getElementType(0)) - .getNumElements(); - if (numElements == elementCount) - return varOp; - } - } - return nullptr; -} - -/// Gets or inserts a global variable for push constant storage containing -/// `elementCount` 32-bit integer values in `block`. -static spirv::GlobalVariableOp -getOrInsertPushConstantVariable(Location loc, Block &block, - unsigned elementCount, OpBuilder &b, - Type indexType) { - if (auto varOp = getPushConstantVariable(block, elementCount)) - return varOp; - - auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); - auto type = getPushConstantStorageType(elementCount, builder, indexType); - const char *name = "__push_constant_var__"; - return builder.create(loc, type, name, - /*initializer=*/nullptr); -} - Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder) { @@ -1267,7 +1177,7 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, } //===----------------------------------------------------------------------===// -// Index calculation +// Public functions for index calculation //===----------------------------------------------------------------------===// Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, @@ -1375,6 +1285,81 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, builder); } +//===----------------------------------------------------------------------===// +// SPIR-V TypeConverter +//===----------------------------------------------------------------------===// + +SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, + const SPIRVConversionOptions &options) + : targetEnv(targetAttr), options(options) { + // Add conversions. The order matters here: later ones will be tried earlier. + + // Allow all SPIR-V dialect specific types. This assumes all builtin types + // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) + // were tried before. + // + // TODO: This assumes that the SPIR-V types are valid to use in the given + // target environment, which should be the case if the whole pipeline is + // driven by the same target environment. Still, we probably still want to + // validate and convert to be safe. + addConversion([](spirv::SPIRVType type) { return type; }); + + addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); + + addConversion([this](IntegerType intType) -> std::optional { + if (auto scalarType = dyn_cast(intType)) + return convertScalarType(this->targetEnv, this->options, scalarType); + if (intType.getWidth() < 8) + return convertSubByteIntegerType(this->options, intType); + return Type(); + }); + + addConversion([this](FloatType floatType) -> std::optional { + if (auto scalarType = dyn_cast(floatType)) + return convertScalarType(this->targetEnv, this->options, scalarType); + return Type(); + }); + + addConversion([this](ComplexType complexType) { + return convertComplexType(this->targetEnv, this->options, complexType); + }); + + addConversion([this](VectorType vectorType) { + return convertVectorType(this->targetEnv, this->options, vectorType); + }); + + addConversion([this](TensorType tensorType) { + return convertTensorType(this->targetEnv, this->options, tensorType); + }); + + addConversion([this](MemRefType memRefType) { + return convertMemrefType(this->targetEnv, this->options, memRefType); + }); + + // Register some last line of defense casting logic. + addSourceMaterialization( + [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { + return castToSourceType(this->targetEnv, builder, type, inputs, loc); + }); + addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return std::optional(cast.getResult(0)); + }); +} + +Type SPIRVTypeConverter::getIndexType() const { + return ::getIndexType(getContext(), options); +} + +MLIRContext *SPIRVTypeConverter::getContext() const { + return targetEnv.getAttr().getContext(); +} + +bool SPIRVTypeConverter::allows(spirv::Capability capability) const { + return targetEnv.allows(capability); +} + //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// @@ -1468,3 +1453,20 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { return true; } + +//===----------------------------------------------------------------------===// +// Public functions for populating patterns +//===----------------------------------------------------------------------===// + +void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(typeConverter, patterns.getContext()); +} + +void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +}