From 616884355094be0ee63b54ddc3aae5960d5b1976 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Sun, 16 Jun 2024 11:53:55 +0200 Subject: [PATCH] Add support for assignments in while-loops This PR adds support for assignments in while-loops. It also enables to combine multiple assignments in the while condition by adding supprot for some logoical operators. Fixes: #913 --- lib/Differentiator/BaseForwardModeVisitor.cpp | 74 +++++++++++-------- test/FirstDerivative/Loops.C | 45 ++++++++++- 2 files changed, 83 insertions(+), 36 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 16f07037a..75122c3e8 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -727,6 +727,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { // Visit(cond) auto* condBO = dyn_cast(cond); auto* condUO = dyn_cast(cond); + // FIXME: Currently we only support logical and assignment operators. if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); @@ -1650,20 +1651,52 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) { const VarDecl* condVar = WS->getConditionVariable(); VarDecl* condVarClone = nullptr; DeclDiff condVarRes; + + StmtDiff condDiff = Clone(WS->getCond()); + Expr* cond = condDiff.getExpr(); + + // Check if the condition contais a variable declaration and create a + // declaration of both the variable and it's adjoint before the while-loop. if (condVar) { - condVarRes = DifferentiateVarDecl(condVar); + condVarRes = DifferentiateVarDecl(condVar, /*ignoreInit=*/true); condVarClone = condVarRes.getDecl(); + if (condVarRes.getDecl_dx()) + addToCurrentBlock(BuildDeclStmt(condVarRes.getDecl_dx())); + auto* condInit = condVarClone->getInit(); + condVarClone->setInit(nullptr); + cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit); + addToCurrentBlock(BuildDeclStmt(condVarClone)); + } + // Assignments in the condition are allowed, differentiate. + if (cond) { + cond = cond->IgnoreParenImpCasts(); + auto* condBO = dyn_cast(cond); + auto* condUO = dyn_cast(cond); + // FIXME: Currently we only support logical and assignment operators. + if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || + condUO) { + StmtDiff condDiff = Visit(cond); + // After Visit(cond) is called the derivative could either be recorded in + // condDiff.getExpr() or condDiff.getExpr_dx(), hence we should build cond + // differently which is implemented below visiting statements like "(x=0)" + // records the differentiated statement in condDiff.getExpr_dx(), meaning + // we have to build in the form ((cond_dx), (cond)), wrapping cond_dx and + // cond into parentheses. + // + // Visiting statements like "(x=0) || false" records the result in + // condDiff.getExpr(), meaning the differentiated condition is already. + if (condDiff.getExpr_dx() && + (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), + BuildParens(condDiff.getExpr())); + else + cond = condDiff.getExpr(); + } } - Expr* condClone = WS->getCond() ? Clone(WS->getCond()) : nullptr; Sema::ConditionResult condRes; - if (condVarClone) { - condRes = m_Sema.ActOnConditionVariable(condVarClone, noLoc, - Sema::ConditionKind::Boolean); - } else { - condRes = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone, - Sema::ConditionKind::Boolean); - } + condRes = m_Sema.ActOnCondition(getCurrentScope(), noLoc, cond, + Sema::ConditionKind::Boolean); const Stmt* body = WS->getBody(); Stmt* bodyResult = nullptr; @@ -1679,29 +1712,6 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) { endScope(); bodyResult = Block; } - // Since condition variable is created and initialized at each iteration, - // derivative of condition variable should also get created and initialized - // at each iteratrion. Therefore, we need to insert declaration statement - // of derivative of condition variable, if any, on top of the derived body - // of the while loop. - // - // while (double b = a) { - // ... - // ... - // } - // - // gets differentiated to, - // - // while (double b = a) { - // double _d_b = _d_a; - // ... - // ... - // } - if (condVarClone) { - bodyResult = utils::PrependAndCreateCompoundStmt( - m_Sema.getASTContext(), cast(bodyResult), - BuildDeclStmt(condVarRes.getDecl_dx())); - } Stmt* WSDiff = clad_compat::Sema_ActOnWhileStmt(m_Sema, condRes, bodyResult).get(); diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 43a35c749..152984bcd 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -288,9 +288,10 @@ double fn7(double i, double j) { // CHECK-NEXT: int b = 3; // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: double res = 0; -// CHECK-NEXT: while (double a = b) +// CHECK-NEXT: double _d_a; +// CHECK-NEXT: double a; +// CHECK-NEXT: while ((_d_a = _d_b) , (a = b)) // CHECK-NEXT: { -// CHECK-NEXT: double _d_a = _d_b; // CHECK-NEXT: _d_a += _d_i; // CHECK-NEXT: a += i; // CHECK-NEXT: _d_res += _d_a; @@ -343,9 +344,10 @@ double fn9(double i, double j) { // CHECK-NEXT: int counter = 4; // CHECK-NEXT: double _d_a = _d_i * j + i * _d_j; // CHECK-NEXT: double a = i * j; -// CHECK-NEXT: while (int num = counter) +// CHECK-NEXT: int _d_num; +// CHECK-NEXT: int num; +// CHECK-NEXT: while ((_d_num = _d_counter) , (num = counter)) // CHECK-NEXT: { -// CHECK-NEXT: int _d_num = _d_counter; // CHECK-NEXT: _d_counter -= 0; // CHECK-NEXT: counter -= 1; // CHECK-NEXT: if (num == 2) @@ -516,6 +518,35 @@ double fn15_darg0(double u, double v); // CHECK-NEXT: return 0 * res + 2 * _d_res; // CHECK-NEXT: } +double fn16(double x) { + while (double t = (x = 0)) {} + return x; +} // = 0 + +double fn16_darg0(double x); +// CHECK: double fn16_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_t; +// CHECK-NEXT: double t; +// CHECK-NEXT: while ((_d_t = (_d_x = 0)) , (t = (x = 0))) +// CHECK-NEXT: { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } + +double fn17(double x) { + while ((x = 0) || false) {} + return x; +} // = 0 + +double fn17_darg0(double x); +// CHECK: double fn17_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: while (((_d_x = 0) , (x = 0)) || false) +// CHECK-NEXT: { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ @@ -577,4 +608,10 @@ int main() { clad::differentiate(fn15, 0); printf("Result is = %.2f\n", fn15_darg0(7, 3)); // CHECK-EXEC: Result is = 6.00 + + clad::differentiate(fn16, 0); + printf("Result is = %.2f\n", fn16_darg0(5)); // CHECK-EXEC: Result is = 0 + + clad::differentiate(fn17, 0); + printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0 }