From 5a02b1a0a9b84a1bef9453ef59104ceaf1ef8b9a Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Mon, 14 Oct 2024 20:53:15 +0200 Subject: [PATCH] Add support for `Kokkos::resize` in the rvs mode --- include/clad/Differentiator/KokkosBuiltins.h | 72 ++++++++++++++------ unittests/Kokkos/ViewAccess.cpp | 6 +- unittests/Kokkos/ViewBasics.cpp | 39 ++++++++++- 3 files changed, 93 insertions(+), 24 deletions(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index 51824a004..4a46dae1e 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -340,6 +340,7 @@ template struct iterate_over_all_view_elements { template struct iterate_over_all_view_elements { template static void run(const View& v, F func) { ::Kokkos::parallel_for("iterate_over_all_view_elements", v.extent(0), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -348,6 +349,7 @@ template struct iterate_over_all_view_elements { ::Kokkos::MDRangePolicy<::Kokkos::Rank<2>>( {0, 0}, {v.extent(0), v.extent(1)}), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -357,6 +359,7 @@ template struct iterate_over_all_view_elements { ::Kokkos::MDRangePolicy<::Kokkos::Rank<3>>( {0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2)}), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -366,6 +369,7 @@ template struct iterate_over_all_view_elements { ::Kokkos::MDRangePolicy<::Kokkos::Rank<4>>( {0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3)}), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -376,6 +380,7 @@ template struct iterate_over_all_view_elements { {0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4)}), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -386,6 +391,7 @@ template struct iterate_over_all_view_elements { {0, 0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4), v.extent(5)}), func); + ::Kokkos::fence(); } }; template struct iterate_over_all_view_elements { @@ -397,6 +403,7 @@ template struct iterate_over_all_view_elements { {v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4), v.extent(5), v.extent(6)}), func); + ::Kokkos::fence(); } }; template @@ -450,32 +457,57 @@ inline void deep_copy_pullback( }); } -template -inline void -resize_pushforward(View& v, const Idx0 n0, const Idx1 n1, const Idx2 n2, - const Idx3 n3, const Idx4 n4, const Idx5 n5, const Idx6 n6, - const Idx7 n7, View& d_v, const Idx0 /*d_n*/, - const Idx1 /*d_n*/, const Idx2 /*d_n*/, const Idx3 /*d_n*/, - const Idx4 /*d_n*/, const Idx5 /*d_n*/, const Idx6 /*d_n*/, - const Idx7 /*d_n*/) { +template +void resize_pushforward( + View& v, const ::std::size_t n0, const ::std::size_t n1, + const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4, + const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7, + View& d_v, const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/) { ::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7); ::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7); } -template -inline void -resize_pushforward(const I& arg, View& v, const Idx0 n0, const Idx1 n1, - const Idx2 n2, const Idx3 n3, const Idx4 n4, const Idx5 n5, - const Idx6 n6, const Idx7 n7, const dI& /*d_arg*/, View& d_v, - const Idx0 /*d_n*/, const Idx1 /*d_n*/, const Idx2 /*d_n*/, - const Idx3 /*d_n*/, const Idx4 /*d_n*/, const Idx5 /*d_n*/, - const Idx6 /*d_n*/, const Idx7 /*d_n*/) { +template +void resize_pushforward( + const I& arg, View& v, const ::std::size_t n0, const ::std::size_t n1, + const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4, + const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7, + const dI& /*d_arg*/, View& d_v, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/) { + ::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7); + ::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7); +} +template +void resize_reverse_forw( + View& v, const ::std::size_t n0, const ::std::size_t n1, + const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4, + const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7, + View& d_v, const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/) { + ::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7); + ::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7); +} +template +void resize_reverse_forw( + const I& arg, View& v, const ::std::size_t n0, const ::std::size_t n1, + const ::std::size_t n2, const ::std::size_t n3, const ::std::size_t n4, + const ::std::size_t n5, const ::std::size_t n6, const ::std::size_t n7, + const dI& /*d_arg*/, View& d_v, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/, const ::std::size_t /*d_n*/, + const ::std::size_t /*d_n*/) { ::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7); ::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7); } +template void resize_pullback(Args... /*args*/) {} /// Fence template void fence_pushforward(const S& s, const S& /*d_s*/) { diff --git a/unittests/Kokkos/ViewAccess.cpp b/unittests/Kokkos/ViewAccess.cpp index e42475ccd..0281eb229 100644 --- a/unittests/Kokkos/ViewAccess.cpp +++ b/unittests/Kokkos/ViewAccess.cpp @@ -82,16 +82,16 @@ TEST(ViewAccess, Test2) { EXPECT_NEAR(f_3_y.execute(3, 4), dy_f_3_FD, tolerance * dy_f_3_FD); auto f_grad_exe = clad::gradient(f); - double dx, dy; + double dx = 0, dy = 0; f_grad_exe.execute(3., 4., &dx, &dy); EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx); - double dx_2, dy_2; + double dx_2 = 0, dy_2 = 0; auto f_2_grad_exe = clad::gradient(f_2); f_2_grad_exe.execute(3., 4., &dx_2, &dy_2); EXPECT_NEAR(f_2_x.execute(3, 4), dx_2, tolerance * dx_2); - double dx_3, dy_3; + double dx_3 = 0, dy_3 = 0; auto f_3_grad_exe = clad::gradient(f_3); f_3_grad_exe.execute(3., 4., &dx_3, &dy_3); EXPECT_NEAR(f_3_y.execute(3, 4), dy_3, tolerance * dy_3); diff --git a/unittests/Kokkos/ViewBasics.cpp b/unittests/Kokkos/ViewBasics.cpp index 4f996d921..e44737e78 100644 --- a/unittests/Kokkos/ViewBasics.cpp +++ b/unittests/Kokkos/ViewBasics.cpp @@ -211,7 +211,9 @@ double f_basics_resize_3(double x, double y) { 2); Kokkos::deep_copy(a, 3 * x + y); - Kokkos::resize(Kokkos::WithoutInitializing, a, 5, 5); + Kokkos::resize(Kokkos::WithoutInitializing, a, 5, + 5); // FIXME: this signature for the resize function is not yet + // supported in the reverse mode a(4, 4, 0) = x * y; @@ -250,6 +252,41 @@ TEST(ViewBasics, TestResize4) { EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps); } +double f_basics_resize_5_both_modes(double x, double y) { + Kokkos::View a("a", 3, + 2); + Kokkos::View b("b", 5, + 5); + + b(4, 4, 0) = x * y * 2; + b(2, 1, 0) = 0; + + Kokkos::deep_copy(a, 3 * x + y); + a(2, 1, 0) = x * y; + + Kokkos::resize(a, 5, 5); + Kokkos::deep_copy(a, b); + + return a(4, 4, 0); +} + +TEST(ViewBasics, TestResize5) { + const double eps = 1e-8; + + auto df = clad::differentiate(f_basics_resize_5_both_modes, 0); + auto gradf = clad::gradient(f_basics_resize_5_both_modes); + auto df_true_x = [](double x, double y) { return y * 2; }; + for (double x = 3; x <= 5; x += 1) + for (double y = 3; y <= 5; y += 1) { + double dfdx = df.execute(x, y); + EXPECT_NEAR(dfdx, df_true_x(x, y), eps); + double dx = 0, dy = 0; + gradf.execute(x, y, &dx, &dy); + EXPECT_NEAR(dfdx, dx, eps); + EXPECT_NEAR(2 * x, dy, eps); + } +} + template struct FooModifier { double x;