From 33d9441969ba4e9643543cf8ed6037227780ec3d Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 1 Aug 2024 19:50:22 +0300 Subject: [PATCH] Support bitwise, shift, comparison, remainder, not operators. This commit adds support for bitwise, shift, comparison, remainder, and bitwise not operators. Shift operators are considered differentiable since they essentially represent multiplication by ``2^n`` or ``2^-n``, where ``n`` is the RHS of the shift operators ``<<`` and ``>>``. Not operators are considered differentiable as well because they represent ``2^n - 1 - x`` or ``- 1 - x`` (depending on whether the type is signed) so the derivative is ``-_d_x``. Other operators have unclear differentiable effects and so they are considered non-differentiable. Fixes #381. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 24 ++++++++++++++----- test/FirstDerivative/UnsupportedOpsWarn.C | 17 +++++-------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f06120c6c..c53757261 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -801,8 +801,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && - (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx()))) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1381,7 +1380,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { - return StmtDiff(op, diff.getExpr_dx()); + Expr* zero = getZeroInit(UnOp->getType()); + if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx())) + return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero}; + return {op, zero}; + } else if (opKind == UnaryOperatorKind::UO_Not) { + // ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones. + // Either way, taking a derivative gives us -_d_x. + Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx()); + return {op, derivedOp}; } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = @@ -1497,7 +1504,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); - } else if (BinOp->isLogicalOp()) { + } else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() || + BinOp->isComparisonOp() || opCode == BO_Rem) { // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and // correct derivative execution. auto buildOneSide = [this](StmtDiff& Xdiff) { @@ -1514,8 +1522,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { // Since the both parts are included in the opDiff, there's no point in // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left - // nullptr is used for treating expressions like ((A && B) && C) correctly. - return StmtDiff(opDiff, nullptr); + // zero is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, getZeroInit(BinOp->getType())); + } else if (BinOp->isShiftOp()) { + // Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS). + // We should do the same to the derivarive. + opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr()); } else { // FIXME: add support for other binary operators unsupportedOpWarn(BinOp->getEndLoc()); diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 0f59ac961..551391407 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -6,12 +6,10 @@ //CHECK-NOT: {{.*error|warning|note:.*}} int binOpWarn_0(int x){ - return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int binOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void binOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } @@ -23,17 +21,14 @@ int binOpWarn_1(int x){ // CHECK-NEXT: } int unOpWarn_0(int x){ - return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int unOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void unOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } int main(){ - - clad::differentiate(binOpWarn_0, 0); + clad::gradient(binOpWarn_0); clad::gradient(binOpWarn_1); - clad::differentiate(unOpWarn_0, 0); + clad::gradient(unOpWarn_0); }