From bcf79116d442e057c40f1318f640040195d08085 Mon Sep 17 00:00:00 2001 From: Infinite Void Date: Sat, 10 Aug 2024 20:37:41 +0530 Subject: [PATCH] Add support for custom `_reverse_forw` functions This commit adds support for custom (user-provided) `_forw` functions. A `_forw` function, if available, is called in place of the actual function. For example, if the primal code contains: ```cpp someFn(u, v, w); ``` and user has defined a custom `_reverse_forw` function for `someFn` as follows: ```cpp namespace clad { namespace custom_derivatives { void someFn_reverse_forw(double u, double v, double w, double *d_u, double *d_v, double *dw) { // ... // ... } } } ``` Then clad will generate the derivative function as follows: ```cpp // forward-pass clad::custom_derivatives::someFn_reverse_forw(u, v, w, d_u, d_v, d_w); // ... // reverse-pass; no change in reverse-pass someFn_pullback(u, v, w, d_u, d_v, d_w); // ... ``` But more importantly, why do we need such a functionality? Two reasons: - Supporting reference/pointer return types in the reverse-mode. This has been discussed at great length here: https://github.com/vgvassilev/clad/pull/425 (#425) - Supporting types whose elements grows dynamically, such as `std::vector` and `std::map`. The issue is that we correctly need to update the size/property of the adjoint variable when a function call updates the size/property of the corresponding primal variable. For example: a call to `vec.push_back(...)` should update the size of `_d_vec` as well. However, the actual function call does not modify the adjoint variable in any way. Here comes `_forw` functions to the rescue. `_forw` functions makes it possible to adjust the adjoint variable size/properties along with executing the actual function call. Please note that `_reverse_forw` function signature takes adjoint variables as arguments and return `clad::ValueAndAdjoint` to support the reference/pointer return type. --- .../clad/Differentiator/BuiltinDerivatives.h | 5 + include/clad/Differentiator/CladUtils.h | 3 + include/clad/Differentiator/Differentiator.h | 8 +- .../clad/Differentiator/ReverseModeVisitor.h | 5 + include/clad/Differentiator/STLBuiltins.h | 29 ++ lib/Differentiator/CladUtils.cpp | 5 + lib/Differentiator/ReverseModeVisitor.cpp | 36 +- lib/Differentiator/TBRAnalyzer.cpp | 10 +- test/Gradient/FunctionCalls.C | 8 +- test/Gradient/STLCustomDerivatives.C | 345 ++++++++++++++++++ 10 files changed, 437 insertions(+), 17 deletions(-) create mode 100644 test/Gradient/STLCustomDerivatives.C diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index 92b9094ea..ac47237bc 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -30,6 +30,11 @@ template struct ValueAndPushforward { } }; +template struct ValueAndAdjoint { + T value; + U adjoint; +}; + /// It is used to identify constructor custom pushforwards. For /// constructor custom pushforward functions, we cannot use the same /// strategy which we use for custom pushforward for member diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 05899cad7..fa14e5629 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -321,6 +321,9 @@ namespace clad { bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); + + /// Returns true if QT is a non-const reference type. + bool isNonConstReferenceType(clang::QualType QT); } // namespace utils } // namespace clad diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 4a089c095..647428423 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -22,12 +22,8 @@ #include namespace clad { -template struct ValueAndAdjoint { - T value; - U adjoint; -}; - /// \returns the size of a c-style string +/// \returns the size of a c-style string inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { const char* code_copy = code; #ifdef __CUDACC__ @@ -507,7 +503,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { // Gradient Structure for Reverse Mode Enzyme template struct EnzymeGradient { double d_arr[N]; }; -} + } // namespace clad #endif // CLAD_DIFFERENTIATOR // Enable clad after the header was included. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 8d899a4ac..15effe2fa 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -97,6 +97,11 @@ namespace clad { // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); + /// Tries to find and build call to user-provided `_forw` function. + clang::Expr* BuildCallToCustomForwPassFn( + const clang::FunctionDecl* FD, llvm::ArrayRef primalArgs, + llvm::ArrayRef derivedArgs, clang::Expr* baseExpr); + public: using direction = rmv::direction; clang::Expr* dfdx() { diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 415aa446c..3077eb886 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -202,6 +202,35 @@ void fill_pushforward(::std::array* a, const T& u, d_a->fill(d_u); } +template +void push_back_reverse_forw(::std::vector* v, U val, ::std::vector* d_v, + U* d_val) { + v->push_back(val); + d_v->push_back(0); +} + +template +void push_back_pullback(::std::vector* v, U val, ::std::vector* d_v, + U* d_val) { + *d_val += d_v->back(); + d_v->pop_back(); +} + +template +clad::ValueAndAdjoint operator_subscript_reverse_forw( + ::std::vector* vec, typename ::std::vector::size_type idx, + ::std::vector* d_vec, typename ::std::vector::size_type* d_idx) { + return {(*vec)[idx], (*d_vec)[idx]}; +} + +template +void operator_subscript_pullback(::std::vector* vec, + typename ::std::vector::size_type idx, + P d_y, ::std::vector* d_vec, + typename ::std::vector::size_type* d_idx) { + (*d_vec)[idx] += d_y; +} + } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 3b6a379e0..ce092882c 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -705,5 +705,10 @@ namespace clad { return FD->getNameAsString() == "free"; #endif } + + bool isNonConstReferenceType(clang::QualType QT) { + return QT->isReferenceType() && + !QT.getNonReferenceType().isConstQualified(); + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 8308da4b7..9d58434b6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1799,6 +1799,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Stores differentiation result of implicit `this` object, if any. StmtDiff baseDiff; + Expr* baseExpr = nullptr; // If it has more args or f_darg0 was not found, we look for its pullback // function. const auto* MD = dyn_cast(FD); @@ -1822,6 +1823,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); + baseExpr = baseDiff.getExpr(); Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); baseDiff.updateStmt(baseDiffStore); Expr* baseDerivative = baseDiff.getExpr_dx(); @@ -2007,8 +2009,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = nullptr; QualType returnType = FD->getReturnType(); - if (returnType->isReferenceType() && - !returnType.getNonReferenceType().isConstQualified()) { + if (Expr* customForwardPassCE = BuildCallToCustomForwPassFn( + FD, CallArgs, DerivedCallOutputArgs, baseExpr)) { + if (!utils::isNonConstReferenceType(returnType)) + return StmtDiff{customForwardPassCE}; + auto* callRes = StoreAndRef(customForwardPassCE); + auto* resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto* resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } + if (utils::isNonConstReferenceType(returnType)) { DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; @@ -4260,4 +4272,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diffParams.end()); return params; } + + Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn( + const FunctionDecl* FD, llvm::ArrayRef primalArgs, + llvm::ArrayRef derivedArgs, Expr* baseExpr) { + std::string forwPassFnName = + clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw"; + llvm::SmallVector args; + if (baseExpr) { + baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr, + m_DiffReq->getLocation()); + args.push_back(baseExpr); + } + args.append(primalArgs.begin(), primalArgs.end()); + args.append(derivedArgs.begin(), derivedArgs.end()); + Expr* customForwPassCE = + m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + forwPassFnName, args, getCurrentScope(), + const_cast(FD->getDeclContext())); + return customForwPassCE; + } } // end namespace clad diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 115b16bb8..a7eca94cd 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -203,9 +203,13 @@ TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) { elemType = pointerType->getPointeeType().getTypePtrOrNull(); else elemType = QT->getArrayElementTypeNoTypeQual(); - ProfileID nonConstIdxID; - auto& idxData = (*m_Val.m_ArrData)[nonConstIdxID]; - idxData = VarData(QualType::getFromOpaquePtr(elemType)); + // FIXME: In some cases for Mac, 'elemType' is nullptr for std::vector + // internal members. + if (elemType) { + ProfileID nonConstIdxID; + auto& idxData = (*m_Val.m_ArrData)[nonConstIdxID]; + idxData = VarData(QualType::getFromOpaquePtr(elemType)); + } } else if (QT->isBuiltinType()) { m_Type = VarData::FUND_TYPE; m_Val.m_FundData = false; diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 40be00e74..c64069359 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -228,7 +228,7 @@ double& identity(double& i) { namespace clad{ namespace custom_derivatives{ - clad::ValueAndAdjoint custom_identity_forw(double &i, double *d_i) { + clad::ValueAndAdjoint custom_identity_reverse_forw(double &i, double *d_i) { return {i, *d_i}; } } // namespace custom_derivatives @@ -260,10 +260,6 @@ double fn7(double i, double j) { // CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i); -// CHECK: clad::ValueAndAdjoint custom_identity_forw(double &i, double *d_i) { -// CHECK-NEXT: return {i, *d_i}; -// CHECK-NEXT: } - // CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: double _t0 = i; // CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &*_d_i); @@ -274,7 +270,7 @@ double fn7(double i, double j) { // CHECK-NEXT: double &_d_l = _t3.adjoint; // CHECK-NEXT: double &l = _t3.value; // CHECK-NEXT: double _t4 = i; -// CHECK-NEXT: clad::ValueAndAdjoint _t5 = custom_identity_forw(i, &*_d_i); +// CHECK-NEXT: clad::ValueAndAdjoint _t5 = {{.*}}custom_derivatives::custom_identity_reverse_forw(i, &*_d_i); // CHECK-NEXT: double &_d_temp = _t5.adjoint; // CHECK-NEXT: double &temp = _t5.value; // CHECK-NEXT: double _t6 = k; diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C new file mode 100644 index 000000000..1a2db8b81 --- /dev/null +++ b/test/Gradient/STLCustomDerivatives.C @@ -0,0 +1,345 @@ +// XFAIL: asserts +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oSTLCustomDerivatives.out 2>&1 | %filecheck %s +// RUN: ./STLCustomDerivatives.out | %filecheck_exec %s +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oSTLCustomDerivativesWithTBR.out +// RUN: ./STLCustomDerivativesWithTBR.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include "../TestUtils.h" +#include "../PrintOverloads.h" + +#include + +double fn10(double u, double v) { + std::vector vec; + vec.push_back(u); + vec.push_back(v); + return vec[0] + vec[1]; +} + +double fn11(double u, double v) { + std::vector vec; + vec.push_back(u); + vec.push_back(v); + double &ref = vec[0]; + ref += u; + return vec[0] + vec[1]; +} + +namespace clad { + namespace custom_derivatives { + namespace class_functions { + ::std::vector size_update_stack; + + template + void resize_reverse_forw(::std::vector *v, + typename ::std::vector::size_type sz, + ::std::vector *d_v, + typename ::std::vector::size_type *d_sz) { + size_update_stack.push_back(v->size()); + v->resize(sz); + d_v->resize(sz, 0); + } + + template + void resize_pullback(::std::vector *v, + typename ::std::vector::size_type sz, + ::std::vector *d_v, + typename ::std::vector::size_type *d_sz) { + size_t prevSz = size_update_stack.back(); + size_update_stack.pop_back(); + d_v->resize(prevSz); + } + + template + void clear_reverse_forw(::std::vector *v, ::std::vector *d_v) { + size_update_stack.push_back(v->size()); + v->clear(); + d_v->clear(); + } + + template + void clear_pullback(::std::vector *v, ::std::vector *d_v) { + size_t prevSz = size_update_stack.back(); + size_update_stack.pop_back(); + d_v->resize(prevSz, 0); + } + } + } +} + +double fn12(double u, double v) { + double res = 0; + std::vector vec; + vec.resize(3); + { + double &ref0 = vec[0]; + double &ref1 = vec[1]; + double &ref2 = vec[2]; + ref0 = u; + ref1 = v; + ref2 = u + v; + } + res = vec[0] + vec[1] + vec[2]; + vec.clear(); + vec.resize(2); + { + double &ref0 = vec[0]; + double &ref1 = vec[1]; + ref0 = u; + ref1 = u; + } + res += vec[0] + vec[1]; + return res; +} + +int main() { + double d_i, d_j; + INIT_GRADIENT(fn10); + INIT_GRADIENT(fn11); + INIT_GRADIENT(fn12); + + TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} + TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} + TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00} +} + +// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: std::vector _d_vec({}); +// CHECK-NEXT: std::vector vec; +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: std::vector _t1 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_reverse_forw(&vec, u, &_d_vec, &*_d_u); +// CHECK-NEXT: double _t2 = v; +// CHECK-NEXT: std::vector _t3 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_reverse_forw(&vec, v, &_d_vec, &*_d_v); +// CHECK-NEXT: std::vector _t4 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r0); +// CHECK-NEXT: std::vector _t6 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t7 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r1); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 1, &_d_vec, &_r0); +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t6, 1, 1, &_d_vec, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: v = _t2; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u); +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK-NEXT: void fn11_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: std::vector _d_vec({}); +// CHECK-NEXT: std::vector vec; +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: std::vector _t1 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_reverse_forw(&vec, u, &_d_vec, &*_d_u); +// CHECK-NEXT: double _t2 = v; +// CHECK-NEXT: std::vector _t3 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_reverse_forw(&vec, v, &_d_vec, &*_d_v); +// CHECK-NEXT: std::vector _t4 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r0); +// CHECK-NEXT: double &_d_ref = _t5.adjoint; +// CHECK-NEXT: double &ref = _t5.value; +// CHECK-NEXT: double _t6 = ref; +// CHECK-NEXT: ref += u; +// CHECK-NEXT: std::vector _t7 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t8 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r1); +// CHECK-NEXT: std::vector _t9 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t10 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r2); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t7, 0, 1, &_d_vec, &_r1); +// CHECK-NEXT: {{.*}} _r2 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t9, 1, 1, &_d_vec, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ref = _t6; +// CHECK-NEXT: double _r_d0 = _d_ref; +// CHECK-NEXT: *_d_u += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 0, &_d_vec, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: v = _t2; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u); +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void fn12_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: std::vector _t1; +// CHECK-NEXT: double *_d_ref0 = 0; +// CHECK-NEXT: double *ref0 = {}; +// CHECK-NEXT: std::vector _t3; +// CHECK-NEXT: double *_d_ref1 = 0; +// CHECK-NEXT: double *ref1 = {}; +// CHECK-NEXT: std::vector _t5; +// CHECK-NEXT: double *_d_ref2 = 0; +// CHECK-NEXT: double *ref2 = {}; +// CHECK-NEXT: double _t7; +// CHECK-NEXT: double _t8; +// CHECK-NEXT: double _t9; +// CHECK-NEXT: std::vector _t19; +// CHECK-NEXT: double *_d_ref00 = 0; +// CHECK-NEXT: double *ref00 = {}; +// CHECK-NEXT: std::vector _t21; +// CHECK-NEXT: double *_d_ref10 = 0; +// CHECK-NEXT: double *ref10 = {}; +// CHECK-NEXT: double _t23; +// CHECK-NEXT: double _t24; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: std::vector _d_vec({}); +// CHECK-NEXT: std::vector vec; +// CHECK-NEXT: std::vector _t0 = vec; +// CHECK-NEXT: {{.*}}class_functions::resize_reverse_forw(&vec, 3, &_d_vec, &_r0); +// CHECK-NEXT: { +// CHECK-NEXT: _t1 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r1); +// CHECK-NEXT: _d_ref0 = &_t2.adjoint; +// CHECK-NEXT: ref0 = &_t2.value; +// CHECK-NEXT: _t3 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t4 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r2); +// CHECK-NEXT: _d_ref1 = &_t4.adjoint; +// CHECK-NEXT: ref1 = &_t4.value; +// CHECK-NEXT: _t5 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t6 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 2, &_d_vec, &_r3); +// CHECK-NEXT: _d_ref2 = &_t6.adjoint; +// CHECK-NEXT: ref2 = &_t6.value; +// CHECK-NEXT: _t7 = *ref0; +// CHECK-NEXT: *ref0 = u; +// CHECK-NEXT: _t8 = *ref1; +// CHECK-NEXT: *ref1 = v; +// CHECK-NEXT: _t9 = *ref2; +// CHECK-NEXT: *ref2 = u + v; +// CHECK-NEXT: } +// CHECK-NEXT: double _t10 = res; +// CHECK-NEXT: std::vector _t11 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t12 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r4); +// CHECK-NEXT: std::vector _t13 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t14 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r5); +// CHECK-NEXT: std::vector _t15 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t16 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 2, &_d_vec, &_r6); +// CHECK-NEXT: res = _t12.value + _t14.value + _t16.value; +// CHECK-NEXT: std::vector _t17 = vec; +// CHECK-NEXT: {{.*}}class_functions::clear_reverse_forw(&vec, &_d_vec); +// CHECK-NEXT: std::vector _t18 = vec; +// CHECK-NEXT: {{.*}}class_functions::resize_reverse_forw(&vec, 2, &_d_vec, &_r7); +// CHECK-NEXT: { +// CHECK-NEXT: _t19 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t20 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r8); +// CHECK-NEXT: _d_ref00 = &_t20.adjoint; +// CHECK-NEXT: ref00 = &_t20.value; +// CHECK-NEXT: _t21 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t22 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r9); +// CHECK-NEXT: _d_ref10 = &_t22.adjoint; +// CHECK-NEXT: ref10 = &_t22.value; +// CHECK-NEXT: _t23 = *ref00; +// CHECK-NEXT: *ref00 = u; +// CHECK-NEXT: _t24 = *ref10; +// CHECK-NEXT: *ref10 = u; +// CHECK-NEXT: } +// CHECK-NEXT: double _t25 = res; +// CHECK-NEXT: std::vector _t26 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t27 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 0, &_d_vec, &_r10); +// CHECK-NEXT: std::vector _t28 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t29 = {{.*}}class_functions::operator_subscript_reverse_forw(&vec, 1, &_d_vec, &_r11); +// CHECK-NEXT: res += _t27.value + _t29.value; +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: res = _t25; +// CHECK-NEXT: double _r_d6 = _d_res; +// CHECK-NEXT: {{.*}} _r10 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t26, 0, _r_d6, &_d_vec, &_r10); +// CHECK-NEXT: {{.*}} _r11 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t28, 1, _r_d6, &_d_vec, &_r11); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: *ref10 = _t24; +// CHECK-NEXT: double _r_d5 = *_d_ref10; +// CHECK-NEXT: *_d_ref10 = 0; +// CHECK-NEXT: *_d_u += _r_d5; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: *ref00 = _t23; +// CHECK-NEXT: double _r_d4 = *_d_ref00; +// CHECK-NEXT: *_d_ref00 = 0; +// CHECK-NEXT: *_d_u += _r_d4; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r9 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t21, 1, 0, &_d_vec, &_r9); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r8 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t19, 0, 0, &_d_vec, &_r8); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r7 = 0; +// CHECK-NEXT: {{.*}}class_functions::resize_pullback(&_t18, 2, &_d_vec, &_r7); +// CHECK-NEXT: } +// CHECK-NEXT: {{.*}}class_functions::clear_pullback(&_t17, &_d_vec); +// CHECK-NEXT: { +// CHECK-NEXT: res = _t10; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: _d_res = 0; +// CHECK-NEXT: {{.*}} _r4 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t11, 0, _r_d3, &_d_vec, &_r4); +// CHECK-NEXT: {{.*}} _r5 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t13, 1, _r_d3, &_d_vec, &_r5); +// CHECK-NEXT: {{.*}} _r6 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t15, 2, _r_d3, &_d_vec, &_r6); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: *ref2 = _t9; +// CHECK-NEXT: double _r_d2 = *_d_ref2; +// CHECK-NEXT: *_d_ref2 = 0; +// CHECK-NEXT: *_d_u += _r_d2; +// CHECK-NEXT: *_d_v += _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: *ref1 = _t8; +// CHECK-NEXT: double _r_d1 = *_d_ref1; +// CHECK-NEXT: *_d_ref1 = 0; +// CHECK-NEXT: *_d_v += _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: *ref0 = _t7; +// CHECK-NEXT: double _r_d0 = *_d_ref0; +// CHECK-NEXT: *_d_ref0 = 0; +// CHECK-NEXT: *_d_u += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r3 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t5, 2, 0, &_d_vec, &_r3); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r2 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t3, 1, 0, &_d_vec, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t1, 0, 0, &_d_vec, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}class_functions::resize_pullback(&_t0, 3, &_d_vec, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file