Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] [bufferize] fix bufferize deallocation error in nest symbol table #98476

Merged
merged 1 commit into from
Jul 15, 2024

Conversation

cxy-1993
Copy link
Contributor

In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Jul 11, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 11, 2024

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

Changes

In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.


Full diff: https://github.com/llvm/llvm-project/pull/98476.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+2-1)
  • (modified) mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp (+25-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir (+41)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..298b2165f0e82 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -46,7 +46,8 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
 /// given pattern set for use in other transformation passes.
 void populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap);
 
 /// Construct the library function needed for the fully generic
 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f51b940..4de204994f519 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
     patterns.add<CloneOpConversion>(patterns.getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb46918ab1e8..17987f7322144 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
         retainCondsMemref);
 
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     rewriter.create<func::CallOp>(
-        op.getLoc(), deallocHelperFunc,
+        op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
         SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
                            castedCondsMemref, castedDeallocCondsMemref,
                            castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
   }
 
 public:
-  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+  DeallocOpConversion(
+      MLIRContext *context,
+      const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap)
       : OpConversionPattern<bufferization::DeallocOp>(context),
-        deallocHelperFunc(deallocHelperFunc) {}
+        deallocHelperFuncMap(deallocHelperFuncMap) {}
 
   LogicalResult
   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
     if (adaptor.getMemrefs().size() == 1)
       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
 
-    if (!deallocHelperFunc)
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+    if (!deallocHelperFuncMap.contains(symtableOp))
       return op->emitError(
           "library function required for generic lowering, but cannot be "
           "automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
   }
 
 private:
-  func::FuncOp deallocHelperFunc;
+  const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap;
 };
 } // namespace
 
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
 }
 
 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
-  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap) {
+  patterns.add<DeallocOpConversion>(patterns.getContext(),
+                                    deallocHelperFuncMap);
 }
 
 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45555fcd..2d83a2a1ec28d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 // CHECK-NEXT:     memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generate once.
+module @conversion_nest_module_dealloc_helper {
+  func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+    %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+    func.return %0#0, %0#1 : i1, i1
+  }
+  module @nested_module_not_need_dealloc_helper {
+    func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+      return %0#0, %0#1 : i1, i1
+    }
+  }
+  module @nested_module_need_dealloc_helper {
+    func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+    func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+  }
+}
+
+// CHECK:     module @conversion_nest_module_dealloc_helper {
+// CHECK:       func.func @top_level_func
+// CHECK:         call @dealloc_helper
+// CHECK:       module @nested_module_not_need_dealloc_helper {
+// CHECK:         func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT:       @dealloc_helper
+// CHECK:       module @nested_module_need_dealloc_helper {
+// CHECK:         func.func @nested_module_need_dealloc_helper_func0
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func @nested_module_need_dealloc_helper_func1
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func private @dealloc_helper
+// CHECK:       func.func private @dealloc_helper

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 11, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: donald chen (cxy-1993)

Changes

In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.


Full diff: https://github.com/llvm/llvm-project/pull/98476.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+2-1)
  • (modified) mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp (+25-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir (+41)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..298b2165f0e82 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -46,7 +46,8 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
 /// given pattern set for use in other transformation passes.
 void populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap);
 
 /// Construct the library function needed for the fully generic
 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f51b940..4de204994f519 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
     patterns.add<CloneOpConversion>(patterns.getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb46918ab1e8..17987f7322144 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
         retainCondsMemref);
 
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     rewriter.create<func::CallOp>(
-        op.getLoc(), deallocHelperFunc,
+        op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
         SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
                            castedCondsMemref, castedDeallocCondsMemref,
                            castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
   }
 
 public:
-  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+  DeallocOpConversion(
+      MLIRContext *context,
+      const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap)
       : OpConversionPattern<bufferization::DeallocOp>(context),
-        deallocHelperFunc(deallocHelperFunc) {}
+        deallocHelperFuncMap(deallocHelperFuncMap) {}
 
   LogicalResult
   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
     if (adaptor.getMemrefs().size() == 1)
       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
 
-    if (!deallocHelperFunc)
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+    if (!deallocHelperFuncMap.contains(symtableOp))
       return op->emitError(
           "library function required for generic lowering, but cannot be "
           "automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
   }
 
 private:
-  func::FuncOp deallocHelperFunc;
+  const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap;
 };
 } // namespace
 
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
 }
 
 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
-  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap) {
+  patterns.add<DeallocOpConversion>(patterns.getContext(),
+                                    deallocHelperFuncMap);
 }
 
 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45555fcd..2d83a2a1ec28d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 // CHECK-NEXT:     memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generate once.
+module @conversion_nest_module_dealloc_helper {
+  func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+    %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+    func.return %0#0, %0#1 : i1, i1
+  }
+  module @nested_module_not_need_dealloc_helper {
+    func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+      return %0#0, %0#1 : i1, i1
+    }
+  }
+  module @nested_module_need_dealloc_helper {
+    func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+    func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+  }
+}
+
+// CHECK:     module @conversion_nest_module_dealloc_helper {
+// CHECK:       func.func @top_level_func
+// CHECK:         call @dealloc_helper
+// CHECK:       module @nested_module_not_need_dealloc_helper {
+// CHECK:         func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT:       @dealloc_helper
+// CHECK:       module @nested_module_need_dealloc_helper {
+// CHECK:         func.func @nested_module_need_dealloc_helper_func0
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func @nested_module_need_dealloc_helper_func1
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func private @dealloc_helper
+// CHECK:       func.func private @dealloc_helper

Copy link

github-actions bot commented Jul 15, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

In nested symbols, the dealloc_helper function generated by lower deallocations
pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
@cxy-1993 cxy-1993 merged commit 662c6fc into llvm:main Jul 15, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants