diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f06120c6c..c53757261 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -801,8 +801,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && - (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx()))) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1381,7 +1380,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { - return StmtDiff(op, diff.getExpr_dx()); + Expr* zero = getZeroInit(UnOp->getType()); + if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx())) + return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero}; + return {op, zero}; + } else if (opKind == UnaryOperatorKind::UO_Not) { + // ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones. + // Either way, taking a derivative gives us -_d_x. + Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx()); + return {op, derivedOp}; } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = @@ -1497,7 +1504,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); - } else if (BinOp->isLogicalOp()) { + } else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() || + BinOp->isComparisonOp() || opCode == BO_Rem) { // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and // correct derivative execution. auto buildOneSide = [this](StmtDiff& Xdiff) { @@ -1514,8 +1522,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { // Since the both parts are included in the opDiff, there's no point in // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left - // nullptr is used for treating expressions like ((A && B) && C) correctly. - return StmtDiff(opDiff, nullptr); + // zero is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, getZeroInit(BinOp->getType())); + } else if (BinOp->isShiftOp()) { + // Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS). + // We should do the same to the derivarive. + opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr()); } else { // FIXME: add support for other binary operators unsupportedOpWarn(BinOp->getEndLoc()); diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 0f59ac961..551391407 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -6,12 +6,10 @@ //CHECK-NOT: {{.*error|warning|note:.*}} int binOpWarn_0(int x){ - return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int binOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void binOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } @@ -23,17 +21,14 @@ int binOpWarn_1(int x){ // CHECK-NEXT: } int unOpWarn_0(int x){ - return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int unOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void unOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } int main(){ - - clad::differentiate(binOpWarn_0, 0); + clad::gradient(binOpWarn_0); clad::gradient(binOpWarn_1); - clad::differentiate(unOpWarn_0, 0); + clad::gradient(unOpWarn_0); }