Skip to content

Commit

Permalink
Differentiate the RHS in multiplication instead of cloning by introdu…
Browse files Browse the repository at this point in the history
…cing placeholders

Currently, when differentiating multiplication in the reverse mode, we need to pass the differentiated LHS when visiting the RHS and vice versa (i.e. ``Visit(R, dR)`` and ``Visit(L, dL)`` where ``dR = LDiff * dfdx`` and ``dL = dfdx * RDiff``). This creates a loop that we break in one of two ways:
1) Create a variable to represent the result of the RHS visitation, use it to visit the LHS, visit the RHS, and then set the value of the variable introduced before to the result of the RHS visitation.
2) Clone the RHS, use it to visit the LHS, then visit the RHS.

The 1st approach is bad because it introduces a new variable that is usually unnecessary.

The 2nd approach is used more frequently but its downside is that cloning is not the same as visiting, some expressions cannot be cloned. e.g.
a) ``x > 0 ? a : b`` is cloned as ``x > 0 ? a : b`` but differentiated as ``_cond ? a : b``, where ``_cond`` is a variable to save the condition ``x > 0``.
b) References are often turned into pointers in the reverse mode. Because of that, ``x`` should be differentiated as ``*x``  and not as just ``x``, which ``Clone`` does. We already have an exception to handle ref-type decl refs as RHS of multiplication, which is removed in this PR.

This commit solves this problem in a third way: by passing a literal expression when visiting LHS and replacing it when the differentiated RHS is known. All of the new logic is sunk into existing functions ``DelayedGlobalStoreAndRef`` and ``Finalize``.
  • Loading branch information
PetroZarytskyi committed Aug 21, 2024
1 parent 293f1c4 commit f3bdb33
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 85 deletions.
14 changes: 8 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,14 @@ namespace clad {
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
clang::VarDecl* pDeclaration, bool pIsInsideLoop,
bool pIsFnScope, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
isInsideLoop(pIsInsideLoop), isFnScope(pIsFnScope),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(clang::Expr* New);
};

Expand All @@ -297,7 +298,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceStore = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand Down
173 changes: 103 additions & 70 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2302,25 +2302,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).

// Check if the local variable declaration is reference type, since it is
// moved to the global scope and the right side should be recomputed
bool promoteToFnScope = false;
if (auto* RDeclRef = dyn_cast<DeclRefExpr>(R->IgnoreImplicit()))
promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() &&
!getCurrentScope()->isFunctionScope();

if (!ShouldRecompute(R) || promoteToFnScope) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}
DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff& RResult = RDelayed.Result;

Expr* dl = nullptr;
if (dfdx())
Expand All @@ -2341,31 +2324,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
RDelayed.Finalize(Rdiff.getExpr());
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (opCode == BO_Div) {
// xi = xl / xr
// dxi/xl = 1 / xr
// df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr)
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceStore=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* dl = nullptr;
if (dfdx())
dl = BuildOp(BO_Div, dfdx(), RStored);
dl = BuildOp(BO_Div, dfdx(), RResult.getExpr());
Ldiff = Visit(L, dl);
StmtDiff LStored = Ldiff;
// Catch the pop statement and emit it after
Expand All @@ -2379,14 +2355,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
// dxi/xr = -xl / (xr * xr)
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
// dxi/xr = -xl / (xr * xr)
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* RxR = BuildParens(
BuildOp(BO_Mul, RResult.getExpr(), RResult.getExpr()));
dr = BuildOp(BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildParens(BuildOp(
Expand All @@ -2397,8 +2376,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
RDelayed.Finalize(Rdiff.getExpr());
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2594,26 +2572,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceStore=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
if (!RDelayed.isConstant) {
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* dr = BuildOp(BO_Mul, oldValue,
BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR)));
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* dr = BuildOp(BO_Mul, oldValue,
BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR)));
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult);
} else
llvm_unreachable("unknown assignment opCode");
if (m_ExternalSource)
Expand Down Expand Up @@ -3374,8 +3351,57 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
if (isConstant || !needsUpdate)
// Placeholders are used when we have to use an expr before we have that.
// For instance, this is necessary for multiplication and division when the
// RHS and LHS need the derivatives of each other to be differentiated. We
// need placeholders to break this loop.
class PlaceholderReplacer
: public RecursiveASTVisitor<PlaceholderReplacer> {
public:
const Expr* placeholder;
Sema& m_Sema;
ASTContext& m_Context;
Expr* newExpr{nullptr};
PlaceholderReplacer(const Expr* Placeholder, Sema& S)
: placeholder(Placeholder), m_Sema(S), m_Context(S.getASTContext()) {}

void Replace(ReverseModeVisitor& RMV, Expr* New, StmtDiff& Result) {
newExpr = New;
for (Stmt* S : RMV.getCurrentBlock(direction::forward))
TraverseStmt(S);
for (Stmt* S : RMV.getCurrentBlock(direction::reverse))
TraverseStmt(S);
Result = New;
}

// We chose iteration rather than visiting because we only do this for
// simple Expression subtrees and it is not worth it to implement an
// entire visitor infrastructure for simple replacements.
bool VisitExpr(Expr* E) const {
for (Stmt*& S : E->children())
if (S == placeholder) {
// Since we are manually replacing the statement, implicit casts are
// not generated automatically.
ExprResult newExprRes{newExpr};
QualType targetTy = cast<Expr>(S)->getType().withConst();
CastKind kind = m_Sema.PrepareScalarCast(newExprRes, targetTy);
S = m_Sema.ImpCastExprToType(newExpr, targetTy, kind).get();
}
return true;
}
PlaceholderReplacer(const PlaceholderReplacer&) = delete;
PlaceholderReplacer(PlaceholderReplacer&&) = delete;
};

if (!needsUpdate)
return;

if (Placeholder) {
PlaceholderReplacer repl(Placeholder, V.m_Sema);
repl.Replace(V, New, Result);
return;
}

if (isInsideLoop) {
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Expand All @@ -3391,21 +3417,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

ReverseModeVisitor::DelayedStoreResult
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E,
llvm::StringRef prefix) {
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, llvm::StringRef prefix,
bool forceStore) {
assert(E && "must be provided");
if (!UsefulToStore(E)) {
StmtDiff Ediff = Visit(E);
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this,
Ediff,
return DelayedStoreResult{*this, Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
/*isFnScope=*/false};
}
if (!forceStore && ShouldRecompute(E)) {
// The value of the literal has no. It's given a very particular value for
// easier debugging.
Expr* PH = ConstantFolder::synthesizeLiteral(E->getType(), m_Context,
/*val=*/~0U);
return DelayedStoreResult{
*this,
StmtDiff{PH, /*diff=*/nullptr, /*forwSweepDiff=*/nullptr, PH},
/*Declaration=*/nullptr,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
}
if (isInsideLoop) {
Expr* dummy = E;
Expand All @@ -3415,7 +3450,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
Expand All @@ -3431,7 +3465,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
Expand Down
8 changes: 7 additions & 1 deletion test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ double func5(double* x, double* y, double* output) {

//CHECK: void func5_grad(double *x, double *y, double *output, double *_d_x, double *_d_y, double *_d_output, double &_final_error) {
//CHECK-NEXT: unsigned {{int|long|long long}} output_size = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} x_size = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} y_size = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} x_size = 0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: double _t0 = output[0];
//CHECK-NEXT: output[0] = x[1] * y[2] - x[2] * y[1];
Expand All @@ -249,10 +249,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[2] = _t2;
//CHECK-NEXT: double _r_d2 = _d_output[2];
//CHECK-NEXT: _d_output[2] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[0] += _r_d2 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[1] += x[0] * _r_d2;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[0] += -_r_d2 * x[1];
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[1] += y[0] * -_r_d2;
Expand All @@ -264,10 +266,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[1] = _t1;
//CHECK-NEXT: double _r_d1 = _d_output[1];
//CHECK-NEXT: _d_output[1] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[2] += _r_d1 * y[0];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[0] += x[2] * _r_d1;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[0] += -_r_d1 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[2] += x[0] * -_r_d1;
Expand All @@ -279,10 +283,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_output[0];
//CHECK-NEXT: _d_output[0] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[1] += _r_d0 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[2] += x[1] * _r_d0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[2] += -_r_d0 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[1] += x[2] * -_r_d0;
Expand Down
3 changes: 2 additions & 1 deletion test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: int _d_j = 0;
//CHECK-NEXT: int j = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: unsigned {{int|long|long long}} a_size = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} b_size = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} a_size = 0;
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
Expand Down Expand Up @@ -111,6 +111,7 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}});
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: b_size = std::max(b_size, j);
//CHECK-NEXT: _d_a[i] += _r_d0 * b[j];
//CHECK-NEXT: a_size = std::max(a_size, i);
//CHECK-NEXT: _d_b[j] += a[i] * _r_d0;
Expand Down
Loading

0 comments on commit f3bdb33

Please sign in to comment.