Skip to content

Commit

Permalink
Differentiate the RHS in multiplication instead of cloning by introdu…
Browse files Browse the repository at this point in the history
…cing placeholders

Currently, when differentiating multiplication in the reverse mode, we need to pass the differentiated LHS when visiting the RHS and vice versa (i.e. ``Visit(R, dR)`` and ``Visit(L, dL)`` where ``dR = LDiff * dfdx`` and ``dL = dfdx * RDiff``). This creates a loop that we break in one of two ways:
1) Create a variable to represent the result of the RHS visitation, use it to visit the LHS, visit the RHS, and then set the value of the variable introduced before to the result of the RHS visitation.
2) Clone the RHS, use it to visit the LHS, then visit the RHS.

The 1st approach is bad because it introduces a new variable that is usually unnecessary.

The 2nd approach is used more frequently but its downside is that cloning is not the same as visiting, some expressions cannot be cloned. e.g.
a) ``x > 0 ? a : b`` is cloned as ``x > 0 ? a : b`` but differentiated as ``_cond ? a : b``, where ``_cond`` is a variable to save the condition ``x > 0``.
b) References are often turned into pointers in the reverse mode. Because of that, ``x`` should be differentiated as ``*x``  and not as just ``x``, which ``Clone`` does. We already have an exception to handle ref-type decl refs as RHS of multiplication, which is removed in this PR.

This commit solves this problem in a third way: by passing a literal expression when visiting LHS and replacing it when the differentiated RHS is known. All of the new logic is sunk into existing functions ``DelayedGlobalStoreAndRef`` and ``Finalize``.
  • Loading branch information
PetroZarytskyi committed Aug 16, 2024
1 parent 135f6f9 commit cf187dc
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 65 deletions.
16 changes: 9 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,15 @@ namespace clad {
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
clang::VarDecl* pDeclaration, bool pIsInsideLoop,
bool pIsFnScope, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
isInsideLoop(pIsInsideLoop), isFnScope(pIsFnScope),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(clang::Expr* New, clang::Sema& S);
};

/// Sometimes (e.g. when visiting multiplication/division operator), we
Expand All @@ -292,7 +293,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceNoRecompute = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand Down
133 changes: 84 additions & 49 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,25 +2296,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).

// Check if the local variable declaration is reference type, since it is
// moved to the global scope and the right side should be recomputed
bool promoteToFnScope = false;
if (auto* RDeclRef = dyn_cast<DeclRefExpr>(R->IgnoreImplicit()))
promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() &&
!getCurrentScope()->isFunctionScope();

if (!ShouldRecompute(R) || promoteToFnScope) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}
DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff& RResult = RDelayed.Result;

Expr* dl = nullptr;
if (dfdx())
Expand All @@ -2336,30 +2319,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff.getExpr(), m_Sema);
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (opCode == BO_Div) {
// xi = xl / xr
// dxi/xl = 1 / xr
// df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr)
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceNoRecompute=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* dl = nullptr;
if (dfdx())
dl = BuildOp(BO_Div, dfdx(), RStored);
dl = BuildOp(BO_Div, dfdx(), RResult.getExpr());
Ldiff = Visit(L, dl);
StmtDiff LStored = Ldiff;
// Catch the pop statement and emit it after
Expand All @@ -2377,22 +2357,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* RxR = BuildParens(
BuildOp(BO_Mul, RResult.getExpr(), RResult.getExpr()));
dr = BuildOp(BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildParens(BuildOp(
BO_Div, LStored.getRevSweepAsExpr(), RxR))));
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff.getExpr(), m_Sema);
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2588,26 +2570,29 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceNoRecompute=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* dr = BuildOp(BO_Mul, oldValue,
BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR)));
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff.getExpr(), m_Sema);
}
valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult);
} else
llvm_unreachable("unknown assignment opCode");
if (m_ExternalSource)
Expand Down Expand Up @@ -3367,9 +3352,51 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {Store, Restore};
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
if (isConstant || !needsUpdate)
void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New, Sema& S) {
class PlaceholderReplacer
: public RecursiveASTVisitor<PlaceholderReplacer> {
public:
const Expr* placeholder;
Sema& m_Sema;
ASTContext& m_Context;
Expr* newExpr{nullptr};
PlaceholderReplacer(const Expr* Placeholder, Sema& S)
: placeholder(Placeholder), m_Sema(S), m_Context(S.getASTContext()) {}

bool VisitExpr(Expr* E) const {
for (auto iter = E->child_begin(), e = E->child_end(); iter != e;
++iter)
if (*iter == placeholder) {
Expr* castExpr = newExpr;
// Since we are manually replacing the statement, implicit casts are
// not generated automatically.
QualType exprTy = newExpr->getType();
QualType targetTy = cast<Expr>(*iter)->getType();
if (exprTy->isIntegralType(m_Context) && targetTy->isFloatingType())
castExpr = m_Sema
.ImpCastExprToType(newExpr, targetTy,
CK_IntegralToFloating)
.get();
*iter = castExpr;
}
return true;
}
};

if (!needsUpdate)
return;

if (Placeholder) {
PlaceholderReplacer repl(Placeholder, S);
repl.newExpr = New;
for (Stmt* S : V.getCurrentBlock(direction::forward))
repl.TraverseStmt(S);
for (Stmt* S : V.getCurrentBlock(direction::reverse))
repl.TraverseStmt(S);
Result = New;
return;
}

if (isInsideLoop) {
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Expand All @@ -3385,22 +3412,32 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

ReverseModeVisitor::DelayedStoreResult
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E,
llvm::StringRef prefix) {
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, llvm::StringRef prefix,
bool forceNoRecompute) {
assert(E && "must be provided");
if (!UsefulToStore(E)) {
StmtDiff Ediff = Visit(E);
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this,
Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
}
if (!forceNoRecompute && ShouldRecompute(E)) {
// The value of the literal has no. It's given a very particular value for
// easier debugging.
Expr* PH = ConstantFolder::synthesizeLiteral(E->getType(), m_Context,
/*val=*/31);
return DelayedStoreResult{*this,
StmtDiff{PH, nullptr, nullptr, PH},
/*Declaration=*/nullptr,
/*isInsideLoop*/ false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expand All @@ -3409,7 +3446,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
Expand All @@ -3425,7 +3461,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
Expand Down
8 changes: 7 additions & 1 deletion test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ double func5(double* x, double* y, double* output) {

//CHECK: void func5_grad(double *x, double *y, double *output, double *_d_x, double *_d_y, double *_d_output, double &_final_error) {
//CHECK-NEXT: unsigned {{int|long}} output_size = 0;
//CHECK-NEXT: unsigned {{int|long}} x_size = 0;
//CHECK-NEXT: unsigned {{int|long}} y_size = 0;
//CHECK-NEXT: unsigned {{int|long}} x_size = 0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: double _t0 = output[0];
//CHECK-NEXT: output[0] = x[1] * y[2] - x[2] * y[1];
Expand All @@ -249,10 +249,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[2] = _t2;
//CHECK-NEXT: double _r_d2 = _d_output[2];
//CHECK-NEXT: _d_output[2] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[0] += _r_d2 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[1] += x[0] * _r_d2;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[0] += -_r_d2 * x[1];
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[1] += y[0] * -_r_d2;
Expand All @@ -264,10 +266,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[1] = _t1;
//CHECK-NEXT: double _r_d1 = _d_output[1];
//CHECK-NEXT: _d_output[1] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[2] += _r_d1 * y[0];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[0] += x[2] * _r_d1;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[0] += -_r_d1 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[2] += x[0] * -_r_d1;
Expand All @@ -279,10 +283,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_output[0];
//CHECK-NEXT: _d_output[0] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[1] += _r_d0 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[2] += x[1] * _r_d0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[2] += -_r_d0 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[1] += x[2] * -_r_d0;
Expand Down
3 changes: 2 additions & 1 deletion test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: int _d_j = 0;
//CHECK-NEXT: int j = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: unsigned {{int|long}} b_size = 0;
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
Expand Down Expand Up @@ -111,6 +111,7 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}});
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: b_size = std::max(b_size, j);
//CHECK-NEXT: _d_a[i] += _r_d0 * b[j];
//CHECK-NEXT: a_size = std::max(a_size, i);
//CHECK-NEXT: _d_b[j] += a[i] * _r_d0;
Expand Down
26 changes: 25 additions & 1 deletion test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,29 @@ double fn_cond_add_assign(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT:}

double f_mult3(double i, double j) {
i = (i + j) * (i < 10 ? i : j);
return i;
}

//CHECK: void f_mult3_grad(double i, double j, double *_d_i, double *_d_j) {
//CHECK-NEXT: double _t0 = i;
//CHECK-NEXT: bool _cond0 = i < 10;
//CHECK-NEXT: i = (i + j) * (_cond0 ? i : j);
//CHECK-NEXT: *_d_i += 1;
//CHECK-NEXT: {
//CHECK-NEXT: i = _t0;
//CHECK-NEXT: double _r_d0 = *_d_i;
//CHECK-NEXT: *_d_i = 0;
//CHECK-NEXT: *_d_i += _r_d0 * (_cond0 ? i : j);
//CHECK-NEXT: *_d_j += _r_d0 * (_cond0 ? i : j);
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: *_d_i += (i + j) * _r_d0;
//CHECK-NEXT: else
//CHECK-NEXT: *_d_j += (i + j) * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -1158,5 +1181,6 @@ int main() {
INIT_GRADIENT(fn_cond_add_assign);
TEST_GRADIENT(fn_cond_add_assign, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {80.00, 48.00}


INIT_GRADIENT(f_mult3);
TEST_GRADIENT(f_mult3, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {11.00, 3.00}
}
2 changes: 1 addition & 1 deletion test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ int main() {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer);
// CHECK-NEXT: *_d_obj.x_pointer += 1 * *obj.y_pointer;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
Expand Down
Loading

0 comments on commit cf187dc

Please sign in to comment.