From 060ad8bc78056ba191d8ec78cd81d5fb5ff6f04a Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 14 Aug 2024 01:07:45 +0300 Subject: [PATCH] Differentiate the RHS in multiplication instead of cloning by introducing placeholders --- .../clad/Differentiator/ReverseModeVisitor.h | 14 ++- lib/Differentiator/ReverseModeVisitor.cpp | 112 +++++++++++------- test/ErrorEstimation/LoopsAndArrays.C | 8 +- test/ErrorEstimation/LoopsAndArraysExec.C | 3 +- test/Gradient/NonDifferentiable.C | 2 +- test/Gradient/Pointers.C | 10 +- 6 files changed, 89 insertions(+), 60 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 176b5cf71..4030af314 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -273,13 +273,14 @@ namespace clad { bool isInsideLoop; bool isFnScope; bool needsUpdate; + clang::Expr* Placeholder; DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult, - clang::VarDecl* pDeclaration, bool pIsConstant, - bool pIsInsideLoop, bool pIsFnScope, - bool pNeedsUpdate = false) + clang::VarDecl* pDeclaration, bool pIsInsideLoop, + bool pIsFnScope, bool pNeedsUpdate = false, + clang::Expr* pPlaceholder = nullptr) : V(pV), Result(pResult), Declaration(pDeclaration), - isConstant(pIsConstant), isInsideLoop(pIsInsideLoop), - isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {} + isInsideLoop(pIsInsideLoop), isFnScope(pIsFnScope), + needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {} void Finalize(clang::Expr* New); }; @@ -292,7 +293,8 @@ namespace clad { /// This is what DelayedGlobalStoreAndRef does. E is expected to be the /// original (uncloned) expression. DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E, - llvm::StringRef prefix = "_t"); + llvm::StringRef prefix = "_t", + bool forceNoRecompute = false); struct CladTapeResult { ReverseModeVisitor& V; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9c7cb542b..7f43a6057 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2296,25 +2296,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // to reduce cloning complexity and only clones once. Storing it in a // global variable allows to save current result and make it accessible // in the reverse pass. - std::unique_ptr RDelayed; - StmtDiff RResult; - // If R has no side effects, it can be just cloned - // (no need to store it). - - // Check if the local variable declaration is reference type, since it is - // moved to the global scope and the right side should be recomputed - bool promoteToFnScope = false; - if (auto* RDeclRef = dyn_cast(R->IgnoreImplicit())) - promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() && - !getCurrentScope()->isFunctionScope(); - - if (!ShouldRecompute(R) || promoteToFnScope) { - RDelayed = std::unique_ptr( - new DelayedStoreResult(DelayedGlobalStoreAndRef(R))); - RResult = RDelayed->Result; - } else { - RResult = StmtDiff(Clone(R)); - } + DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff& RResult = RDelayed.Result; Expr* dl = nullptr; if (dfdx()) @@ -2336,30 +2319,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /*force=*/true); Stmt* LPop = endBlock(direction::reverse); Expr::EvalResult dummy; - if (RDelayed || - !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) || + RDelayed.needsUpdate) { Expr* dr = nullptr; if (dfdx()) dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx()); Rdiff = Visit(R, dr); // Assign right multiplier's variable with R. - if (RDelayed) - RDelayed->Finalize(Rdiff.getExpr()); + RDelayed.Finalize(Rdiff.getExpr()); } addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse); - std::tie(Ldiff, Rdiff) = - std::make_pair(LStored.getExpr(), RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult); } else if (opCode == BO_Div) { // xi = xl / xr // dxi/xl = 1 / xr // df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr) - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; - Expr* RStored = - StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); + auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", + /*forceNoRecompute=*/true); + StmtDiff& RResult = RDelayed.Result; Expr* dl = nullptr; if (dfdx()) - dl = BuildOp(BO_Div, dfdx(), RStored); + dl = BuildOp(BO_Div, dfdx(), RResult.getExpr()); Ldiff = Visit(L, dl); StmtDiff LStored = Ldiff; // Catch the pop statement and emit it after @@ -2377,10 +2357,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is // produced instead of 1 / (R * R). - if (!RDelayed.isConstant) { + Expr::EvalResult dummy; + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) || + RDelayed.needsUpdate) { Expr* dr = nullptr; if (dfdx()) { - Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); + Expr* RxR = BuildParens( + BuildOp(BO_Mul, RResult.getExpr(), RResult.getExpr())); dr = BuildOp(BO_Mul, dfdx(), BuildOp(UO_Minus, BuildParens(BuildOp( @@ -2391,8 +2374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, RDelayed.Finalize(Rdiff.getExpr()); } addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse); - std::tie(Ldiff, Rdiff) = - std::make_pair(LStored.getExpr(), RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult); } else if (BinOp->isAssignmentOp()) { if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { diag(DiagnosticsEngine::Warning, @@ -2588,14 +2570,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* zero = getZeroInit(ResultRef->getType()); addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero), direction::reverse); - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; + auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", + /*forceNoRecompute=*/true); + StmtDiff& RResult = RDelayed.Result; Expr* RStored = StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef, BuildOp(BO_Div, oldValue, RStored)), direction::reverse); - if (!RDelayed.isConstant) { + Expr::EvalResult dummy; + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) || + RDelayed.needsUpdate) { if (isInsideLoop) addToCurrentBlock(LCloned, direction::forward); Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); @@ -2607,7 +2592,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(), Ldiff.getRevSweepAsExpr()); - std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult); } else llvm_unreachable("unknown assignment opCode"); if (m_ExternalSource) @@ -3368,8 +3353,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) { - if (isConstant || !needsUpdate) + class PlaceholderReplacer + : public RecursiveASTVisitor { + public: + const Expr* placeholder; + Expr* newExpr{nullptr}; + PlaceholderReplacer(const Expr* Placeholder) : placeholder(Placeholder) {} + + bool VisitExpr(Expr* E) { + for (auto iter = E->child_begin(), e = E->child_end(); iter != e; + ++iter) + if (*iter == placeholder) + *iter = newExpr; + return true; + } + }; + + if (!needsUpdate) return; + + if (Placeholder) { + PlaceholderReplacer repl(Placeholder); + repl.newExpr = New; + for (Stmt* S : V.getCurrentBlock(direction::forward)) + repl.TraverseStmt(S); + for (Stmt* S : V.getCurrentBlock(direction::reverse)) + repl.TraverseStmt(S); + Result = New; + return; + } + if (isInsideLoop) { auto* Push = cast(Result.getExpr()); unsigned lastArg = Push->getNumArgs() - 1; @@ -3385,22 +3398,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } ReverseModeVisitor::DelayedStoreResult - ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, - llvm::StringRef prefix) { + ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, llvm::StringRef prefix, + bool forceNoRecompute) { assert(E && "must be provided"); if (!UsefulToStore(E)) { StmtDiff Ediff = Visit(E); Expr::EvalResult evalRes; - bool isConst = - clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context); return DelayedStoreResult{*this, Ediff, /*Declaration=*/nullptr, - /*isConstant=*/isConst, /*isInsideLoop=*/false, /*isFnScope=*/false, /*pNeedsUpdate=*/false}; } + if (!forceNoRecompute && ShouldRecompute(E)) { + // The value of the literal has no. It's given a very particular value for + // easier debugging. + Expr* PH = ConstantFolder::synthesizeLiteral(E->getType(), m_Context, 31); + return DelayedStoreResult{*this, + StmtDiff{PH, nullptr, nullptr, PH}, + /*Declaration=*/nullptr, + /*isInsideLoop*/ false, + /*isFnScope=*/false, + /*pNeedsUpdate=*/true, + /*pPlaceholder=*/PH}; + } if (isInsideLoop) { Expr* dummy = E; auto CladTape = MakeCladTapeFor(dummy); @@ -3409,7 +3431,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop}, /*Declaration=*/nullptr, - /*isConstant=*/false, /*isInsideLoop=*/true, /*isFnScope=*/false, /*pNeedsUpdate=*/true}; @@ -3425,7 +3446,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref}, /*Declaration=*/VD, - /*isConstant=*/false, /*isInsideLoop=*/false, /*isFnScope=*/isFnScope, /*pNeedsUpdate=*/true}; diff --git a/test/ErrorEstimation/LoopsAndArrays.C b/test/ErrorEstimation/LoopsAndArrays.C index 98e1ba0bd..0265ea1d6 100644 --- a/test/ErrorEstimation/LoopsAndArrays.C +++ b/test/ErrorEstimation/LoopsAndArrays.C @@ -226,8 +226,8 @@ double func5(double* x, double* y, double* output) { //CHECK: void func5_grad(double *x, double *y, double *output, double *_d_x, double *_d_y, double *_d_output, double &_final_error) { //CHECK-NEXT: unsigned {{int|long}} output_size = 0; -//CHECK-NEXT: unsigned {{int|long}} x_size = 0; //CHECK-NEXT: unsigned {{int|long}} y_size = 0; +//CHECK-NEXT: unsigned {{int|long}} x_size = 0; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: double _t0 = output[0]; //CHECK-NEXT: output[0] = x[1] * y[2] - x[2] * y[1]; @@ -249,10 +249,12 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: output[2] = _t2; //CHECK-NEXT: double _r_d2 = _d_output[2]; //CHECK-NEXT: _d_output[2] = 0; +//CHECK-NEXT: y_size = std::max(y_size, 1); //CHECK-NEXT: _d_x[0] += _r_d2 * y[1]; //CHECK-NEXT: x_size = std::max(x_size, 0); //CHECK-NEXT: _d_y[1] += x[0] * _r_d2; //CHECK-NEXT: y_size = std::max(y_size, 1); +//CHECK-NEXT: x_size = std::max(x_size, 1); //CHECK-NEXT: _d_y[0] += -_r_d2 * x[1]; //CHECK-NEXT: y_size = std::max(y_size, 0); //CHECK-NEXT: _d_x[1] += y[0] * -_r_d2; @@ -264,10 +266,12 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: output[1] = _t1; //CHECK-NEXT: double _r_d1 = _d_output[1]; //CHECK-NEXT: _d_output[1] = 0; +//CHECK-NEXT: y_size = std::max(y_size, 0); //CHECK-NEXT: _d_x[2] += _r_d1 * y[0]; //CHECK-NEXT: x_size = std::max(x_size, 2); //CHECK-NEXT: _d_y[0] += x[2] * _r_d1; //CHECK-NEXT: y_size = std::max(y_size, 0); +//CHECK-NEXT: y_size = std::max(y_size, 2); //CHECK-NEXT: _d_x[0] += -_r_d1 * y[2]; //CHECK-NEXT: x_size = std::max(x_size, 0); //CHECK-NEXT: _d_y[2] += x[0] * -_r_d1; @@ -279,10 +283,12 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: output[0] = _t0; //CHECK-NEXT: double _r_d0 = _d_output[0]; //CHECK-NEXT: _d_output[0] = 0; +//CHECK-NEXT: y_size = std::max(y_size, 2); //CHECK-NEXT: _d_x[1] += _r_d0 * y[2]; //CHECK-NEXT: x_size = std::max(x_size, 1); //CHECK-NEXT: _d_y[2] += x[1] * _r_d0; //CHECK-NEXT: y_size = std::max(y_size, 2); +//CHECK-NEXT: y_size = std::max(y_size, 1); //CHECK-NEXT: _d_x[2] += -_r_d0 * y[1]; //CHECK-NEXT: x_size = std::max(x_size, 2); //CHECK-NEXT: _d_y[1] += x[2] * -_r_d0; diff --git a/test/ErrorEstimation/LoopsAndArraysExec.C b/test/ErrorEstimation/LoopsAndArraysExec.C index 4460855bf..4647d55b4 100644 --- a/test/ErrorEstimation/LoopsAndArraysExec.C +++ b/test/ErrorEstimation/LoopsAndArraysExec.C @@ -72,8 +72,8 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: int _d_j = 0; //CHECK-NEXT: int j = 0; //CHECK-NEXT: clad::tape _t3 = {}; -//CHECK-NEXT: unsigned {{int|long}} a_size = 0; //CHECK-NEXT: unsigned {{int|long}} b_size = 0; +//CHECK-NEXT: unsigned {{int|long}} a_size = 0; //CHECK-NEXT: double _d_sum = 0; //CHECK-NEXT: double sum = 0; //CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; @@ -111,6 +111,7 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t3); //CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: b_size = std::max(b_size, j); //CHECK-NEXT: _d_a[i] += _r_d0 * b[j]; //CHECK-NEXT: a_size = std::max(a_size, i); //CHECK-NEXT: _d_b[j] += a[i] * _r_d0; diff --git a/test/Gradient/NonDifferentiable.C b/test/Gradient/NonDifferentiable.C index 8ad0ea115..230c9a2b1 100644 --- a/test/Gradient/NonDifferentiable.C +++ b/test/Gradient/NonDifferentiable.C @@ -154,7 +154,7 @@ int main() { // CHECK-NEXT: SimpleFunctions1 _d_obj({}); // CHECK-NEXT: SimpleFunctions1 obj(2, 3); // CHECK-NEXT: { - // CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer); + // CHECK-NEXT: *_d_obj.x_pointer += 1 * *obj.y_pointer; // CHECK-NEXT: *_d_i += 1 * j; // CHECK-NEXT: *_d_j += i * 1; // CHECK-NEXT: } diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 3232bdf42..a6eb48d52 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -28,13 +28,13 @@ double minimalPointer(double x) { // CHECK-NEXT: double *_d_p = &*_d_x; // CHECK-NEXT: double *const p = &x; // CHECK-NEXT: double _t0 = *p; -// CHECK-NEXT: *p = *p * (*p); +// CHECK-NEXT: *p = *p * *p; // CHECK-NEXT: *_d_p += 1; // CHECK-NEXT: { // CHECK-NEXT: *p = _t0; // CHECK-NEXT: double _r_d0 = *_d_p; // CHECK-NEXT: *_d_p = 0; -// CHECK-NEXT: *_d_p += _r_d0 * (*p); +// CHECK-NEXT: *_d_p += _r_d0 * *p; // CHECK-NEXT: *_d_p += *p * _r_d0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -87,7 +87,7 @@ double arrayPointer(const double* arr) { // CHECK-NEXT: _d_p = _d_p - 2; // CHECK-NEXT: p = p - 2; // CHECK-NEXT: double _t11 = sum; -// CHECK-NEXT: sum += 5 * (*p); +// CHECK-NEXT: sum += 5 * *p; // CHECK-NEXT: _d_sum += 1; // CHECK-NEXT: { // CHECK-NEXT: sum = _t11; @@ -170,7 +170,7 @@ double pointerParam(const double* arr, size_t n) { // CHECK-NEXT: clad::push(_t1, _d_j); // CHECK-NEXT: clad::push(_t3, j) , j = &i; // CHECK-NEXT: clad::push(_t4, sum); -// CHECK-NEXT: sum += arr[0] * (*j); +// CHECK-NEXT: sum += arr[0] * *j; // CHECK-NEXT: clad::push(_t5, arr); // CHECK-NEXT: clad::push(_t6, _d_arr); // CHECK-NEXT: _d_arr = _d_arr + 1; @@ -191,7 +191,7 @@ double pointerParam(const double* arr, size_t n) { // CHECK-NEXT: { // CHECK-NEXT: sum = clad::pop(_t4); // CHECK-NEXT: double _r_d0 = _d_sum; -// CHECK-NEXT: _d_arr[0] += _r_d0 * (*j); +// CHECK-NEXT: _d_arr[0] += _r_d0 * *j; // CHECK-NEXT: *_t2 += arr[0] * _r_d0; // CHECK-NEXT: } // CHECK-NEXT: j = clad::pop(_t3);