From 573929a3bea79322bfca937d56e6f336ec14791d Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Mon, 2 Sep 2024 19:38:12 +0200 Subject: [PATCH] [EmitC] Fix API usage in dialect conversion (#18411) This should fix a bug triggered by https://github.com/llvm/llvm-project/pull/106760. Signed-off-by: Simon Camphausen --- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 28 ++++--------------- .../VM/Conversion/VMToEmitC/EmitCBuilders.cpp | 9 ++++++ .../VM/Conversion/VMToEmitC/EmitCBuilders.h | 4 +++ 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index ae1a93c13c83..d103c924a197 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -1485,19 +1485,6 @@ ResultOpTy lookupSymbolRef(Operation *accessOp, StringRef attrName) { return globalOp; } -void updateResultUses(Operation *op, ConversionPatternRewriter &rewriter, - SmallVector &resultOperands) { - for (auto [result, resultOperand] : - llvm::zip(op->getResults(), resultOperands)) { - if (!llvm::isa(result.getType())) { - auto operand = cast>(resultOperand); - auto operandRValue = - emitc_builders::asRValue(rewriter, op->getLoc(), operand); - result.replaceAllUsesWith(operandRValue); - } - } -} - template class EmitCConversionPattern : public OpConversionPattern { public: @@ -2647,9 +2634,8 @@ class CallOpConversion : public EmitCConversionPattern { /*rewriter=*/rewriter, /*location=*/loc, /*callee=*/funcOp, /*operands=*/updatedOperands, this->getModuleAnalysis()); - updateResultUses(op, rewriter, resultOperands); - - rewriter.eraseOp(op); + emitc_builders::asRValues(rewriter, loc, resultOperands); + rewriter.replaceOp(op, resultOperands); return success(); } @@ -2714,9 +2700,8 @@ class CallOpConversion : public EmitCConversionPattern { returnIfError(rewriter, loc, callee, updatedOperands, this->getModuleAnalysis()); - updateResultUses(op, rewriter, resultOperands); - - rewriter.eraseOp(op); + emitc_builders::asRValues(rewriter, loc, resultOperands); + rewriter.replaceOp(op, resultOperands); return success(); } @@ -3833,9 +3818,8 @@ class ContainerOpConversion : public EmitCConversionPattern { /*operands=*/ArrayRef(unwrappedOperands), this->getModuleAnalysis()); - updateResultUses(op.getOperation(), rewriter, resultOperands); - - rewriter.eraseOp(op); + emitc_builders::asRValues(rewriter, loc, resultOperands); + rewriter.replaceOp(op, resultOperands); } else { rewriter.replaceOpWithNewOp( /*op=*/op, diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp index 5a859726764e..ca7756d4c607 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp @@ -78,6 +78,15 @@ Value asRValue(OpBuilder builder, Location loc, value); } +void asRValues(OpBuilder builder, Location location, + SmallVector &values) { + for (auto &value : values) { + if (auto lvalue = llvm::dyn_cast>(value)) { + value = emitc_builders::asRValue(builder, location, lvalue); + } + } +} + TypedValue addressOf(OpBuilder builder, Location location, TypedValue operand) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h index 2b26fa60e011..6a907ee8f9f4 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h @@ -62,6 +62,10 @@ TypedValue asLValue(OpBuilder builder, Location loc, Value asRValue(OpBuilder builder, Location loc, TypedValue value); +/// Replace values of lvalue type with rvalues. +void asRValues(OpBuilder builder, Location location, + SmallVector &values); + TypedValue addressOf(OpBuilder builder, Location location, TypedValue operand);