Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Support bufferization of external functions #113999

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.

This commit is in preparation of removing the deprecated func-bufferize pass. That pass can bufferize external functions.

Also update a few comments.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Oct 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.

This commit is in preparation of removing the deprecated func-bufferize pass. That pass can bufferize external functions.

Also update a few comments.


Full diff: https://github.com/llvm/llvm-project/pull/113999.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+6-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+30-26)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (-17)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index aceb9d059b95f3..4866e31b19d5de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -60,7 +60,8 @@ struct AliasingValue {
   bool isDefinite;
 };
 
-template <typename T> class AliasList {
+template <typename T>
+class AliasList {
 public:
   /// Create an empty list of aliases.
   AliasList() = default;
@@ -259,7 +260,7 @@ struct BufferizationOptions {
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
   /// Tensor -> MemRef type converter.
-  /// Parameters: Value, memory space, func op, bufferization options
+  /// Parameters: tensor type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
       std::function<BaseMemRefType(TensorType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
@@ -344,9 +345,9 @@ struct BufferizationOptions {
   void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
 
   /// Type converter from tensors to memrefs. This type converter is used to
-  /// determine bufferized function argument types. By default, a type
-  /// converter that returns a memref type with a fully dynamic layout map is
-  /// used.
+  /// determine bufferized function argument and result types. By default, a
+  /// type converter that returns a memref type with a fully dynamic layout map
+  /// is used.
   ///
   /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
   FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..a372e87d8335f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
 /// Return the FuncOp called by `callOp`.
 static FuncOp getCalledFunction(CallOpInterface callOp) {
-  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
@@ -392,11 +393,11 @@ struct FuncOpInterface
     auto funcOp = cast<FuncOp>(op);
     FunctionType funcType = funcOp.getFunctionType();
 
-    // Construct the bufferized function type.
+    // Construct the bufferized function type. Compute the argument types.
     SmallVector<Type> argTypes;
     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
       Type argType = it.value();
-      if (dyn_cast<TensorType>(argType)) {
+      if (isa<TensorType>(argType)) {
         argTypes.push_back(
             getBufferizedFunctionArgType(funcOp, it.index(), options));
         continue;
@@ -404,24 +405,33 @@ struct FuncOpInterface
       argTypes.push_back(argType);
     }
 
-    // Bodiless functions are assumed opaque and we cannot know the
-    // bufferization contract they want to enforce. As a consequence, only
-    // support functions that don't return any tensors atm.
-    if (funcOp.isExternal()) {
-      SmallVector<Type> retTypes;
-      for (Type resultType : funcType.getResults()) {
-        if (isa<TensorType>(resultType))
-          return funcOp->emitError() << "cannot bufferize bodiless function "
-                                     << "that returns a tensor";
+    // Compute the result types.
+    SmallVector<Type> retTypes;
+    for (Type resultType : funcType.getResults()) {
+      if (auto tensorType = dyn_cast<TensorType>(resultType)) {
+        BaseMemRefType resultType = options.functionArgTypeConverterFn(
+            tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+            options);
         retTypes.push_back(resultType);
+        continue;
       }
-      funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
+      retTypes.push_back(resultType);
+    }
+
+    // Compute the new function type.
+    auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
+
+    // If the function has no body, set the new function type and we are done.
+    if (funcOp.isExternal()) {
+      funcOp.setType(newFuncType);
       return success();
     }
 
     // TODO: Support functions with multiple returns.
     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
     assert(returnOp && "expected func with single return op");
+    assert(returnOp->getNumOperands() == retTypes.size() &&
+           "incorrect number of return values");
     Location loc = returnOp.getLoc();
 
     // 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
                                                         options)))
         return failure();
 
-    // 2. For each result, keep track of which inplace argument it reuses.
+    // 2. Bufferize all operands of the return op.
     SmallVector<Value> returnValues;
-    for (OpOperand &returnOperand : returnOp->getOpOperands()) {
-      Value returnVal = returnOperand.get();
+    for (auto [returnVal, bufferizedType] :
+         llvm::zip_equal(returnOp->getOperands(), retTypes)) {
       auto tensorType = dyn_cast<TensorType>(returnVal.getType());
       rewriter.setInsertionPoint(returnOp);
 
@@ -443,23 +453,17 @@ struct FuncOpInterface
         continue;
       }
 
-      // Note: If `inferFunctionResultLayout = true`, cast are later folded
+      // Note: If `inferFunctionResultLayout = true`, casts are later folded
       // away.
-      BaseMemRefType resultType = options.functionArgTypeConverterFn(
-          tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-          options);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          loc, resultType, returnVal);
+          loc, bufferizedType, returnVal);
       returnValues.push_back(toMemrefOp);
     }
 
-    // 3. Rewrite the terminator without the in-place bufferizable values.
     returnOp.getOperandsMutable().assign(returnValues);
 
-    // 4. Rewrite the FuncOp type to buffer form.
-    funcOp.setType(FunctionType::get(op->getContext(), argTypes,
-                                     ValueRange(returnValues).getTypes()));
-
+    // 3. Set the new function type.
+    funcOp.setType(newFuncType);
     return success();
   }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..2829eafb7c1c59 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,11 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
-// expected-error @+1 {{failed to bufferize op}}
-func.func private @foo() -> tensor<?xf32>
-
-// -----
-
 // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
 func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
     -> (tensor<f32>, tensor<f32>)
@@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
 
 // -----
 
-// expected-error @+2 {{failed to bufferize op}}
-// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
-func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-
-func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
-  call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-  return
-}
-
-// -----
-
 func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
                                         %idx2 : index) -> f32 {
   %1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..d31b43477beb9f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
 
 // -----
 
+// Bufferization of bodiless function that returns a tensor.
+
+// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+// CHECK: func.func @call_to_unknown_tensor_returning_func(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+  // CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+  call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+  return
+}
+
+// -----
+
 // A function that returns a non-equivalent tensor with layout map.
 
 // CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>

This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.

This commit is in preparation of removing the `func-bufferize` pass.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_bodiless branch from a9f607f to 11089cc Compare October 29, 2024 22:16
Copy link
Contributor

@javedabsar1 javedabsar1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
As Matthias mentions, there is a follow up PR #114017

@matthias-springer matthias-springer merged commit 217700b into main Oct 30, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/bufferization_bodiless branch October 30, 2024 12:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants