Skip to content

Commit

Permalink
Add support for Kokkos::fence in the fwd mode
Browse files Browse the repository at this point in the history
Although this function doesn't need to be differentiated and
is correctly used by Clad automatically, this custom pushforward
prevents Clad from throwing a warning during that.
  • Loading branch information
gojakuch authored and vgvassilev committed Aug 21, 2024
1 parent 66b6af4 commit 397e3b0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0,
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

/// Fence
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
::Kokkos::fence(s);
}

/// Parallel for
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
Expand Down
20 changes: 20 additions & 0 deletions unittests/Kokkos/ParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ template <typename View> struct Foo {
void operator()(const int i) const { res(i) = x * i; }
};

double parallel_for_functor_simplest_case_fence(double x) {
Kokkos::View<double[5], Kokkos::HostSpace> res("res");

// Kokkos::fence("named fence"); // Does not work on some versions of Kokkos.

Foo<Kokkos::View<double[5], Kokkos::HostSpace>> f(res, x);

f(0); // FIXME: this is a workaround to put Foo::operator() into the
// differentiation plan. This needs to be solved in clad.

Kokkos::parallel_for(5, f);
Kokkos::fence();

return res(3);
}

double parallel_for_functor_simplest_case_intpol(double x) {
Kokkos::View<double[5], Kokkos::HostSpace> res("res");

Expand Down Expand Up @@ -191,6 +207,10 @@ double parallel_for_functor_simplest_case_mdpol_space_and_anon(double x) {
TEST(ParallelFor, FunctorSimplestCases) {
const double eps = 1e-8;

auto df0 = clad::differentiate(parallel_for_functor_simplest_case_fence, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df0.execute(x), 3, eps);

auto df1 = clad::differentiate(parallel_for_functor_simplest_case_intpol, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df1.execute(x), 3, eps);
Expand Down

0 comments on commit 397e3b0

Please sign in to comment.