From 7d330c35d6f1e16bdd0d52c15a402895842d1057 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Fri, 7 Jun 2024 15:46:47 +0300 Subject: [PATCH] Ensure proper ending of the reverse block for conditional statements. Fixes #922. --- lib/Differentiator/ReverseModeVisitor.cpp | 18 +++++++++--------- test/Gradient/Gradients.C | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 195ccbf3e..0814072d9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -945,20 +945,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .get(); // If result is a glvalue, we should keep it as it can potentially be // assigned as in (c ? a : b) = x; + Expr* ResultRef = nullptr; if ((CO->isModifiableLvalue(m_Context) == Expr::MLV_Valid) && ifTrueExprDiff.getExpr_dx() && ifFalseExprDiff.getExpr_dx()) { - Expr* ResultRef = m_Sema - .ActOnConditionalOp(noLoc, noLoc, condStored, - ifTrueExprDiff.getExpr_dx(), - ifFalseExprDiff.getExpr_dx()) - .get(); + ResultRef = m_Sema + .ActOnConditionalOp(noLoc, noLoc, condStored, + ifTrueExprDiff.getExpr_dx(), + ifFalseExprDiff.getExpr_dx()) + .get(); if (ResultRef->isModifiableLvalue(m_Context) != Expr::MLV_Valid) ResultRef = nullptr; - Stmt* revBlock = utils::unwrapIfSingleStmt(endBlock(direction::reverse)); - addToCurrentBlock(revBlock, direction::reverse); - return StmtDiff(condExpr, ResultRef); } - return StmtDiff(condExpr); + Stmt* revBlock = utils::unwrapIfSingleStmt(endBlock(direction::reverse)); + addToCurrentBlock(revBlock, direction::reverse); + return StmtDiff(condExpr, ResultRef); } StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) { diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index 900edd071..7ac10f9f4 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -913,6 +913,18 @@ double fn_null_stmts(double x) { //CHECK-NEXT: *_d_x += 1; //CHECK-NEXT: } +double fn_const_cond_op(double x) { + return x + (x > 0 ? 1.0 : 0.0); +} + +//CHECK: void fn_const_cond_op_grad(double x, double *_d_x) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: _cond0 = x > 0; +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_x += 1; +//CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -985,4 +997,7 @@ int main() { TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00 INIT_GRADIENT(fn_null_stmts); + + INIT_GRADIENT(fn_const_cond_op); + TEST_GRADIENT(fn_const_cond_op, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00 }