-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Support bufferization of external functions #113999
[mlir][bufferization] Support bufferization of external functions #113999
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value. This commit is in preparation of removing the deprecated Also update a few comments. Full diff: https://github.com/llvm/llvm-project/pull/113999.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index aceb9d059b95f3..4866e31b19d5de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -60,7 +60,8 @@ struct AliasingValue {
bool isDefinite;
};
-template <typename T> class AliasList {
+template <typename T>
+class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
@@ -259,7 +260,7 @@ struct BufferizationOptions {
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
- /// Parameters: Value, memory space, func op, bufferization options
+ /// Parameters: tensor type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
@@ -344,9 +345,9 @@ struct BufferizationOptions {
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
/// Type converter from tensors to memrefs. This type converter is used to
- /// determine bufferized function argument types. By default, a type
- /// converter that returns a memref type with a fully dynamic layout map is
- /// used.
+ /// determine bufferized function argument and result types. By default, a
+ /// type converter that returns a memref type with a fully dynamic layout map
+ /// is used.
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..a372e87d8335f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
@@ -392,11 +393,11 @@ struct FuncOpInterface
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
- // Construct the bufferized function type.
+ // Construct the bufferized function type. Compute the argument types.
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (dyn_cast<TensorType>(argType)) {
+ if (isa<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -404,24 +405,33 @@ struct FuncOpInterface
argTypes.push_back(argType);
}
- // Bodiless functions are assumed opaque and we cannot know the
- // bufferization contract they want to enforce. As a consequence, only
- // support functions that don't return any tensors atm.
- if (funcOp.isExternal()) {
- SmallVector<Type> retTypes;
- for (Type resultType : funcType.getResults()) {
- if (isa<TensorType>(resultType))
- return funcOp->emitError() << "cannot bufferize bodiless function "
- << "that returns a tensor";
+ // Compute the result types.
+ SmallVector<Type> retTypes;
+ for (Type resultType : funcType.getResults()) {
+ if (auto tensorType = dyn_cast<TensorType>(resultType)) {
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options);
retTypes.push_back(resultType);
+ continue;
}
- funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
+ retTypes.push_back(resultType);
+ }
+
+ // Compute the new function type.
+ auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
+
+ // If the function has no body, set the new function type and we are done.
+ if (funcOp.isExternal()) {
+ funcOp.setType(newFuncType);
return success();
}
// TODO: Support functions with multiple returns.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
+ assert(returnOp->getNumOperands() == retTypes.size() &&
+ "incorrect number of return values");
Location loc = returnOp.getLoc();
// 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
options)))
return failure();
- // 2. For each result, keep track of which inplace argument it reuses.
+ // 2. Bufferize all operands of the return op.
SmallVector<Value> returnValues;
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- Value returnVal = returnOperand.get();
+ for (auto [returnVal, bufferizedType] :
+ llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
@@ -443,23 +453,17 @@ struct FuncOpInterface
continue;
}
- // Note: If `inferFunctionResultLayout = true`, cast are later folded
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
- BaseMemRefType resultType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, resultType, returnVal);
+ loc, bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}
- // 3. Rewrite the terminator without the in-place bufferizable values.
returnOp.getOperandsMutable().assign(returnValues);
- // 4. Rewrite the FuncOp type to buffer form.
- funcOp.setType(FunctionType::get(op->getContext(), argTypes,
- ValueRange(returnValues).getTypes()));
-
+ // 3. Set the new function type.
+ funcOp.setType(newFuncType);
return success();
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..2829eafb7c1c59 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,11 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
-// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
-// expected-error @+1 {{failed to bufferize op}}
-func.func private @foo() -> tensor<?xf32>
-
-// -----
-
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
@@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
// -----
-// expected-error @+2 {{failed to bufferize op}}
-// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
-func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-
-func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
- call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
- return
-}
-
-// -----
-
func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
%idx2 : index) -> f32 {
%1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..d31b43477beb9f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
// -----
+// Bufferization of bodiless function that returns a tensor.
+
+// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+// CHECK: func.func @call_to_unknown_tensor_returning_func(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+ // CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+ call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+ return
+}
+
+// -----
+
// A function that returns a non-equivalent tensor with layout map.
// CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>
|
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
Show resolved
Hide resolved
This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value. This commit is in preparation of removing the `func-bufferize` pass.
a9f607f
to
11089cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
As Matthias mentions, there is a follow up PR #114017
This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.
This commit is in preparation of removing the deprecated
func-bufferize
pass. That pass can bufferize external functions.Also update a few comments.