diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index 7a0bdc4d6..2f9505991 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -6,6 +6,7 @@ #define CLAD_DIFFERENTIATOR_KOKKOSBUILTINS_H #include +#include #include "clad/Differentiator/Differentiator.h" namespace clad::custom_derivatives { @@ -29,7 +30,7 @@ constructor_pushforward( } } // namespace class_functions -/// Kokkos functions +/// Kokkos functions (view utils) namespace Kokkos { template inline void deep_copy_pushforward(const View1& dst, const View2& src, T param, @@ -66,15 +67,235 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0, ::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7); } -template +/// Parallel for +template // range policy +inline void parallel_for_pushforward( + const ::std::string& str, + const ::Kokkos::RangePolicy& policy, + const FunctorType& functor, const ::std::string& /*d_str*/, + const ::Kokkos::RangePolicy& /*d_policy*/, + const FunctorType& d_functor) { + ::Kokkos::parallel_for(str, policy, functor); + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const int i) { + functor.operator_call_pushforward(i, &d_functor, 0); + }); +} + +// This structure is used to dispatch parallel for pushforward calls based on +// the rank and the work tag of the MDPolicy +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + assert(false && "Some parallel_for misuse happened during the compilation " + "(templates have not been matched properly)."); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward(args..., &d_functor, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, 0, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, 0, 0, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0); + }); + } +}; +template +struct diff_parallel_for_MDP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward( + args..., &d_functor, 0, 0, 0, 0, 0, 0); + }); + } +}; + +template // multi-dimensional policy +inline void parallel_for_pushforward( + const ::std::string& str, + const ::Kokkos::MDRangePolicy& policy, + const FunctorType& functor, const ::std::string& /*d_str*/, + const ::Kokkos::MDRangePolicy& /*d_policy*/, + const FunctorType& d_functor) { + ::Kokkos::parallel_for(str, policy, functor); + diff_parallel_for_MDP_call_dispatch< + ::Kokkos::MDRangePolicy, FunctorType, + typename ::Kokkos::MDRangePolicy::work_tag, + ::Kokkos::MDRangePolicy::rank>::run(str, policy, + functor, + d_functor); +} + +// This structure is used to dispatch parallel for pushforward calls based on +// the work tag of other types of policies +template +struct diff_parallel_for_OP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for("_diff_" + str, policy, + [&functor, &d_functor](const T x, auto&&... args) { + functor.operator_call_pushforward( + x, args..., &d_functor, &x, {}); + }); + } +}; +template +struct diff_parallel_for_OP_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](auto&&... args) { + functor.operator_call_pushforward(args..., &d_functor, {}); + }); + } +}; + +// This structure is used to dispatch parallel for pushforward calls for +// integral policies +template +struct diff_parallel_for_int_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + diff_parallel_for_OP_call_dispatch< + Policy, FunctorType, typename Policy::work_tag>::run(str, policy, + functor, + d_functor); + } +}; +template +struct diff_parallel_for_int_call_dispatch { + static void run(const ::std::string& str, const Policy& policy, + const FunctorType& functor, const FunctorType& d_functor) { + ::Kokkos::parallel_for( + "_diff_" + str, policy, [&functor, &d_functor](const int i) { + functor.operator_call_pushforward(i, &d_functor, 0); + }); + } +}; + +template // other policy type +inline void parallel_for_pushforward(const ::std::string& str, + const Policy& policy, + const FunctorType& functor, + const ::std::string& /*d_str*/, + const Policy& /*d_policy*/, + const FunctorType& d_functor) { + ::Kokkos::parallel_for(str, policy, functor); + diff_parallel_for_int_call_dispatch< + Policy, FunctorType, ::std::is_integral::value>::run(str, policy, + functor, + d_functor); +} + +template // anonymous loop inline void -parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy, - const FunctorType& functor, const ::std::string& d_str, - const ExecPolicy& d_policy, - const FunctorType& d_functor) { - // TODO: implement parallel_for_pushforward - return; +parallel_for_pushforward(const Policy& policy, const FunctorType& functor, + const Policy& d_policy, const FunctorType& d_functor) { + parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy, + functor, ::std::string(""), d_policy, d_functor); } + +template // anonymous loop +inline void parallel_for_pushforward( + const Policy& policy, const FunctorType& functor, + ::std::enable_if_t<::Kokkos::is_execution_policy::value>* /*param*/, + const Policy& d_policy, const FunctorType& d_functor, + ::std::enable_if_t< + ::Kokkos::is_execution_policy::value>* /*d_param*/) { + parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy, + functor, ::std::string(""), d_policy, d_functor); +} + } // namespace Kokkos } // namespace clad::custom_derivatives diff --git a/unittests/Kokkos/ParallelFor.cpp b/unittests/Kokkos/ParallelFor.cpp index 69494226b..6d1dcd0da 100644 --- a/unittests/Kokkos/ParallelFor.cpp +++ b/unittests/Kokkos/ParallelFor.cpp @@ -1,5 +1,6 @@ #include #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/KokkosBuiltins.h" #include "gtest/gtest.h" // #include "TestUtils.h" #include "ParallelAdd.h" @@ -89,4 +90,118 @@ TEST(ParallelFor, ParallelPolynomialReverse) { // f_grad.execute(x, &dx); // EXPECT_NEAR(dx_f_true, dx, abs(tau*dx)); // } +} + +template struct Foo { + View& res; + double& x; + + Foo(View& _res, double& _x) : res(_res), x(_x) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int i) const { res(i) = x * i; } +}; + +double parallel_for_functor_simplest_case_intpol(double x) { + Kokkos::View res("res"); + + Foo> f(res, x); + + f(0); + + Kokkos::parallel_for("polynomial", 5, f); + Kokkos::parallel_for(5, f); + + return res(3); +} + +double parallel_for_functor_simplest_case_rangepol(double x) { + Kokkos::View res("res"); + + Foo> f(res, x); + + f(0); + + Kokkos::parallel_for( + "polynomial", + Kokkos::RangePolicy(1, 5), f); + // Overwrite with another parallel_for (not named) + Kokkos::parallel_for( + Kokkos::RangePolicy(1, 5), f); + + return res(3); +} + +template struct Foo2 { + View& res; + double& x; + + Foo2(View& _res, double& _x) : res(_res), x(_x) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int i, const int j) const { res(i, j) = x * i * j; } +}; + +double parallel_for_functor_simplest_case_mdpol(double x) { + Kokkos::View res("res"); + + Foo2> f(res, x); + + f(0, 0); + + Kokkos::parallel_for( + "polynomial", + Kokkos::MDRangePolicy< + Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>( + {1, 1}, {5, 5}, {1, 1}), + f); + + return res(3, 4); +} + +double parallel_for_functor_simplest_case_mdpol_space_and_anon(double x) { + Kokkos::View res("res"); + + Foo2> f(res, x); + + f(0, 0); + + Kokkos::parallel_for( + "polynomial", + Kokkos::MDRangePolicy< + Kokkos::DefaultHostExecutionSpace, + Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>( + {1, 1}, {5, 5}, {1, 1}), + f); + // Overwrite with another parallel_for (not named) + Kokkos::parallel_for( + Kokkos::MDRangePolicy< + Kokkos::DefaultHostExecutionSpace, + Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>( + {1, 1}, {5, 5}, {1, 1}), + f); + + return res(3, 4); +} + +TEST(ParallelFor, FunctorSimplestCases) { + const double eps = 1e-8; + + auto df1 = clad::differentiate(parallel_for_functor_simplest_case_intpol, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df1.execute(x), 3, eps); + + auto df2 = + clad::differentiate(parallel_for_functor_simplest_case_rangepol, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df2.execute(x), 3, eps); + + auto df3 = clad::differentiate(parallel_for_functor_simplest_case_mdpol, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df3.execute(x), 12, eps); + + auto df4 = clad::differentiate( + parallel_for_functor_simplest_case_mdpol_space_and_anon, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df4.execute(x), 12, eps); } \ No newline at end of file