From 4c824edf6519086f41b0b456a1045a095a1206d5 Mon Sep 17 00:00:00 2001 From: parth-07 Date: Wed, 12 Jun 2024 00:33:50 +0530 Subject: [PATCH] Add support for initializer_list in forward mode AD This commit adds primitive support for initializer_list in the forward mode AD. --- .../Differentiator/BaseForwardModeVisitor.h | 2 + include/clad/Differentiator/STLBuiltins.h | 17 +++++++++ lib/Differentiator/BaseForwardModeVisitor.cpp | 5 +++ test/FirstDerivative/Loops.C | 38 +++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 17e6ba6a4..484d17d4b 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -112,6 +112,8 @@ class BaseForwardModeVisitor StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; + StmtDiff + VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 41fc05518..39e786aba 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -2,6 +2,7 @@ #define CLAD_STL_BUILTINS_H #include +#include "clad/Differentiator/BuiltinDerivatives.h" namespace clad { namespace custom_derivatives { @@ -26,6 +27,22 @@ void resize_pushforward(::std::vector* v, unsigned sz, U val, d_v->resize(sz, d_val); v->resize(sz, val); } + +template +ValueAndPushforward::iterator, + typename ::std::initializer_list::iterator> +begin_pushforward(::std::initializer_list* il, + ::std::initializer_list* d_il) { + return {il->begin(), d_il->begin()}; +} + +template +ValueAndPushforward::iterator, + typename ::std::initializer_list::iterator> +end_pushforward(const ::std::initializer_list* il, + const ::std::initializer_list* d_il) { + return {il->end(), d_il->end()}; +} } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 48e1cf35c..067fd77d1 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -2159,4 +2159,9 @@ BaseForwardModeVisitor::DifferentiateStaticAssertDecl( const clang::StaticAssertDecl* SAD) { return DeclDiff(); } + +StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr( + const clang::CXXStdInitializerListExpr* ILE) { + return Visit(ILE->getSubExpr()); +} } // end namespace clad diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 152984bcd..5534291b5 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -3,7 +3,10 @@ // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include #include +#include "../TestUtils.h" double f1(double x, int y) { double r = 1.0; @@ -548,6 +551,38 @@ double fn17_darg0(double x); // CHECK-NEXT: return _d_x; // CHECK-NEXT: } +double fn18(double u, double v) { + auto dl = {u, v, u*v}; + double res = 0; + auto dl_end = dl.end(); + for (auto i = dl.begin(); i != dl_end; ++i) + res += *i; + return res; +} + +// CHECK: double fn18_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: {{.*}}initializer_list _d_dl = {_d_u, _d_v, _d_u * v + u * _d_v}; +// CHECK-NEXT: {{.*}}initializer_list dl = {u, v, u * v}; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t0 = {{.*}}end_pushforward(&dl, &_d_dl); +// CHECK-NEXT: {{.*}}_d_dl_end = _t0.pushforward; +// CHECK-NEXT: {{.*}}dl_end = _t0.value; +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t1 = {{.*}}begin_pushforward(&dl, &_d_dl); +// CHECK-NEXT: {{.*}}_d_i = _t1.pushforward; +// CHECK-NEXT: for ({{.*}}i = _t1.value; i != dl_end; ++_d_i , ++i) { +// CHECK-NEXT: _d_res += *_d_i; +// CHECK-NEXT: res += *i; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + + + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -614,4 +649,7 @@ int main() { clad::differentiate(fn17, 0); printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0 + + INIT_DIFFERENTIATE(fn18, "u"); + TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00} }