Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Restructure code in SPIRVConversion.cpp. NFC. #99393

Merged
merged 1 commit into from
Jul 18, 2024

Conversation

angelz913
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

Changes

Patch is 31.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99393.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+245-244)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e3a09ef1ff684..710e39692471a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -40,11 +40,13 @@
 
 using namespace mlir;
 
+namespace {
+
 //===----------------------------------------------------------------------===//
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-static int getComputeVectorSize(int64_t size) {
+int getComputeVectorSize(int64_t size) {
   for (int i : {4, 3, 2}) {
     if (size % i == 0)
       return i;
@@ -52,7 +54,7 @@ static int getComputeVectorSize(int64_t size) {
   return 1;
 }
 
-static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
   LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
   if (vecType.isScalable()) {
     LLVM_DEBUG(llvm::dbgs()
@@ -88,7 +90,7 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
 /// convention.
 template <typename LabelT>
-static LogicalResult checkExtensionRequirements(
+LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
@@ -116,7 +118,7 @@ static LogicalResult checkExtensionRequirements(
 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
 /// convention.
 template <typename LabelT>
-static LogicalResult checkCapabilityRequirements(
+LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
@@ -139,7 +141,7 @@ static LogicalResult checkCapabilityRequirements(
 
 /// Returns true if the given `storageClass` needs explicit layout when used in
 /// Shader environments.
-static bool needsExplicitLayout(spirv::StorageClass storageClass) {
+bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:
   case spirv::StorageClass::PushConstant:
@@ -153,8 +155,8 @@ static bool needsExplicitLayout(spirv::StorageClass storageClass) {
 
 /// Wraps the given `elementType` in a struct and gets the pointer to the
 /// struct. This is used to satisfy Vulkan interface requirements.
-static spirv::PointerType
-wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
+spirv::PointerType wrapInStructAndGetPointer(Type elementType,
+                                             spirv::StorageClass storageClass) {
   auto structType = needsExplicitLayout(storageClass)
                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
                         : spirv::StructType::get(elementType);
@@ -165,28 +167,16 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-static spirv::ScalarType getIndexType(MLIRContext *ctx,
-                                      const SPIRVConversionOptions &options) {
+spirv::ScalarType getIndexType(MLIRContext *ctx,
+                               const SPIRVConversionOptions &options) {
   return cast<spirv::ScalarType>(
       IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
 }
 
-Type SPIRVTypeConverter::getIndexType() const {
-  return ::getIndexType(getContext(), options);
-}
-
-MLIRContext *SPIRVTypeConverter::getContext() const {
-  return targetEnv.getAttr().getContext();
-}
-
-bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
-  return targetEnv.allows(capability);
-}
-
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
-static std::optional<int64_t>
-getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+std::optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
+                                       Type type) {
   if (isa<spirv::ScalarType>(type)) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -266,10 +256,10 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
 }
 
 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
-static Type
-convertScalarType(const spirv::TargetEnv &targetEnv,
-                  const SPIRVConversionOptions &options, spirv::ScalarType type,
-                  std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertScalarType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options,
+                       spirv::ScalarType type,
+                       std::optional<spirv::StorageClass> storageClass = {}) {
   // Get extension and capability requirements for the given type.
   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
@@ -311,8 +301,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
 /// the above given that these sub-byte types are not supported at all in
 /// SPIR-V; there are no compute/storage capability for them like other
 /// supported integer types.
-static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
-                                      IntegerType type) {
+Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
+                               IntegerType type) {
   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
     return nullptr;
@@ -333,9 +323,8 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
-static ShapedType
-convertIndexElementType(ShapedType type,
-                        const SPIRVConversionOptions &options) {
+ShapedType convertIndexElementType(ShapedType type,
+                                   const SPIRVConversionOptions &options) {
   Type indexType = dyn_cast<IndexType>(type.getElementType());
   if (!indexType)
     return type;
@@ -344,10 +333,9 @@ convertIndexElementType(ShapedType type,
 }
 
 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
-static Type
-convertVectorType(const spirv::TargetEnv &targetEnv,
-                  const SPIRVConversionOptions &options, VectorType type,
-                  std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertVectorType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, VectorType type,
+                       std::optional<spirv::StorageClass> storageClass = {}) {
   type = cast<VectorType>(convertIndexElementType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
@@ -401,10 +389,9 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   return nullptr;
 }
 
-static Type
-convertComplexType(const spirv::TargetEnv &targetEnv,
-                   const SPIRVConversionOptions &options, ComplexType type,
-                   std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertComplexType(const spirv::TargetEnv &targetEnv,
+                        const SPIRVConversionOptions &options, ComplexType type,
+                        std::optional<spirv::StorageClass> storageClass = {}) {
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
@@ -431,9 +418,8 @@ convertComplexType(const spirv::TargetEnv &targetEnv,
 /// create composite constants with OpConstantComposite to embed relative large
 /// constant values and use OpCompositeExtract and OpCompositeInsert to
 /// manipulate, like what we do for vectors.
-static Type convertTensorType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVConversionOptions &options,
-                              TensorType type) {
+Type convertTensorType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, TensorType type) {
   // TODO: Handle dynamic shapes.
   if (!type.hasStaticShape()) {
     LLVM_DEBUG(llvm::dbgs()
@@ -478,10 +464,9 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   return spirv::ArrayType::get(arrayElemType, arrayElemCount);
 }
 
-static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
-                                  const SPIRVConversionOptions &options,
-                                  MemRefType type,
-                                  spirv::StorageClass storageClass) {
+Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
+                           const SPIRVConversionOptions &options,
+                           MemRefType type, spirv::StorageClass storageClass) {
   unsigned numBoolBits = options.boolNumBits;
   if (numBoolBits != 8) {
     LLVM_DEBUG(llvm::dbgs()
@@ -531,10 +516,10 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
-static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
-                                     const SPIRVConversionOptions &options,
-                                     MemRefType type,
-                                     spirv::StorageClass storageClass) {
+Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
+                              const SPIRVConversionOptions &options,
+                              MemRefType type,
+                              spirv::StorageClass storageClass) {
   IntegerType elementType = cast<IntegerType>(type.getElementType());
   Type arrayElemType = convertSubByteIntegerType(options, elementType);
   if (!arrayElemType)
@@ -569,9 +554,8 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
-static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVConversionOptions &options,
-                              MemRefType type) {
+Type convertMemrefType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, MemRefType type) {
   auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
   if (!attr) {
     LLVM_DEBUG(
@@ -731,73 +715,134 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
 }
 
 //===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
 //===----------------------------------------------------------------------===//
 
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
-                                       const SPIRVConversionOptions &options)
-    : targetEnv(targetAttr), options(options) {
-  // Add conversions. The order matters here: later ones will be tried earlier.
+spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+                                           spirv::BuiltIn builtin) {
+  // Look through all global variables in the given `body` block and check if
+  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+            spirv::SPIRVDialect::getAttributeName(
+                spirv::Decoration::BuiltIn))) {
+      auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+      if (varBuiltIn && *varBuiltIn == builtin) {
+        return varOp;
+      }
+    }
+  }
+  return nullptr;
+}
 
-  // Allow all SPIR-V dialect specific types. This assumes all builtin types
-  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
-  // were tried before.
-  //
-  // TODO: This assumes that the SPIR-V types are valid to use in the given
-  // target environment, which should be the case if the whole pipeline is
-  // driven by the same target environment. Still, we probably still want to
-  // validate and convert to be safe.
-  addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+                              StringRef suffix) {
+  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
 
-  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+                           Type integerType, OpBuilder &builder,
+                           StringRef prefix, StringRef suffix) {
+  if (auto varOp = getBuiltinVariable(body, builtin))
+    return varOp;
 
-  addConversion([this](IntegerType intType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    if (intType.getWidth() < 8)
-      return convertSubByteIntegerType(this->options, intType);
-    return Type();
-  });
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&body);
 
-  addConversion([this](FloatType floatType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    return Type();
-  });
+  spirv::GlobalVariableOp newVarOp;
+  switch (builtin) {
+  case spirv::BuiltIn::NumWorkgroups:
+  case spirv::BuiltIn::WorkgroupSize:
+  case spirv::BuiltIn::WorkgroupId:
+  case spirv::BuiltIn::LocalInvocationId:
+  case spirv::BuiltIn::GlobalInvocationId: {
+    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+                                           spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  case spirv::BuiltIn::SubgroupId:
+  case spirv::BuiltIn::NumSubgroups:
+  case spirv::BuiltIn::SubgroupSize: {
+    auto ptrType =
+        spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  default:
+    emitError(loc, "unimplemented builtin variable generation for ")
+        << stringifyBuiltIn(builtin);
+  }
+  return newVarOp;
+}
 
-  addConversion([this](ComplexType complexType) {
-    return convertComplexType(this->targetEnv, this->options, complexType);
-  });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
 
-  addConversion([this](VectorType vectorType) {
-    return convertVectorType(this->targetEnv, this->options, vectorType);
-  });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+                                              Builder &builder,
+                                              Type indexType) {
+  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+                                         /*stride=*/4);
+  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
 
-  addConversion([this](TensorType tensorType) {
-    return convertTensorType(this->targetEnv, this->options, tensorType);
-  });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+                                                unsigned elementCount) {
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+    if (!ptrType)
+      continue;
 
-  addConversion([this](MemRefType memRefType) {
-    return convertMemrefType(this->targetEnv, this->options, memRefType);
-  });
+    // Note that Vulkan requires "There must be no more than one push constant
+    // block statically used per shader entry point." So we should always reuse
+    // the existing one.
+    if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+      auto numElements = cast<spirv::ArrayType>(
+                             cast<spirv::StructType>(ptrType.getPointeeType())
+                                 .getElementType(0))
+                             .getNumElements();
+      if (numElements == elementCount)
+        return varOp;
+    }
+  }
+  return nullptr;
+}
 
-  // Register some last line of defense casting logic.
-  addSourceMaterialization(
-      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
-        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
-      });
-  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc,
+                                                        Block &block,
+                                                        unsigned elementCount,
+                                                        OpBuilder &b,
+                                                        Type indexType) {
+  if (auto varOp = getPushConstantVariable(block, elementCount))
+    return varOp;
+
+  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+  auto type = getPushConstantStorageType(elementCount, builder, indexType);
+  const char *name = "__push_constant_var__";
+  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+                                                 /*initializer=*/nullptr);
 }
 
 //===----------------------------------------------------------------------===//
 // func::FuncOp Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// A pattern for rewriting function signature to convert arguments of functions
 /// to be of valid SPIR-V types.
 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
@@ -808,7 +853,6 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
-} // namespace
 
 LogicalResult
 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
@@ -855,16 +899,6 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
   return success();
 }
 
-void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
-                                              RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// func::FuncOp Conversion Patterns
-//===----------------------------------------------------------------------===//
-
-namespace {
 /// A pattern for rewriting function signature to convert vect...
[truncated]

@kuhar kuhar changed the title [mlir][spirv] Restructure code in SPIRVConversion.cpp [mlir][spirv] Restructure code in SPIRVConversion.cpp. NFC. Jul 17, 2024
@angelz913 angelz913 force-pushed the spirv-conversion-code-refactor branch from 9947252 to 5125a74 Compare July 18, 2024 02:31
@angelz913 angelz913 force-pushed the spirv-conversion-code-refactor branch from 5125a74 to 0fe1360 Compare July 18, 2024 13:14
@kuhar kuhar requested a review from Hardcode84 July 18, 2024 13:50
@kuhar kuhar merged commit 9527d77 into llvm:main Jul 18, 2024
5 of 6 checks passed
sgundapa pushed a commit to sgundapa/upstream_effort that referenced this pull request Jul 23, 2024
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary: 

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251412
@angelz913 angelz913 deleted the spirv-conversion-code-refactor branch July 25, 2024 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants