Skip to content

Commit

Permalink
Add support for && operator
Browse files Browse the repository at this point in the history
Add support for differentiation of expressions which include && operator.
Check whether then/else block of if stmt is empty before adding it to
reverse or forward block.
  • Loading branch information
rohanjulka19 authored and vgvassilev committed Jun 22, 2024
1 parent 1267d57 commit 216201a
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 11 deletions.
24 changes: 23 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
endScope();
return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock),
utils::unwrapIfSingleStmt(ReverseBlock));
utils::unwrapIfSingleStmt(ReverseBlock),
/*forwSweepDiff=*/nullptr,
/*valueForRevSweep=*/condDiffStored);
}

StmtDiff ReverseModeVisitor::VisitConditionalOperator(
Expand Down Expand Up @@ -2382,6 +2384,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dfdx());
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
VarDecl* condVar = GlobalStoreImpl(m_Context.BoolTy, "_cond");
VarDecl* derivedCondVar = GlobalStoreImpl(
m_Context.DoubleTy, "_d" + condVar->getNameAsString());
Expr* condVarRef = BuildDeclRef(condVar);
Expr* assignExpr = BuildOp(BO_Assign, condVarRef, Clone(R));
m_Variables.emplace(condVar, BuildDeclRef(derivedCondVar));
auto* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/L, /*LPL=*/noLoc, /*RPL=*/noLoc, /*Then=*/assignExpr,
/*EL=*/noLoc,
/*Else=*/nullptr);

