Skip to content

Commit

Permalink
Add support for assignments in while-loops
Browse files Browse the repository at this point in the history
This PR adds support for assignments in while-loops. It also enables to combine
multiple assignments in the while condition by adding supprot for some logoical
operators.

Fixes: #913
  • Loading branch information
Max Andriychuk authored and vgvassilev committed Jun 19, 2024
1 parent 2a3e9bc commit 6168843
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 36 deletions.
74 changes: 42 additions & 32 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
// Visit(cond)
auto* condBO = dyn_cast<BinaryOperator>(cond);
auto* condUO = dyn_cast<UnaryOperator>(cond);
// FIXME: Currently we only support logical and assignment operators.
if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) ||
condUO) {
condDiff = Visit(cond);
Expand Down Expand Up @@ -1650,20 +1651,52 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
const VarDecl* condVar = WS->getConditionVariable();
VarDecl* condVarClone = nullptr;
DeclDiff<VarDecl> condVarRes;

StmtDiff condDiff = Clone(WS->getCond());
Expr* cond = condDiff.getExpr();

// Check if the condition contais a variable declaration and create a
// declaration of both the variable and it's adjoint before the while-loop.
if (condVar) {
condVarRes = DifferentiateVarDecl(condVar);
condVarRes = DifferentiateVarDecl(condVar, /*ignoreInit=*/true);
condVarClone = condVarRes.getDecl();
if (condVarRes.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarRes.getDecl_dx()));
auto* condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}
// Assignments in the condition are allowed, differentiate.
if (cond) {
cond = cond->IgnoreParenImpCasts();
auto* condBO = dyn_cast<BinaryOperator>(cond);
auto* condUO = dyn_cast<UnaryOperator>(cond);
// FIXME: Currently we only support logical and assignment operators.
if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) ||
condUO) {
StmtDiff condDiff = Visit(cond);
// After Visit(cond) is called the derivative could either be recorded in
// condDiff.getExpr() or condDiff.getExpr_dx(), hence we should build cond
// differently which is implemented below visiting statements like "(x=0)"
// records the differentiated statement in condDiff.getExpr_dx(), meaning
// we have to build in the form ((cond_dx), (cond)), wrapping cond_dx and
// cond into parentheses.
//
// Visiting statements like "(x=0) || false" records the result in
// condDiff.getExpr(), meaning the differentiated condition is already.
if (condDiff.getExpr_dx() &&
(!isUnusedResult(condDiff.getExpr_dx()) || condUO))
cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()),
BuildParens(condDiff.getExpr()));
else
cond = condDiff.getExpr();
}
}
Expr* condClone = WS->getCond() ? Clone(WS->getCond()) : nullptr;

Sema::ConditionResult condRes;
if (condVarClone) {
condRes = m_Sema.ActOnConditionVariable(condVarClone, noLoc,
Sema::ConditionKind::Boolean);
} else {
condRes = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone,
Sema::ConditionKind::Boolean);
}
condRes = m_Sema.ActOnCondition(getCurrentScope(), noLoc, cond,
Sema::ConditionKind::Boolean);

const Stmt* body = WS->getBody();
Stmt* bodyResult = nullptr;
Expand All @@ -1679,29 +1712,6 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
endScope();
bodyResult = Block;
}
// Since condition variable is created and initialized at each iteration,
// derivative of condition variable should also get created and initialized
// at each iteratrion. Therefore, we need to insert declaration statement
// of derivative of condition variable, if any, on top of the derived body
// of the while loop.
//
// while (double b = a) {
// ...
// ...
// }
//
// gets differentiated to,
//
// while (double b = a) {
// double _d_b = _d_a;
// ...
// ...
// }
if (condVarClone) {
bodyResult = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), cast<CompoundStmt>(bodyResult),
BuildDeclStmt(condVarRes.getDecl_dx()));
}

Stmt* WSDiff =
clad_compat::Sema_ActOnWhileStmt(m_Sema, condRes, bodyResult).get();
Expand Down
45 changes: 41 additions & 4 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ double fn7(double i, double j) {
// CHECK-NEXT: int b = 3;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: while (double a = b)
// CHECK-NEXT: double _d_a;
// CHECK-NEXT: double a;
// CHECK-NEXT: while ((_d_a = _d_b) , (a = b))
// CHECK-NEXT: {
// CHECK-NEXT: double _d_a = _d_b;
// CHECK-NEXT: _d_a += _d_i;
// CHECK-NEXT: a += i;
// CHECK-NEXT: _d_res += _d_a;
Expand Down Expand Up @@ -343,9 +344,10 @@ double fn9(double i, double j) {
// CHECK-NEXT: int counter = 4;
// CHECK-NEXT: double _d_a = _d_i * j + i * _d_j;
// CHECK-NEXT: double a = i * j;
// CHECK-NEXT: while (int num = counter)
// CHECK-NEXT: int _d_num;
// CHECK-NEXT: int num;
// CHECK-NEXT: while ((_d_num = _d_counter) , (num = counter))
// CHECK-NEXT: {
// CHECK-NEXT: int _d_num = _d_counter;
// CHECK-NEXT: _d_counter -= 0;
// CHECK-NEXT: counter -= 1;
// CHECK-NEXT: if (num == 2)
Expand Down Expand Up @@ -516,6 +518,35 @@ double fn15_darg0(double u, double v);
// CHECK-NEXT: return 0 * res + 2 * _d_res;
// CHECK-NEXT: }

double fn16(double x) {
while (double t = (x = 0)) {}
return x;
} // = 0

double fn16_darg0(double x);
// CHECK: double fn16_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_t;
// CHECK-NEXT: double t;
// CHECK-NEXT: while ((_d_t = (_d_x = 0)) , (t = (x = 0)))
// CHECK-NEXT: {
// CHECK-NEXT: }
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn17(double x) {
while ((x = 0) || false) {}
return x;
} // = 0

double fn17_darg0(double x);
// CHECK: double fn17_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: while (((_d_x = 0) , (x = 0)) || false)
// CHECK-NEXT: {
// CHECK-NEXT: }
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
Expand Down Expand Up @@ -577,4 +608,10 @@ int main() {

clad::differentiate(fn15, 0);
printf("Result is = %.2f\n", fn15_darg0(7, 3)); // CHECK-EXEC: Result is = 6.00

clad::differentiate(fn16, 0);
printf("Result is = %.2f\n", fn16_darg0(5)); // CHECK-EXEC: Result is = 0

clad::differentiate(fn17, 0);
printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0
}

0 comments on commit 6168843

Please sign in to comment.