diff --git a/include/Math/Dual.hpp b/include/Math/Dual.hpp index 2057f0d..f6b210f 100644 --- a/include/Math/Dual.hpp +++ b/include/Math/Dual.hpp @@ -1,4 +1,5 @@ #pragma once +#include "Math/MatrixDimensions.hpp" #include "Math/Vector.hpp" #include #include @@ -25,10 +26,14 @@ template class Dual { constexpr Dual(std::floating_point auto v) : val(v) {} constexpr auto value() -> T & { return val; } constexpr auto gradient() -> SVector & { return partials; } + constexpr auto gradient(ptrdiff_t i) -> T & { return partials[i]; } [[nodiscard]] constexpr auto value() const -> const T & { return val; } [[nodiscard]] constexpr auto gradient() const -> const SVector & { return partials; } + [[nodiscard]] constexpr auto gradient(ptrdiff_t i) const -> const T & { + return partials[i]; + } constexpr auto operator-() const & -> Dual { return {-val, -partials}; } constexpr auto operator+(const Dual &other) const & -> Dual { return {val + other.val, partials + other.partials}; @@ -142,12 +147,48 @@ template constexpr auto exp(Dual x) -> Dual { return {expx, expx * x.gradient()}; } -template -constexpr auto gradient(utils::Arena<> *arena, PtrVector x, const auto &f) { - constexpr ptrdiff_t U = 7; - using D = Dual; +class GradientResult { + double x; + MutPtrVector grad; + +public: + [[nodiscard]] constexpr auto value() const -> double { return x; } + [[nodiscard]] constexpr auto gradient() const -> MutPtrVector { + return grad; + } +}; +class HessianResult { + double x; + double *ptr; + unsigned dim; + +public: + [[nodiscard]] constexpr auto value() const -> double { return x; } + [[nodiscard]] constexpr auto gradient() const -> MutPtrVector { + return {ptr, dim}; + } + [[nodiscard]] constexpr auto hessian() const -> MutSquarePtrMatrix { + return {ptr + dim, dim}; + } +}; + +struct Assign { + constexpr auto operator()(double &x, double y) const { x = y; } +}; +struct Increment { + constexpr auto operator()(double &x, double y) const { x += y; } +}; +struct ScaledIncrement { + double scale; + constexpr auto operator()(double &x, double y) const { x += scale * y; } +}; + +constexpr auto gradient(utils::Arena<> *arena, PtrVector x, + const auto &f) { + constexpr ptrdiff_t U = 8; + using D = Dual; ptrdiff_t N = x.size(); - MutPtrVector grad = vector(arena, N); + MutPtrVector grad = vector(arena, N); auto p = arena->scope(); MutPtrVector dx = vector(arena, N); for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i]; @@ -163,66 +204,77 @@ constexpr auto gradient(utils::Arena<> *arena, PtrVector x, const auto &f) { } } // only computes the upper triangle blocks -template -constexpr auto hessian(utils::Arena<> *arena, PtrVector x, const auto &f) { - constexpr ptrdiff_t Ui = 7; - constexpr ptrdiff_t Uj = 2; - using D = Dual; +constexpr auto extractDualValRecurse(std::floating_point auto x) { return x; } +template +constexpr auto extractDualValRecurse(const Dual &x) { + return extractDualValRecurse(x.value()); +} + +template +constexpr auto hessian(MutPtrVector grad, MutArray hess, + MutPtrVector, Uj>> dx, + const auto &f, auto update) -> double { + using D = Dual; using DD = Dual; - ptrdiff_t N = x.size(); - MutPtrVector grad = vector(arena, N); - MutSquarePtrMatrix hess = matrix(arena, N); - auto p = arena->scope(); - MutPtrVector
dx = vector
(arena, N); - for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i]; + ptrdiff_t N = dx.size(); + invariant(N == grad.size()); + invariant(N == hess.numCol()); + invariant(N == hess.numRow()); for (ptrdiff_t j = 0;; j += Uj) { - for (ptrdiff_t i = j;; i += Ui) { + bool jbr = j + Uj >= N; + for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) + dx[j + k].gradient(k).value() = 1.0; + + for (ptrdiff_t i = 0;; i += Ui) { // df^2/dx_i dx_j + bool ibr = i + Ui - Uj >= j; // we want to copy into both regions _(j, j+Uj) and _(i, i+Ui) - // these regions overlap for the first `i` iteration only - if (i == j) - for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) - dx[j + k] = DD(D(x[j + k], k), k); - for (ptrdiff_t k = (i == j) ? Uj : 0; ((k < Ui) && (i + k < N)); ++k) - dx[i + k] = D(x[i + k], k); + // these regions overlap for the last `i` iteration only + for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k) + dx[i + k].value().gradient(k) = 1.0; - DD fx = utils::call(*arena, f, dx); + DD fx = f(dx); + // DD fx = utils::call(arena, f, dx); for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) for (ptrdiff_t l = 0; ((l < Ui) && (i + l < N)); ++l) - hess(j + k, i + l) = fx.gradient()[k].gradient()[l]; - if (i == j) + update(hess(j + k, i + l), fx.gradient()[k].gradient()[l]); + if (jbr) for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k) grad[i + k] = fx.value().gradient()[k]; - - if (i + Ui >= N) { - if (j + Uj >= N) return std::make_tuple(fx.value().value(), grad, hess); - // we have another `j` iteration, so we reset - // if `i != j`, the `i` and `j` blocks aren't contiguous, we reset both - // if `i == j`, we have one block; we only bother resetting - // the lower `j` subsection, because we're about to overwrite - // the upper `i` subsection anyway - for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) - dx[j + k] = x[j + k]; - if (i != j) - for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k) - dx[i + k] = x[i + k]; - break; - } - // if we're here, we have another `i` iteration - // if we're in the first `i` iteration, we set the first Uj iter - if (i == j) - for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) - dx[j + k] = DD(x[j + k], k); - for (ptrdiff_t k = (i == j ? Uj : 0); ((k < Ui) && (i + k < N)); ++k) - dx[i + k] = x[i + k]; // reset `i` block + if constexpr (!Preserve) + if (ibr && jbr) return fx.value().value(); + for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k) + dx[i + k].value().gradient(k) = 0.0; + if (!ibr) continue; + for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k) + dx[j + k].gradient(k).value() = 0.0; + if constexpr (Preserve) + if (jbr) return fx.value().value(); + break; } } } - -constexpr auto extractDualValRecurse(std::floating_point auto x) { return x; } -template -constexpr auto extractDualValRecurse(const Dual &x) { - return extractDualValRecurse(x.value()); +template +constexpr auto hessian(utils::Arena<> arena, MutPtrVector grad, + MutArray hess, PtrVector x, + const auto &f, auto update) -> double { + constexpr ptrdiff_t Ui = 8; + constexpr ptrdiff_t Uj = 2; + using D = Dual; + using DD = Dual; + ptrdiff_t N = x.size(); + MutPtrVector
dx = vector
(&arena, N); + for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i]; + return hessian(grad, hess, dx, f, update); +} +constexpr auto hessian(utils::Arena<> *arena, PtrVector x, + const auto &f) { + ptrdiff_t N = x.size(); + MutPtrVector grad = vector(arena, N); + MutSquarePtrMatrix hess = matrix(arena, N); + Assign assign{}; + return std::make_tuple(hessian(*arena, grad, hess, x, f, assign), grad, hess); } +static_assert(MatrixDimension); } // namespace poly::math diff --git a/include/Math/Math.hpp b/include/Math/Math.hpp index 9172ad8..64e8a92 100644 --- a/include/Math/Math.hpp +++ b/include/Math/Math.hpp @@ -61,6 +61,8 @@ template concept TrivialCompatibile = Trivial && Compatible; template concept TrivialVecOrMat = Trivial && VecOrMat; +template +concept TrivialDataMatrix = Trivial && DataMatrix; // // TODO: binary func invocable trait? // template diff --git a/include/Math/MatrixDimensions.hpp b/include/Math/MatrixDimensions.hpp index 565bc07..8881232 100644 --- a/include/Math/MatrixDimensions.hpp +++ b/include/Math/MatrixDimensions.hpp @@ -188,5 +188,6 @@ concept MatrixDimension = requires(D d) { static_assert(MatrixDimension); static_assert(MatrixDimension); static_assert(MatrixDimension); +static_assert(!MatrixDimension); } // namespace poly::math diff --git a/test/dual_test.cpp b/test/dual_test.cpp index 4a750fb..edf18d7 100644 --- a/test/dual_test.cpp +++ b/test/dual_test.cpp @@ -34,9 +34,11 @@ TEST(DualTest, BasicAssertions) { EXPECT_TRUE(std::abs(fxx - f) < 1e-10); EXPECT_TRUE(norm2(g - gx) < 1e-10); EXPECT_TRUE(norm2(g - gxx) < 1e-10); - for (ptrdiff_t i = 1; i < hxx.numRow(); ++i) - for (ptrdiff_t j = 0; j < i; ++j) hxx(i, j) = hxx(j, i); - // std::cout << "B = " << B << "\nhxx = " << hxx << std::endl; + std::cout << "g = " << g << "\ngxx = " << gxx << std::endl; + std::cout << "B = " << B << "\nhxx = " << hxx << std::endl; + for (ptrdiff_t i = 0; i < hxx.numRow(); ++i) + for (ptrdiff_t j = i + 1; j < hxx.numCol(); ++j) hxx(i, j) = hxx(j, i); + std::cout << "hxx = " << hxx << std::endl; EXPECT_TRUE(norm2(B - hxx) < 1e-10); };