From efc079b050d551fef950f42d4ecc21bf9479eff8 Mon Sep 17 00:00:00 2001 From: cxy Date: Thu, 11 Jul 2024 12:56:41 +0000 Subject: [PATCH] [mlir] [bufferize] fix bufferize deallocation error in nest symbol table In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue. --- .../Dialect/Bufferization/Transforms/Passes.h | 3 +- .../BufferizationToMemRef.cpp | 21 ++++++---- .../Transforms/LowerDeallocations.cpp | 41 +++++++++++-------- .../Transforms/lower-deallocations.mlir | 41 +++++++++++++++++++ 4 files changed, 80 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index e053e6c97e1430..298b2165f0e820 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -46,7 +46,8 @@ std::unique_ptr createLowerDeallocationsPass(); /// Adds the conversion pattern of the `bufferization.dealloc` operation to the /// given pattern set for use in other transformation passes. void populateBufferizationDeallocLoweringPattern( - RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc); + RewritePatternSet &patterns, + const llvm::DenseMap &deallocHelperFuncMap); /// Construct the library function needed for the fully generic /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass. diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 2aae39f51b9409..4de204994f5196 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -132,27 +132,30 @@ struct BufferizationToMemRefPass return; } - func::FuncOp helperFuncOp; + llvm::DenseMap deallocHelperFuncMap; if (auto module = dyn_cast(getOperation())) { OpBuilder builder = OpBuilder::atBlockBegin(&module.getBodyRegion().front()); - SymbolTable symbolTable(module); // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = bufferization::buildDeallocationLibraryFunction( - builder, getOperation()->getLoc(), symbolTable); - return WalkResult::interrupt(); + Operation *symtableOp = + deallocOp->getParentWithTrait(); + if (deallocOp.getMemrefs().size() > 1 && + !deallocHelperFuncMap.contains(symtableOp)) { + SymbolTable symbolTable(symtableOp); + func::FuncOp helperFuncOp = + bufferization::buildDeallocationLibraryFunction( + builder, getOperation()->getLoc(), symbolTable); + deallocHelperFuncMap[symtableOp] = helperFuncOp; } - return WalkResult::advance(); }); } RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); - bufferization::populateBufferizationDeallocLoweringPattern(patterns, - helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern( + patterns, deallocHelperFuncMap); ConversionTarget target(getContext()); target.addLegalDialectgetParentWithTrait(); rewriter.create( - op.getLoc(), deallocHelperFunc, + op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), SmallVector{castedDeallocMemref, castedRetainMemref, castedCondsMemref, castedDeallocCondsMemref, castedRetainCondsMemref}); @@ -338,9 +339,11 @@ class DeallocOpConversion } public: - DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc) + DeallocOpConversion( + MLIRContext *context, + const llvm::DenseMap &deallocHelperFuncMap) : OpConversionPattern(context), - deallocHelperFunc(deallocHelperFunc) {} + deallocHelperFuncMap(deallocHelperFuncMap) {} LogicalResult matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, @@ -360,7 +363,8 @@ class DeallocOpConversion if (adaptor.getMemrefs().size() == 1) return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); - if (!deallocHelperFunc) + Operation *symtableOp = op->getParentWithTrait(); + if (!deallocHelperFuncMap.contains(symtableOp)) return op->emitError( "library function required for generic lowering, but cannot be " "automatically inserted when operating on functions"); @@ -369,7 +373,7 @@ class DeallocOpConversion } private: - func::FuncOp deallocHelperFunc; + const llvm::DenseMap &deallocHelperFuncMap; }; } // namespace @@ -385,26 +389,29 @@ struct LowerDeallocationsPass return; } - func::FuncOp helperFuncOp; + llvm::DenseMap deallocHelperFuncMap; if (auto module = dyn_cast(getOperation())) { OpBuilder builder = OpBuilder::atBlockBegin(&module.getBodyRegion().front()); - SymbolTable symbolTable(module); // Build dealloc helper function if there are deallocs. getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = bufferization::buildDeallocationLibraryFunction( - builder, getOperation()->getLoc(), symbolTable); - return WalkResult::interrupt(); + Operation *symtableOp = + deallocOp->getParentWithTrait(); + if (deallocOp.getMemrefs().size() > 1 && + !deallocHelperFuncMap.contains(symtableOp)) { + SymbolTable symbolTable(symtableOp); + func::FuncOp helperFuncOp = + bufferization::buildDeallocationLibraryFunction( + builder, getOperation()->getLoc(), symbolTable); + deallocHelperFuncMap[symtableOp] = helperFuncOp; } - return WalkResult::advance(); }); } RewritePatternSet patterns(&getContext()); - bufferization::populateBufferizationDeallocLoweringPattern(patterns, - helperFuncOp); + bufferization::populateBufferizationDeallocLoweringPattern( + patterns, deallocHelperFuncMap); ConversionTarget target(getContext()); target.addLegalDialect(patterns.getContext(), deallocLibraryFunc); + RewritePatternSet &patterns, + const llvm::DenseMap &deallocHelperFuncMap) { + patterns.add(patterns.getContext(), + deallocHelperFuncMap); } std::unique_ptr mlir::bufferization::createLowerDeallocationsPass() { diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir index 5fedd45555fcd8..2d83a2a1ec28db 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir @@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32> // CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]] // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// This test check dealloc_helper function is generated on each nested symbol +// table operation when needed and only generate once. +module @conversion_nest_module_dealloc_helper { + func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + module @nested_module_not_need_dealloc_helper { + func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>) + return %0#0, %0#1 : i1, i1 + } + } + module @nested_module_need_dealloc_helper { + func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + func.return %0#0, %0#1 : i1, i1 + } + } +} + +// CHECK: module @conversion_nest_module_dealloc_helper { +// CHECK: func.func @top_level_func +// CHECK: call @dealloc_helper +// CHECK: module @nested_module_not_need_dealloc_helper { +// CHECK: func.func @nested_module_not_need_dealloc_helper_func +// CHECK-NOT: @dealloc_helper +// CHECK: module @nested_module_need_dealloc_helper { +// CHECK: func.func @nested_module_need_dealloc_helper_func0 +// CHECK: call @dealloc_helper +// CHECK: func.func @nested_module_need_dealloc_helper_func1 +// CHECK: call @dealloc_helper +// CHECK: func.func private @dealloc_helper +// CHECK: func.func private @dealloc_helper