Skip to content

Commit

Permalink
Add support for initializer_list in forward mode AD
Browse files Browse the repository at this point in the history
This commit adds primitive support for initializer_list in the forward
mode AD.
  • Loading branch information
parth-07 authored and vgvassilev committed Jun 27, 2024
1 parent 4f487e3 commit 4c824ed
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class BaseForwardModeVisitor
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
17 changes: 17 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define CLAD_STL_BUILTINS_H

#include <vector>
#include "clad/Differentiator/BuiltinDerivatives.h"

namespace clad {
namespace custom_derivatives {
Expand All @@ -26,6 +27,22 @@ void resize_pushforward(::std::vector<T>* v, unsigned sz, U val,
d_v->resize(sz, d_val);
v->resize(sz, val);
}

template <typename T>
ValueAndPushforward<typename ::std::initializer_list<T>::iterator,
typename ::std::initializer_list<T>::iterator>
begin_pushforward(::std::initializer_list<T>* il,
::std::initializer_list<T>* d_il) {
return {il->begin(), d_il->begin()};
}

template <typename T>
ValueAndPushforward<typename ::std::initializer_list<T>::iterator,
typename ::std::initializer_list<T>::iterator>
end_pushforward(const ::std::initializer_list<T>* il,
const ::std::initializer_list<T>* d_il) {
return {il->end(), d_il->end()};
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2159,4 +2159,9 @@ BaseForwardModeVisitor::DifferentiateStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return DeclDiff<StaticAssertDecl>();
}

StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr());
}
} // end namespace clad
38 changes: 38 additions & 0 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/STLBuiltins.h"
#include <initializer_list>
#include <cmath>
#include "../TestUtils.h"

double f1(double x, int y) {
double r = 1.0;
Expand Down Expand Up @@ -548,6 +551,38 @@ double fn17_darg0(double x);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn18(double u, double v) {
auto dl = {u, v, u*v};
double res = 0;
auto dl_end = dl.end();
for (auto i = dl.begin(); i != dl_end; ++i)
res += *i;
return res;
}

// CHECK: double fn18_darg0(double u, double v) {
// CHECK-NEXT: double _d_u = 1;
// CHECK-NEXT: double _d_v = 0;
// CHECK-NEXT: {{.*}}initializer_list<double> _d_dl = {_d_u, _d_v, _d_u * v + u * _d_v};
// CHECK-NEXT: {{.*}}initializer_list<double> dl = {u, v, u * v};
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t0 = {{.*}}end_pushforward(&dl, &_d_dl);
// CHECK-NEXT: {{.*}}_d_dl_end = _t0.pushforward;
// CHECK-NEXT: {{.*}}dl_end = _t0.value;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t1 = {{.*}}begin_pushforward(&dl, &_d_dl);
// CHECK-NEXT: {{.*}}_d_i = _t1.pushforward;
// CHECK-NEXT: for ({{.*}}i = _t1.value; i != dl_end; ++_d_i , ++i) {
// CHECK-NEXT: _d_res += *_d_i;
// CHECK-NEXT: res += *i;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }



#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
printf("%.2f\n", d_##fn.execute(3, 5));
Expand Down Expand Up @@ -614,4 +649,7 @@ int main() {

clad::differentiate(fn17, 0);
printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0

INIT_DIFFERENTIATE(fn18, "u");
TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00}
}

0 comments on commit 4c824ed

Please sign in to comment.