-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Restructure code in SPIRVConversion.cpp
. NFC.
#99393
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesPatch is 31.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99393.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e3a09ef1ff684..710e39692471a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -40,11 +40,13 @@
using namespace mlir;
+namespace {
+
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
-static int getComputeVectorSize(int64_t size) {
+int getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
@@ -52,7 +54,7 @@ static int getComputeVectorSize(int64_t size) {
return 1;
}
-static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
if (vecType.isScalable()) {
LLVM_DEBUG(llvm::dbgs()
@@ -88,7 +90,7 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
template <typename LabelT>
-static LogicalResult checkExtensionRequirements(
+LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
for (const auto &ors : candidates) {
@@ -116,7 +118,7 @@ static LogicalResult checkExtensionRequirements(
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
/// convention.
template <typename LabelT>
-static LogicalResult checkCapabilityRequirements(
+LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
for (const auto &ors : candidates) {
@@ -139,7 +141,7 @@ static LogicalResult checkCapabilityRequirements(
/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
-static bool needsExplicitLayout(spirv::StorageClass storageClass) {
+bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
case spirv::StorageClass::PushConstant:
@@ -153,8 +155,8 @@ static bool needsExplicitLayout(spirv::StorageClass storageClass) {
/// Wraps the given `elementType` in a struct and gets the pointer to the
/// struct. This is used to satisfy Vulkan interface requirements.
-static spirv::PointerType
-wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
+spirv::PointerType wrapInStructAndGetPointer(Type elementType,
+ spirv::StorageClass storageClass) {
auto structType = needsExplicitLayout(storageClass)
? spirv::StructType::get(elementType, /*offsetInfo=*/0)
: spirv::StructType::get(elementType);
@@ -165,28 +167,16 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
// Type Conversion
//===----------------------------------------------------------------------===//
-static spirv::ScalarType getIndexType(MLIRContext *ctx,
- const SPIRVConversionOptions &options) {
+spirv::ScalarType getIndexType(MLIRContext *ctx,
+ const SPIRVConversionOptions &options) {
return cast<spirv::ScalarType>(
IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
}
-Type SPIRVTypeConverter::getIndexType() const {
- return ::getIndexType(getContext(), options);
-}
-
-MLIRContext *SPIRVTypeConverter::getContext() const {
- return targetEnv.getAttr().getContext();
-}
-
-bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
- return targetEnv.allows(capability);
-}
-
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
-static std::optional<int64_t>
-getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+std::optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
+ Type type) {
if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -266,10 +256,10 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
}
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
-static Type
-convertScalarType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options, spirv::ScalarType type,
- std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertScalarType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options,
+ spirv::ScalarType type,
+ std::optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
@@ -311,8 +301,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
/// the above given that these sub-byte types are not supported at all in
/// SPIR-V; there are no compute/storage capability for them like other
/// supported integer types.
-static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
- IntegerType type) {
+Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
+ IntegerType type) {
if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
return nullptr;
@@ -333,9 +323,8 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
-static ShapedType
-convertIndexElementType(ShapedType type,
- const SPIRVConversionOptions &options) {
+ShapedType convertIndexElementType(ShapedType type,
+ const SPIRVConversionOptions &options) {
Type indexType = dyn_cast<IndexType>(type.getElementType());
if (!indexType)
return type;
@@ -344,10 +333,9 @@ convertIndexElementType(ShapedType type,
}
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
-static Type
-convertVectorType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options, VectorType type,
- std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertVectorType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, VectorType type,
+ std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
@@ -401,10 +389,9 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
-static Type
-convertComplexType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options, ComplexType type,
- std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertComplexType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, ComplexType type,
+ std::optional<spirv::StorageClass> storageClass = {}) {
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -431,9 +418,8 @@ convertComplexType(const spirv::TargetEnv &targetEnv,
/// create composite constants with OpConstantComposite to embed relative large
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
-static Type convertTensorType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options,
- TensorType type) {
+Type convertTensorType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs()
@@ -478,10 +464,9 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
return spirv::ArrayType::get(arrayElemType, arrayElemCount);
}
-static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options,
- MemRefType type,
- spirv::StorageClass storageClass) {
+Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options,
+ MemRefType type, spirv::StorageClass storageClass) {
unsigned numBoolBits = options.boolNumBits;
if (numBoolBits != 8) {
LLVM_DEBUG(llvm::dbgs()
@@ -531,10 +516,10 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
-static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options,
- MemRefType type,
- spirv::StorageClass storageClass) {
+Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options,
+ MemRefType type,
+ spirv::StorageClass storageClass) {
IntegerType elementType = cast<IntegerType>(type.getElementType());
Type arrayElemType = convertSubByteIntegerType(options, elementType);
if (!arrayElemType)
@@ -569,9 +554,8 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
-static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
- const SPIRVConversionOptions &options,
- MemRefType type) {
+Type convertMemrefType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, MemRefType type) {
auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!attr) {
LLVM_DEBUG(
@@ -731,73 +715,134 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
}
//===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
//===----------------------------------------------------------------------===//
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
- const SPIRVConversionOptions &options)
- : targetEnv(targetAttr), options(options) {
- // Add conversions. The order matters here: later ones will be tried earlier.
+spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+ spirv::BuiltIn builtin) {
+ // Look through all global variables in the given `body` block and check if
+ // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+ spirv::SPIRVDialect::getAttributeName(
+ spirv::Decoration::BuiltIn))) {
+ auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+ if (varBuiltIn && *varBuiltIn == builtin) {
+ return varOp;
+ }
+ }
+ }
+ return nullptr;
+}
- // Allow all SPIR-V dialect specific types. This assumes all builtin types
- // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
- // were tried before.
- //
- // TODO: This assumes that the SPIR-V types are valid to use in the given
- // target environment, which should be the case if the whole pipeline is
- // driven by the same target environment. Still, we probably still want to
- // validate and convert to be safe.
- addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+ StringRef suffix) {
+ return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
- addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+ Type integerType, OpBuilder &builder,
+ StringRef prefix, StringRef suffix) {
+ if (auto varOp = getBuiltinVariable(body, builtin))
+ return varOp;
- addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- if (intType.getWidth() < 8)
- return convertSubByteIntegerType(this->options, intType);
- return Type();
- });
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&body);
- addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- return Type();
- });
+ spirv::GlobalVariableOp newVarOp;
+ switch (builtin) {
+ case spirv::BuiltIn::NumWorkgroups:
+ case spirv::BuiltIn::WorkgroupSize:
+ case spirv::BuiltIn::WorkgroupId:
+ case spirv::BuiltIn::LocalInvocationId:
+ case spirv::BuiltIn::GlobalInvocationId: {
+ auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+ spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ case spirv::BuiltIn::SubgroupId:
+ case spirv::BuiltIn::NumSubgroups:
+ case spirv::BuiltIn::SubgroupSize: {
+ auto ptrType =
+ spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ default:
+ emitError(loc, "unimplemented builtin variable generation for ")
+ << stringifyBuiltIn(builtin);
+ }
+ return newVarOp;
+}
- addConversion([this](ComplexType complexType) {
- return convertComplexType(this->targetEnv, this->options, complexType);
- });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
- addConversion([this](VectorType vectorType) {
- return convertVectorType(this->targetEnv, this->options, vectorType);
- });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+ Builder &builder,
+ Type indexType) {
+ auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+ /*stride=*/4);
+ auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+ return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
- addConversion([this](TensorType tensorType) {
- return convertTensorType(this->targetEnv, this->options, tensorType);
- });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+ unsigned elementCount) {
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+ if (!ptrType)
+ continue;
- addConversion([this](MemRefType memRefType) {
- return convertMemrefType(this->targetEnv, this->options, memRefType);
- });
+ // Note that Vulkan requires "There must be no more than one push constant
+ // block statically used per shader entry point." So we should always reuse
+ // the existing one.
+ if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
+ .getNumElements();
+ if (numElements == elementCount)
+ return varOp;
+ }
+ }
+ return nullptr;
+}
- // Register some last line of defense casting logic.
- addSourceMaterialization(
- [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
- return castToSourceType(this->targetEnv, builder, type, inputs, loc);
- });
- addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
- return std::optional<Value>(cast.getResult(0));
- });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc,
+ Block &block,
+ unsigned elementCount,
+ OpBuilder &b,
+ Type indexType) {
+ if (auto varOp = getPushConstantVariable(block, elementCount))
+ return varOp;
+
+ auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+ auto type = getPushConstantStorageType(elementCount, builder, indexType);
+ const char *name = "__push_constant_var__";
+ return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+ /*initializer=*/nullptr);
}
//===----------------------------------------------------------------------===//
// func::FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
-namespace {
/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
@@ -808,7 +853,6 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
-} // namespace
LogicalResult
FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
@@ -855,16 +899,6 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
return success();
}
-void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// func::FuncOp Conversion Patterns
-//===----------------------------------------------------------------------===//
-
-namespace {
/// A pattern for rewriting function signature to convert vect...
[truncated]
|
kuhar
reviewed
Jul 17, 2024
kuhar
changed the title
[mlir][spirv] Restructure code in
[mlir][spirv] Restructure code in Jul 17, 2024
SPIRVConversion.cpp
SPIRVConversion.cpp
. NFC.
angelz913
force-pushed
the
spirv-conversion-code-refactor
branch
from
July 18, 2024 02:31
9947252
to
5125a74
Compare
kuhar
reviewed
Jul 18, 2024
angelz913
force-pushed
the
spirv-conversion-code-refactor
branch
from
July 18, 2024 13:14
5125a74
to
0fe1360
Compare
kuhar
approved these changes
Jul 18, 2024
sgundapa
pushed a commit
to sgundapa/upstream_effort
that referenced
this pull request
Jul 23, 2024
yuxuanchen1997
pushed a commit
that referenced
this pull request
Jul 25, 2024
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251412
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.