Skip to content

Commit

Permalink
[mlir] [bufferize] fix bufferize deallocation error in nest symbol ta…
Browse files Browse the repository at this point in the history
…ble (#98476)

In nested symbols, the dealloc_helper function generated by lower
deallocations pass was incorrectly positioned, causing calls fail. This
patch fixes this issue.
  • Loading branch information
cxy-1993 authored Jul 15, 2024
1 parent 3698453 commit 662c6fc
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 26 deletions.
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class FuncOp;
namespace bufferization {
struct OneShotBufferizationOptions;

/// Maps from symbol table to its corresponding dealloc helper function.
using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -46,7 +49,7 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
/// given pattern set for use in other transformation passes.
void populateBufferizationDeallocLoweringPattern(
RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap);

/// Construct the library function needed for the fully generic
/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
return;
}

func::FuncOp helperFuncOp;
bufferization::DeallocHelperMap deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
SymbolTable symbolTable(module);

// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
if (deallocOp.getMemrefs().size() > 1) {
helperFuncOp = bufferization::buildDeallocationLibraryFunction(
builder, getOperation()->getLoc(), symbolTable);
return WalkResult::interrupt();
Operation *symtableOp =
deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
if (deallocOp.getMemrefs().size() > 1 &&
!deallocHelperFuncMap.contains(symtableOp)) {
SymbolTable symbolTable(symtableOp);
func::FuncOp helperFuncOp =
bufferization::buildDeallocationLibraryFunction(
builder, getOperation()->getLoc(), symbolTable);
deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
return WalkResult::advance();
});
}

RewritePatternSet patterns(&getContext());
patterns.add<CloneOpConversion>(patterns.getContext());
bufferization::populateBufferizationDeallocLoweringPattern(patterns,
helperFuncOp);
bufferization::populateBufferizationDeallocLoweringPattern(
patterns, deallocHelperFuncMap);

ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
Expand Down
41 changes: 25 additions & 16 deletions mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ class DeallocOpConversion
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
retainCondsMemref);

Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
rewriter.create<func::CallOp>(
op.getLoc(), deallocHelperFunc,
op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
castedCondsMemref, castedDeallocCondsMemref,
castedRetainCondsMemref});
Expand Down Expand Up @@ -338,9 +339,11 @@ class DeallocOpConversion
}

public:
DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
DeallocOpConversion(
MLIRContext *context,
const bufferization::DeallocHelperMap &deallocHelperFuncMap)
: OpConversionPattern<bufferization::DeallocOp>(context),
deallocHelperFunc(deallocHelperFunc) {}
deallocHelperFuncMap(deallocHelperFuncMap) {}

LogicalResult
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
Expand All @@ -360,7 +363,8 @@ class DeallocOpConversion
if (adaptor.getMemrefs().size() == 1)
return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);

if (!deallocHelperFunc)
Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
if (!deallocHelperFuncMap.contains(symtableOp))
return op->emitError(
"library function required for generic lowering, but cannot be "
"automatically inserted when operating on functions");
Expand All @@ -369,7 +373,7 @@ class DeallocOpConversion
}

private:
func::FuncOp deallocHelperFunc;
const bufferization::DeallocHelperMap &deallocHelperFuncMap;
};
} // namespace

Expand All @@ -385,26 +389,29 @@ struct LowerDeallocationsPass
return;
}

func::FuncOp helperFuncOp;
bufferization::DeallocHelperMap deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
SymbolTable symbolTable(module);

// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
if (deallocOp.getMemrefs().size() > 1) {
helperFuncOp = bufferization::buildDeallocationLibraryFunction(
builder, getOperation()->getLoc(), symbolTable);
return WalkResult::interrupt();
Operation *symtableOp =
deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
if (deallocOp.getMemrefs().size() > 1 &&
!deallocHelperFuncMap.contains(symtableOp)) {
SymbolTable symbolTable(symtableOp);
func::FuncOp helperFuncOp =
bufferization::buildDeallocationLibraryFunction(
builder, getOperation()->getLoc(), symbolTable);
deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
return WalkResult::advance();
});
}

RewritePatternSet patterns(&getContext());
bufferization::populateBufferizationDeallocLoweringPattern(patterns,
helperFuncOp);
bufferization::populateBufferizationDeallocLoweringPattern(
patterns, deallocHelperFuncMap);

ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
Expand Down Expand Up @@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
}

void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
RewritePatternSet &patterns,
const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
patterns.add<DeallocOpConversion>(patterns.getContext(),
deallocHelperFuncMap);
}

std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
// CHECK-NEXT: }
// CHECK-NEXT: return

// -----

// This test check dealloc_helper function is generated on each nested symbol
// table operation when needed and only generated once.
module @conversion_nest_module_dealloc_helper {
func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
func.return %0#0, %0#1 : i1, i1
}
module @nested_module_not_need_dealloc_helper {
func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
%0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
return %0#0, %0#1 : i1, i1
}
}
module @nested_module_need_dealloc_helper {
func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
func.return %0#0, %0#1 : i1, i1
}
func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
func.return %0#0, %0#1 : i1, i1
}
}
}

// CHECK: module @conversion_nest_module_dealloc_helper {
// CHECK: func.func @top_level_func
// CHECK: call @dealloc_helper
// CHECK: module @nested_module_not_need_dealloc_helper {
// CHECK: func.func @nested_module_not_need_dealloc_helper_func
// CHECK-NOT: @dealloc_helper
// CHECK: module @nested_module_need_dealloc_helper {
// CHECK: func.func @nested_module_need_dealloc_helper_func0
// CHECK: call @dealloc_helper
// CHECK: func.func @nested_module_need_dealloc_helper_func1
// CHECK: call @dealloc_helper
// CHECK: func.func private @dealloc_helper
// CHECK: func.func private @dealloc_helper

0 comments on commit 662c6fc

Please sign in to comment.