diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index aceb9d059b95f3..4866e31b19d5de 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -60,7 +60,8 @@ struct AliasingValue { bool isDefinite; }; -template class AliasList { +template +class AliasList { public: /// Create an empty list of aliases. AliasList() = default; @@ -259,7 +260,7 @@ struct BufferizationOptions { /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; /// Tensor -> MemRef type converter. - /// Parameters: Value, memory space, func op, bufferization options + /// Parameters: tensor type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = std::function; @@ -344,9 +345,9 @@ struct BufferizationOptions { void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); /// Type converter from tensors to memrefs. This type converter is used to - /// determine bufferized function argument types. By default, a type - /// converter that returns a memref type with a fully dynamic layout map is - /// used. + /// determine bufferized function argument and result types. By default, a + /// type converter that returns a memref type with a fully dynamic layout map + /// is used. /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 9fbe574ec392dc..6e91d3b89a7c79 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( @@ -392,11 +393,11 @@ struct FuncOpInterface auto funcOp = cast(op); FunctionType funcType = funcOp.getFunctionType(); - // Construct the bufferized function type. + // Compute the argument types. SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (dyn_cast(argType)) { + if (isa(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -404,24 +405,33 @@ struct FuncOpInterface argTypes.push_back(argType); } - // Bodiless functions are assumed opaque and we cannot know the - // bufferization contract they want to enforce. As a consequence, only - // support functions that don't return any tensors atm. - if (funcOp.isExternal()) { - SmallVector retTypes; - for (Type resultType : funcType.getResults()) { - if (isa(resultType)) - return funcOp->emitError() << "cannot bufferize bodiless function " - << "that returns a tensor"; + // Compute the result types. + SmallVector retTypes; + for (Type resultType : funcType.getResults()) { + if (auto tensorType = dyn_cast(resultType)) { + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + options); retTypes.push_back(resultType); + continue; } - funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); + retTypes.push_back(resultType); + } + + // Compute the new function type. + auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes); + + // If the function has no body, set the new function type and we are done. + if (funcOp.isExternal()) { + funcOp.setType(newFuncType); return success(); } // TODO: Support functions with multiple returns. func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); + assert(returnOp->getNumOperands() == retTypes.size() && + "incorrect number of return values"); Location loc = returnOp.getLoc(); // 1. Bufferize every block. @@ -430,10 +440,10 @@ struct FuncOpInterface options))) return failure(); - // 2. For each result, keep track of which inplace argument it reuses. + // 2. Bufferize all operands of the return op. SmallVector returnValues; - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Value returnVal = returnOperand.get(); + for (auto [returnVal, bufferizedType] : + llvm::zip_equal(returnOp->getOperands(), retTypes)) { auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); @@ -443,23 +453,17 @@ struct FuncOpInterface continue; } - // Note: If `inferFunctionResultLayout = true`, cast are later folded + // Note: If `inferFunctionResultLayout = true`, casts are later folded // away. - BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, - options); Value toMemrefOp = rewriter.create( - loc, resultType, returnVal); + loc, bufferizedType, returnVal); returnValues.push_back(toMemrefOp); } - // 3. Rewrite the terminator without the in-place bufferizable values. returnOp.getOperandsMutable().assign(returnValues); - // 4. Rewrite the FuncOp type to buffer form. - funcOp.setType(FunctionType::get(op->getContext(), argTypes, - ValueRange(returnValues).getTypes())); - + // 3. Set the new function type. + funcOp.setType(newFuncType); return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index ee0f71f668dc74..2829eafb7c1c59 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -1,11 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics -// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}} -// expected-error @+1 {{failed to bufferize op}} -func.func private @foo() -> tensor - -// ----- - // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor) -> (tensor, tensor) @@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref, %idx: index) -> (f32) { // ----- -// expected-error @+2 {{failed to bufferize op}} -// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} -func.func private @foo(%t : tensor) -> (f32, tensor, f32) - -func.func @call_to_unknown_tensor_returning_func(%t : tensor) { - call @foo(%t) : (tensor) -> (f32, tensor, f32) - return -} - -// ----- - func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index, %idx2 : index) -> f32 { %1 = bufferization.alloc_tensor(%idx) : tensor diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 0d5224514e3a02..d31b43477beb9f 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32 // ----- +// Bufferization of bodiless function that returns a tensor. + +// CHECK: func.func private @foo(memref>) -> (f32, memref>, f32) +func.func private @foo(%t : tensor) -> (f32, tensor, f32) + +// CHECK: func.func @call_to_unknown_tensor_returning_func( +// CHECK-SAME: %[[arg0:.*]]: memref>) { +func.func @call_to_unknown_tensor_returning_func(%t : tensor) { + // CHECK: call @foo(%[[arg0]]) : (memref>) -> (f32, memref>, f32) + call @foo(%t) : (tensor) -> (f32, tensor, f32) + return +} + +// ----- + // A function that returns a non-equivalent tensor with layout map. // CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>