Skip to content

Commit

Permalink
Ensure proper ending of the reverse block for conditional statements. F…
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jun 7, 2024
1 parent 3cf8b33 commit 7d330c3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
18 changes: 9 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,20 +945,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.get();
// If result is a glvalue, we should keep it as it can potentially be
// assigned as in (c ? a : b) = x;
Expr* ResultRef = nullptr;
if ((CO->isModifiableLvalue(m_Context) == Expr::MLV_Valid) &&
ifTrueExprDiff.getExpr_dx() && ifFalseExprDiff.getExpr_dx()) {
Expr* ResultRef = m_Sema
.ActOnConditionalOp(noLoc, noLoc, condStored,
ifTrueExprDiff.getExpr_dx(),
ifFalseExprDiff.getExpr_dx())
.get();
ResultRef = m_Sema
.ActOnConditionalOp(noLoc, noLoc, condStored,
ifTrueExprDiff.getExpr_dx(),
ifFalseExprDiff.getExpr_dx())
.get();
if (ResultRef->isModifiableLvalue(m_Context) != Expr::MLV_Valid)
ResultRef = nullptr;
Stmt* revBlock = utils::unwrapIfSingleStmt(endBlock(direction::reverse));
addToCurrentBlock(revBlock, direction::reverse);
return StmtDiff(condExpr, ResultRef);
}
return StmtDiff(condExpr);
Stmt* revBlock = utils::unwrapIfSingleStmt(endBlock(direction::reverse));
addToCurrentBlock(revBlock, direction::reverse);
return StmtDiff(condExpr, ResultRef);
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
Expand Down
15 changes: 15 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,18 @@ double fn_null_stmts(double x) {
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: }

double fn_const_cond_op(double x) {
return x + (x > 0 ? 1.0 : 0.0);
}

//CHECK: void fn_const_cond_op_grad(double x, double *_d_x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: _cond0 = x > 0;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -985,4 +997,7 @@ int main() {
TEST_GRADIENT(fn_cond_init, /*numOfDerivativeArgs=*/1, -1, &dx); // CHECK-EXEC: 1.00

INIT_GRADIENT(fn_null_stmts);

INIT_GRADIENT(fn_const_cond_op);
TEST_GRADIENT(fn_const_cond_op, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00
}

0 comments on commit 7d330c3

Please sign in to comment.