StmtDiff IfStmtDiff = VisitIfStmt(IfStmt);
addToCurrentBlock(utils::unwrapIfSingleStmt(IfStmtDiff.getStmt()));
addToCurrentBlock(utils::unwrapIfSingleStmt(IfStmtDiff.getStmt_dx()),
direction::reverse);
auto* condDiffStored = IfStmtDiff.getRevSweepAsExpr();
return BuildOp(BO_LAnd, condDiffStored, condVarRef);
} else {
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
Expand Down
40 changes: 30 additions & 10 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -966,16 +966,36 @@ double sq_defined_later(double x) {
// CHECK-NEXT: }

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: double _d_cond0;
// CHECK-NEXT: bool _cond1;
// CHECK-NEXT: bool _t0;
// CHECK-NEXT: bool _cond2;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: _cond1 = c == 'a';
// CHECK-NEXT: if (_cond1) {
// CHECK-NEXT: _t0 = _cond0;
// CHECK-NEXT: _cond0 = s[0] == 'a';
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _cond2 = _cond1 && _cond0;
// CHECK-NEXT: if (_cond2)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond2)
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond1) {
// CHECK-NEXT: _cond0 = _t0;
// CHECK-NEXT: double _r_d0 = _d_cond0;
// CHECK-NEXT: _d_cond0 -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: void custom_max_pullback(const double &a, const double &b, double _d_y, double *_d_a, double *_d_b) {
// CHECK-NEXT: bool _cond0;
Expand Down
155 changes: 155 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,154 @@ double fn_empty_if_else(double x) {
//CHECK-NEXT: }
//CHECK-NEXT:}

double fn_cond_false(double i, double j) {
double res = 0;
if (i*j && res > 0) {
res = 6 * i * j;
}
return res;
}

// CHECK: void fn_cond_false_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: double _d_cond0;
// CHECK-NEXT: bool _cond1;
// CHECK-NEXT: bool _t0;
// CHECK-NEXT: bool _cond2;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: _cond1 = i * j;
// CHECK-NEXT: if (_cond1) {
// CHECK-NEXT: _t0 = _cond0;
// CHECK-NEXT: _cond0 = res > 0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _cond2 = _cond1 && _cond0;
// CHECK-NEXT: if (_cond2) {
// CHECK-NEXT: _t1 = res;
// CHECK-NEXT: res = 6 * i * j;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond2) {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t1;
// CHECK-NEXT: double _r_d1 = _d_res;
// CHECK-NEXT: _d_res -= _r_d1;
// CHECK-NEXT: *_d_i += 6 * _r_d1 * j;
// CHECK-NEXT: *_d_j += 6 * i * _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond1) {
// CHECK-NEXT: _cond0 = _t0;
// CHECK-NEXT: double _r_d0 = _d_cond0;
// CHECK-NEXT: _d_cond0 -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}

double fn_cond_add_assign(double i, double j) {
double res = 0;
if ((res = 2 * i * j) && (res += 3 * i * j) && (res += 5 * i * j)) {
res += 6 * i * j;
}
return res;
}

// CHECK: void fn_cond_add_assign_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: double _d_cond0;
// CHECK-NEXT: bool _cond1;
// CHECK-NEXT: double _d_cond1;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: bool _cond2;
// CHECK-NEXT: bool _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: bool _cond3;
// CHECK-NEXT: bool _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: bool _cond4;
// CHECK-NEXT: double _t5;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: _t0 = res;
// CHECK-NEXT: _cond2 = (res = 2 * i * j);
// CHECK-NEXT: if (_cond2) {
// CHECK-NEXT: _t1 = _cond1;
// CHECK-NEXT: _t2 = res;
// CHECK-NEXT: _cond1 = (res += 3 * i * j);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _cond3 = _cond2 && _cond1;
// CHECK-NEXT: if (_cond3) {
// CHECK-NEXT: _t3 = _cond0;
// CHECK-NEXT: _t4 = res;
// CHECK-NEXT: _cond0 = (res += 5 * i * j);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _cond4 = _cond3 && _cond0;
// CHECK-NEXT: if (_cond4) {
// CHECK-NEXT: _t5 = res;
// CHECK-NEXT: res += 6 * i * j;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond4) {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t5;
// CHECK-NEXT: double _r_d5 = _d_res;
// CHECK-NEXT: *_d_i += 6 * _r_d5 * j;
// CHECK-NEXT: *_d_j += 6 * i * _r_d5;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond3) {
// CHECK-NEXT: _cond0 = _t3;
// CHECK-NEXT: double _r_d3 = _d_cond0;
// CHECK-NEXT: _d_cond0 -= _r_d3;
// CHECK-NEXT: _d_res += _r_d3;
// CHECK-NEXT: res = _t4;
// CHECK-NEXT: double _r_d4 = _d_res;
// CHECK-NEXT: *_d_i += 5 * _r_d4 * j;
// CHECK-NEXT: *_d_j += 5 * i * _r_d4;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: if (_cond2) {
// CHECK-NEXT: _cond1 = _t1;
// CHECK-NEXT: double _r_d1 = _d_cond1;
// CHECK-NEXT: _d_cond1 -= _r_d1;
// CHECK-NEXT: _d_res += _r_d1;
// CHECK-NEXT: res = _t2;
// CHECK-NEXT: double _r_d2 = _d_res;
// CHECK-NEXT: *_d_i += 3 * _r_d2 * j;
// CHECK-NEXT: *_d_j += 3 * i * _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: res = _t0;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res -= _r_d0;
// CHECK-NEXT: *_d_i += 2 * _r_d0 * j;
// CHECK-NEXT: *_d_j += 2 * i * _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -1006,4 +1154,11 @@ int main() {
INIT_GRADIENT(fn_empty_if_else);
TEST_GRADIENT(fn_empty_if_else, /*numOfDerivativeArgs=*/1, 1, &dx); // CHECK-EXEC: 5.00

INIT_GRADIENT(fn_cond_false);
TEST_GRADIENT(fn_cond_false, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {0.00, 0.00}

INIT_GRADIENT(fn_cond_add_assign);
TEST_GRADIENT(fn_cond_add_assign, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {80.00, 48.00}


}

0 comments on commit 216201a

Please sign in to comment.