From 58129cdd28ebbc66386adacd5139e0e8547db906 Mon Sep 17 00:00:00 2001 From: Infinite Void Date: Sat, 10 Aug 2024 20:37:41 +0530 Subject: [PATCH] Add support for custom `_forw` functions This commit adds support for custom (user-provided) `_forw` functions. A `_forw` function, if available, is called in place of the actual function. For example, if the primal code contains: ```cpp someFn(u, v, w); ``` and user has defined a custom `_forw` function for `someFn` as follows: ```cpp namespace clad { namespace custom_derivatives { void someFn_forw(double u, double v, double w, double *d_u, double *d_v, double *dw) { // ... // ... } } } ``` Then clad will generate the derivative function as follows: ```cpp // forward-pass clad::custom_derivatives::someFn_forw(u, v, w, d_u, d_v, d_w); // ... // reverse-pass; no change in reverse-pass someFn_pullback(u, v, w, d_u, d_v, d_w); // ... ``` But more importantly, why do we need such a functionality? Two reasons: - Supporting reference/pointer return types in the reverse-mode. This has been discussed at great length here: https://github.com/vgvassilev/clad/pull/425 (#425) - Supporting types whose elements grows dynamically, such as `std::vector` and `std::map`. The issue is that we correctly need to update the size/property of the adjoint variable when a function call updates the size/property of the corresponding primal variable. However, the actual function call does not modify the adjoint variable. Here comes `_forw` functions to the rescue. `_forw` functions makes it possible to adjust the adjoint variable size/properties along with executing the code of the actual function call. --- .../clad/Differentiator/BuiltinDerivatives.h | 5 + include/clad/Differentiator/CladUtils.h | 3 + include/clad/Differentiator/Differentiator.h | 8 +- .../clad/Differentiator/ReverseModeVisitor.h | 5 + include/clad/Differentiator/STLBuiltins.h | 29 +++++ lib/Differentiator/CladUtils.cpp | 5 + lib/Differentiator/ReverseModeVisitor.cpp | 36 ++++++- test/Gradient/FunctionCalls.C | 6 +- test/Gradient/UserDefinedTypes.C | 100 +++++++++++++++++- 9 files changed, 182 insertions(+), 15 deletions(-) diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index 92b9094ea..ac47237bc 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -30,6 +30,11 @@ template struct ValueAndPushforward { } }; +template struct ValueAndAdjoint { + T value; + U adjoint; +}; + /// It is used to identify constructor custom pushforwards. For /// constructor custom pushforward functions, we cannot use the same /// strategy which we use for custom pushforward for member diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 05899cad7..fa14e5629 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -321,6 +321,9 @@ namespace clad { bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); + + /// Returns true if QT is a non-const reference type. + bool isNonConstReferenceType(clang::QualType QT); } // namespace utils } // namespace clad diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 4a089c095..647428423 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -22,12 +22,8 @@ #include namespace clad { -template struct ValueAndAdjoint { - T value; - U adjoint; -}; - /// \returns the size of a c-style string +/// \returns the size of a c-style string inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { const char* code_copy = code; #ifdef __CUDACC__ @@ -507,7 +503,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { // Gradient Structure for Reverse Mode Enzyme template struct EnzymeGradient { double d_arr[N]; }; -} + } // namespace clad #endif // CLAD_DIFFERENTIATOR // Enable clad after the header was included. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a161f1f58..da1e363e7 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -97,6 +97,11 @@ namespace clad { // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); + /// Tries to find and build call to user-provided `_forw` function. + clang::Expr* BuildCallToCustomForwPassFn( + const clang::FunctionDecl* FD, llvm::ArrayRef primalArgs, + llvm::ArrayRef derivedArgs, clang::Expr* baseExpr); + public: using direction = rmv::direction; clang::Expr* dfdx() { diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 415aa446c..db031735c 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -202,6 +202,35 @@ void fill_pushforward(::std::array* a, const T& u, d_a->fill(d_u); } +template +void push_back_forw(::std::vector* v, U val, ::std::vector* d_v, + U* d_val) { + v->push_back(val); + d_v->push_back(0); +} + +template +void push_back_pullback(::std::vector* v, U val, ::std::vector* d_v, + U* d_val) { + *d_val += d_v->back(); + d_v->pop_back(); +} + +template +clad::ValueAndAdjoint operator_subscript_forw( + ::std::vector* vec, typename ::std::vector::size_type idx, + ::std::vector* d_vec, typename ::std::vector::size_type* d_idx) { + return {(*vec)[idx], (*d_vec)[idx]}; +} + +template +void operator_subscript_pullback(::std::vector* vec, + typename ::std::vector::size_type idx, + P d_y, ::std::vector* d_vec, + typename ::std::vector::size_type* d_idx) { + (*d_vec)[idx] += d_y; +} + } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 3b6a379e0..ce092882c 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -705,5 +705,10 @@ namespace clad { return FD->getNameAsString() == "free"; #endif } + + bool isNonConstReferenceType(clang::QualType QT) { + return QT->isReferenceType() && + !QT.getNonReferenceType().isConstQualified(); + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fb1102d66..a5127264d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1791,6 +1791,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Stores differentiation result of implicit `this` object, if any. StmtDiff baseDiff; + Expr* baseExpr = nullptr; // If it has more args or f_darg0 was not found, we look for its pullback // function. const auto* MD = dyn_cast(FD); @@ -1814,6 +1815,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); + baseExpr = baseDiff.getExpr(); Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); baseDiff.updateStmt(baseDiffStore); Expr* baseDerivative = baseDiff.getExpr_dx(); @@ -1999,8 +2001,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = nullptr; QualType returnType = FD->getReturnType(); - if (returnType->isReferenceType() && - !returnType.getNonReferenceType().isConstQualified()) { + if (Expr* customForwardPassCE = BuildCallToCustomForwPassFn( + FD, CallArgs, DerivedCallOutputArgs, baseExpr)) { + if (!utils::isNonConstReferenceType(returnType)) + return StmtDiff{customForwardPassCE}; + auto* callRes = StoreAndRef(customForwardPassCE); + auto* resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto* resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } + if (utils::isNonConstReferenceType(returnType)) { DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; @@ -4259,4 +4271,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diffParams.end()); return params; } + + Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn( + const FunctionDecl* FD, llvm::ArrayRef primalArgs, + llvm::ArrayRef derivedArgs, Expr* baseExpr) { + std::string forwPassFnName = + clad::utils::ComputeEffectiveFnName(FD) + "_forw"; + llvm::SmallVector args; + if (baseExpr) { + baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr, + m_DiffReq->getLocation()); + args.push_back(baseExpr); + } + args.append(primalArgs.begin(), primalArgs.end()); + args.append(derivedArgs.begin(), derivedArgs.end()); + Expr* customForwPassCE = + m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + forwPassFnName, args, getCurrentScope(), + const_cast(FD->getDeclContext())); + return customForwPassCE; + } } // end namespace clad diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 40be00e74..b7edec084 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -260,10 +260,6 @@ double fn7(double i, double j) { // CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i); -// CHECK: clad::ValueAndAdjoint custom_identity_forw(double &i, double *d_i) { -// CHECK-NEXT: return {i, *d_i}; -// CHECK-NEXT: } - // CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: double _t0 = i; // CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &*_d_i); @@ -274,7 +270,7 @@ double fn7(double i, double j) { // CHECK-NEXT: double &_d_l = _t3.adjoint; // CHECK-NEXT: double &l = _t3.value; // CHECK-NEXT: double _t4 = i; -// CHECK-NEXT: clad::ValueAndAdjoint _t5 = custom_identity_forw(i, &*_d_i); +// CHECK-NEXT: clad::ValueAndAdjoint _t5 = {{.*}}custom_derivatives::custom_identity_forw(i, &*_d_i); // CHECK-NEXT: double &_d_temp = _t5.adjoint; // CHECK-NEXT: double &temp = _t5.value; // CHECK-NEXT: double _t6 = k; diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index f5f954808..405b36cda 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -1,12 +1,14 @@ -// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s // RUN: ./UserDefinedTypes.out | %filecheck_exec %s -// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out // RUN: ./UserDefinedTypes.out | %filecheck_exec %s // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" #include +#include #include #include "../TestUtils.h" @@ -326,6 +328,22 @@ double fn9(Tangent t, dcomplex c) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn10(double u, double v) { + std::vector vec; + vec.push_back(u); + vec.push_back(v); + return vec[0] + vec[1]; +} + +double fn11(double u, double v) { + std::vector vec; + vec.push_back(u); + vec.push_back(v); + double &ref = vec[0]; + ref += u; + return vec[0] + vec[1]; +} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -334,6 +352,7 @@ void print(const Tangent& t) { } } + int main() { pairdd p(3, 5), d_p; double i = 3, d_i, d_j; @@ -351,6 +370,8 @@ int main() { INIT_GRADIENT(fn7); INIT_GRADIENT(fn8); INIT_GRADIENT(fn9); + INIT_GRADIENT(fn10); + INIT_GRADIENT(fn11); TEST_GRADIENT(fn1, /*numOfDerivativeArgs=*/2, p, i, &d_p, &d_i); // CHECK-EXEC: {1.00, 2.00, 3.00} TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, t, i, &d_t, &d_i); // CHECK-EXEC: {4.00, 2.00, 2.00, 2.00, 2.00, 1.00} @@ -364,8 +385,83 @@ int main() { TEST_GRADIENT(fn7, /*numOfDerivativeArgs=*/2, c1, c2, &d_c1, &d_c2);// CHECK-EXEC: {0.00, 3.00, 5.00, 1.00} TEST_GRADIENT(fn8, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {0.00, 0.00, 0.00, 0.00, 0.00, 5.00, 0.00} TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00} + TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} + TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} } +// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: std::vector _d_vec({}); +// CHECK-NEXT: std::vector vec; +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: std::vector _t1 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, u, &_d_vec, &*_d_u); +// CHECK-NEXT: double _t2 = v; +// CHECK-NEXT: std::vector _t3 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, v, &_d_vec, &*_d_v); +// CHECK-NEXT: std::vector _t4 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r0); +// CHECK-NEXT: std::vector _t6 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t7 = {{.*}}class_functions::operator_subscript_forw(&vec, 1, &_d_vec, &_r1); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 1, &_d_vec, &_r0); +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t6, 1, 1, &_d_vec, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: v = _t2; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u); +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK-NEXT: void fn11_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: std::vector _d_vec({}); +// CHECK-NEXT: std::vector vec; +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: std::vector _t1 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, u, &_d_vec, &*_d_u); +// CHECK-NEXT: double _t2 = v; +// CHECK-NEXT: std::vector _t3 = vec; +// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, v, &_d_vec, &*_d_v); +// CHECK-NEXT: std::vector _t4 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r0); +// CHECK-NEXT: double &_d_ref = _t5.adjoint; +// CHECK-NEXT: double &ref = _t5.value; +// CHECK-NEXT: double _t6 = ref; +// CHECK-NEXT: ref += u; +// CHECK-NEXT: std::vector _t7 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t8 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r1); +// CHECK-NEXT: std::vector _t9 = vec; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t10 = {{.*}}class_functions::operator_subscript_forw(&vec, 1, &_d_vec, &_r2); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t7, 0, 1, &_d_vec, &_r1); +// CHECK-NEXT: {{.*}} _r2 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t9, 1, 1, &_d_vec, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ref = _t6; +// CHECK-NEXT: double _r_d0 = _d_ref; +// CHECK-NEXT: *_d_u += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 0, &_d_vec, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: v = _t2; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u); +// CHECK-NEXT: } +// CHECK-NEXT: } + // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) { // CHECK-NEXT: int _d_i = 0; // CHECK-NEXT: int i = 0;