Skip to content

Commit

Permalink
Add support for Kokkos::resize in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch authored and vgvassilev committed Oct 16, 2024
1 parent 183719a commit 5a02b1a
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 24 deletions.
72 changes: 52 additions & 20 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ template <typename View, int Rank> struct iterate_over_all_view_elements {
template <typename View> struct iterate_over_all_view_elements<View, 1> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for("iterate_over_all_view_elements", v.extent(0), func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 2> {
Expand All @@ -348,6 +349,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 2> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<2>>(
{0, 0}, {v.extent(0), v.extent(1)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 3> {
Expand All @@ -357,6 +359,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 3> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<3>>(
{0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 4> {
Expand All @@ -366,6 +369,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 4> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<4>>(
{0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 5> {
Expand All @@ -376,6 +380,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 5> {
{0, 0, 0, 0, 0},
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 6> {
Expand All @@ -386,6 +391,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 6> {
{0, 0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2),
v.extent(3), v.extent(4), v.extent(5)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 7> {
Expand All @@ -397,6 +403,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 7> {
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4),
v.extent(5), v.extent(6)}),
func);
::Kokkos::fence();
}
};
template <typename... ViewArgs>
Expand Down Expand Up @@ -450,32 +457,57 @@ inline void deep_copy_pullback(
});
}

template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7>
inline void
resize_pushforward(View& v, const Idx0 n0, const Idx1 n1, const Idx2 n2,
const Idx3 n3, const Idx4 n4, const Idx5 n5, const Idx6 n6,
const Idx7 n7, View& d_v, const Idx0 /*d_n*/,
const Idx1 /*d_n*/, const Idx2 /*d_n*/, const Idx3 /*d_n*/,
const Idx4 /*d_n*/, const Idx5 /*d_n*/, const Idx6 /*d_n*/,
const Idx7 /*d_n*/) {
template <typename View>
void resize_pushforward(
View& v, const ::std::size_t n0, const ::std::size_t n1,
const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4,
const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7,
View& d_v, const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class I, class dI, class View, typename Idx0, typename Idx1,
typename Idx2, typename Idx3, typename Idx4, typename Idx5,
typename Idx6, typename Idx7>
inline void
resize_pushforward(const I& arg, View& v, const Idx0 n0, const Idx1 n1,
const Idx2 n2, const Idx3 n3, const Idx4 n4, const Idx5 n5,
const Idx6 n6, const Idx7 n7, const dI& /*d_arg*/, View& d_v,
const Idx0 /*d_n*/, const Idx1 /*d_n*/, const Idx2 /*d_n*/,
const Idx3 /*d_n*/, const Idx4 /*d_n*/, const Idx5 /*d_n*/,
const Idx6 /*d_n*/, const Idx7 /*d_n*/) {
template <class I, class dI, class View>
void resize_pushforward(
const I& arg, View& v, const ::std::size_t n0, const ::std::size_t n1,
const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4,
const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7,
const dI& /*d_arg*/, View& d_v, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class View>
void resize_reverse_forw(
View& v, const ::std::size_t n0, const ::std::size_t n1,
const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4,
const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7,
View& d_v, const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class I, class dI, class View>
void resize_reverse_forw(
const I& arg, View& v, const ::std::size_t n0, const ::std::size_t n1,
const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4,
const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7,
const dI& /*d_arg*/, View& d_v, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/,
const ::std::size_t /*d_n*/) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class... Args> void resize_pullback(Args... /*args*/) {}

/// Fence
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
Expand Down
6 changes: 3 additions & 3 deletions unittests/Kokkos/ViewAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ TEST(ViewAccess, Test2) {
EXPECT_NEAR(f_3_y.execute(3, 4), dy_f_3_FD, tolerance * dy_f_3_FD);

auto f_grad_exe = clad::gradient(f);
double dx, dy;
double dx = 0, dy = 0;
f_grad_exe.execute(3., 4., &dx, &dy);
EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx);

double dx_2, dy_2;
double dx_2 = 0, dy_2 = 0;
auto f_2_grad_exe = clad::gradient(f_2);
f_2_grad_exe.execute(3., 4., &dx_2, &dy_2);
EXPECT_NEAR(f_2_x.execute(3, 4), dx_2, tolerance * dx_2);

double dx_3, dy_3;
double dx_3 = 0, dy_3 = 0;
auto f_3_grad_exe = clad::gradient(f_3);
f_3_grad_exe.execute(3., 4., &dx_3, &dy_3);
EXPECT_NEAR(f_3_y.execute(3, 4), dy_3, tolerance * dy_3);
Expand Down
39 changes: 38 additions & 1 deletion unittests/Kokkos/ViewBasics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ double f_basics_resize_3(double x, double y) {
2);
Kokkos::deep_copy(a, 3 * x + y);

Kokkos::resize(Kokkos::WithoutInitializing, a, 5, 5);
Kokkos::resize(Kokkos::WithoutInitializing, a, 5,
5); // FIXME: this signature for the resize function is not yet
// supported in the reverse mode

a(4, 4, 0) = x * y;

Expand Down Expand Up @@ -250,6 +252,41 @@ TEST(ViewBasics, TestResize4) {
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

double f_basics_resize_5_both_modes(double x, double y) {
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> a("a", 3,
2);
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> b("b", 5,
5);

b(4, 4, 0) = x * y * 2;
b(2, 1, 0) = 0;

Kokkos::deep_copy(a, 3 * x + y);
a(2, 1, 0) = x * y;

Kokkos::resize(a, 5, 5);
Kokkos::deep_copy(a, b);

return a(4, 4, 0);
}

TEST(ViewBasics, TestResize5) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_resize_5_both_modes, 0);
auto gradf = clad::gradient(f_basics_resize_5_both_modes);
auto df_true_x = [](double x, double y) { return y * 2; };
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1) {
double dfdx = df.execute(x, y);
EXPECT_NEAR(dfdx, df_true_x(x, y), eps);
double dx = 0, dy = 0;
gradf.execute(x, y, &dx, &dy);
EXPECT_NEAR(dfdx, dx, eps);
EXPECT_NEAR(2 * x, dy, eps);
}
}

template <typename View> struct FooModifier {
double x;

Expand Down

0 comments on commit 5a02b1a

Please sign in to comment.