Skip to content

Commit

Permalink
[mlir][bufferization] Support bufferization of external functions
Browse files Browse the repository at this point in the history
This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.

This commit is in preparation of removing the `func-bufferize` pass.
  • Loading branch information
matthias-springer committed Oct 29, 2024
1 parent 00ca207 commit 11089cc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ struct AliasingValue {
bool isDefinite;
};

template <typename T> class AliasList {
template <typename T>
class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
Expand Down Expand Up @@ -259,7 +260,7 @@ struct BufferizationOptions {
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, func op, bufferization options
/// Parameters: tensor type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
Expand Down Expand Up @@ -344,9 +345,9 @@ struct BufferizationOptions {
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);

/// Type converter from tensors to memrefs. This type converter is used to
/// determine bufferized function argument types. By default, a type
/// converter that returns a memref type with a fully dynamic layout map is
/// used.
/// determine bufferized function argument and result types. By default, a
/// type converter that returns a memref type with a fully dynamic layout map
/// is used.
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,

/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
Expand Down Expand Up @@ -392,36 +393,45 @@ struct FuncOpInterface
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();

// Construct the bufferized function type.
// Compute the argument types.
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
if (dyn_cast<TensorType>(argType)) {
if (isa<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
}
argTypes.push_back(argType);
}

// Bodiless functions are assumed opaque and we cannot know the
// bufferization contract they want to enforce. As a consequence, only
// support functions that don't return any tensors atm.
if (funcOp.isExternal()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (isa<TensorType>(resultType))
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
BaseMemRefType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options);
retTypes.push_back(resultType);
continue;
}
funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
retTypes.push_back(resultType);
}

// Compute the new function type.
auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);

// If the function has no body, set the new function type and we are done.
if (funcOp.isExternal()) {
funcOp.setType(newFuncType);
return success();
}

// TODO: Support functions with multiple returns.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
Location loc = returnOp.getLoc();

// 1. Bufferize every block.
Expand All @@ -430,10 +440,10 @@ struct FuncOpInterface
options)))
return failure();

// 2. For each result, keep track of which inplace argument it reuses.
// 2. Bufferize all operands of the return op.
SmallVector<Value> returnValues;
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Value returnVal = returnOperand.get();
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);

Expand All @@ -443,23 +453,17 @@ struct FuncOpInterface
continue;
}

// Note: If `inferFunctionResultLayout = true`, cast are later folded
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
BaseMemRefType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, resultType, returnVal);
loc, bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}

// 3. Rewrite the terminator without the in-place bufferizable values.
returnOp.getOperandsMutable().assign(returnValues);

// 4. Rewrite the FuncOp type to buffer form.
funcOp.setType(FunctionType::get(op->getContext(), argTypes,
ValueRange(returnValues).getTypes()));

// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics

// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
// expected-error @+1 {{failed to bufferize op}}
func.func private @foo() -> tensor<?xf32>

// -----

// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
Expand Down Expand Up @@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {

// -----

// expected-error @+2 {{failed to bufferize op}}
// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)

func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}

// -----

func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
%idx2 : index) -> f32 {
%1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32

// -----

// Bufferization of bodiless function that returns a tensor.

// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)

// CHECK: func.func @call_to_unknown_tensor_returning_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
// CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}

// -----

// A function that returns a non-equivalent tensor with layout map.

// CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>
Expand Down

0 comments on commit 11089cc

Please sign in to comment.