Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Support bufferization of external functions #113999

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ struct AliasingValue {
bool isDefinite;
};

template <typename T> class AliasList {
template <typename T>
class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
Expand Down Expand Up @@ -259,7 +260,7 @@ struct BufferizationOptions {
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, func op, bufferization options
/// Parameters: tensor type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
Expand Down Expand Up @@ -344,9 +345,9 @@ struct BufferizationOptions {
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);

/// Type converter from tensors to memrefs. This type converter is used to
/// determine bufferized function argument types. By default, a type
/// converter that returns a memref type with a fully dynamic layout map is
/// used.
/// determine bufferized function argument and result types. By default, a
/// type converter that returns a memref type with a fully dynamic layout map
/// is used.
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,

/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
Expand Down Expand Up @@ -392,36 +393,45 @@ struct FuncOpInterface
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();

// Construct the bufferized function type.
// Compute the argument types.
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
if (dyn_cast<TensorType>(argType)) {
if (isa<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
}
argTypes.push_back(argType);
}

// Bodiless functions are assumed opaque and we cannot know the
// bufferization contract they want to enforce. As a consequence, only
// support functions that don't return any tensors atm.
if (funcOp.isExternal()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (isa<TensorType>(resultType))
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
BaseMemRefType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options);
retTypes.push_back(resultType);
continue;
}
funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
retTypes.push_back(resultType);
}

// Compute the new function type.
auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);

// If the function has no body, set the new function type and we are done.
if (funcOp.isExternal()) {
funcOp.setType(newFuncType);
return success();
}

// TODO: Support functions with multiple returns.
matthias-springer marked this conversation as resolved.
Show resolved Hide resolved
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
Location loc = returnOp.getLoc();

// 1. Bufferize every block.
Expand All @@ -430,10 +440,10 @@ struct FuncOpInterface
options)))
return failure();

// 2. For each result, keep track of which inplace argument it reuses.
// 2. Bufferize all operands of the return op.
SmallVector<Value> returnValues;
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Value returnVal = returnOperand.get();
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);

Expand All @@ -443,23 +453,17 @@ struct FuncOpInterface
continue;
}

// Note: If `inferFunctionResultLayout = true`, cast are later folded
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
BaseMemRefType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, resultType, returnVal);
loc, bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}

// 3. Rewrite the terminator without the in-place bufferizable values.
returnOp.getOperandsMutable().assign(returnValues);

// 4. Rewrite the FuncOp type to buffer form.
funcOp.setType(FunctionType::get(op->getContext(), argTypes,
ValueRange(returnValues).getTypes()));

// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics

// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
// expected-error @+1 {{failed to bufferize op}}
func.func private @foo() -> tensor<?xf32>

// -----

// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
Expand Down Expand Up @@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {

// -----

// expected-error @+2 {{failed to bufferize op}}
// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)

func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}

// -----

func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
%idx2 : index) -> f32 {
%1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32

// -----

// Bufferization of bodiless function that returns a tensor.

// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)

// CHECK: func.func @call_to_unknown_tensor_returning_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
// CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}

// -----

// A function that returns a non-equivalent tensor with layout map.

// CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>
Expand Down
Loading