Skip to content

Commit

Permalink
transpose hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Sep 5, 2023
1 parent 70097f8 commit 76aafc7
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 56 deletions.
158 changes: 105 additions & 53 deletions include/Math/Dual.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include "Math/MatrixDimensions.hpp"
#include "Math/Vector.hpp"
#include <Math/Array.hpp>
#include <Math/Constructors.hpp>
Expand All @@ -25,10 +26,14 @@ template <class T, ptrdiff_t N> class Dual {
constexpr Dual(std::floating_point auto v) : val(v) {}
constexpr auto value() -> T & { return val; }
constexpr auto gradient() -> SVector<T, N> & { return partials; }
constexpr auto gradient(ptrdiff_t i) -> T & { return partials[i]; }
[[nodiscard]] constexpr auto value() const -> const T & { return val; }
[[nodiscard]] constexpr auto gradient() const -> const SVector<T, N> & {
return partials;
}
[[nodiscard]] constexpr auto gradient(ptrdiff_t i) const -> const T & {
return partials[i];
}
constexpr auto operator-() const & -> Dual { return {-val, -partials}; }
constexpr auto operator+(const Dual &other) const & -> Dual {
return {val + other.val, partials + other.partials};
Expand Down Expand Up @@ -142,12 +147,48 @@ template <class T, ptrdiff_t N> constexpr auto exp(Dual<T, N> x) -> Dual<T, N> {
return {expx, expx * x.gradient()};
}

template <typename T>
constexpr auto gradient(utils::Arena<> *arena, PtrVector<T> x, const auto &f) {
constexpr ptrdiff_t U = 7;
using D = Dual<T, U>;
class GradientResult {
double x;
MutPtrVector<double> grad;

public:
[[nodiscard]] constexpr auto value() const -> double { return x; }
[[nodiscard]] constexpr auto gradient() const -> MutPtrVector<double> {
return grad;
}
};
class HessianResult {
double x;
double *ptr;
unsigned dim;

public:
[[nodiscard]] constexpr auto value() const -> double { return x; }
[[nodiscard]] constexpr auto gradient() const -> MutPtrVector<double> {
return {ptr, dim};
}
[[nodiscard]] constexpr auto hessian() const -> MutSquarePtrMatrix<double> {
return {ptr + dim, dim};
}
};

struct Assign {
constexpr auto operator()(double &x, double y) const { x = y; }
};
struct Increment {
constexpr auto operator()(double &x, double y) const { x += y; }
};
struct ScaledIncrement {
double scale;
constexpr auto operator()(double &x, double y) const { x += scale * y; }
};

constexpr auto gradient(utils::Arena<> *arena, PtrVector<double> x,
const auto &f) {
constexpr ptrdiff_t U = 8;
using D = Dual<double, U>;
ptrdiff_t N = x.size();
MutPtrVector<T> grad = vector<T>(arena, N);
MutPtrVector<double> grad = vector<double>(arena, N);
auto p = arena->scope();
MutPtrVector<D> dx = vector<D>(arena, N);
for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i];
Expand All @@ -163,66 +204,77 @@ constexpr auto gradient(utils::Arena<> *arena, PtrVector<T> x, const auto &f) {
}
}
// only computes the upper triangle blocks
template <typename T>
constexpr auto hessian(utils::Arena<> *arena, PtrVector<T> x, const auto &f) {
constexpr ptrdiff_t Ui = 7;
constexpr ptrdiff_t Uj = 2;
using D = Dual<T, Ui>;
constexpr auto extractDualValRecurse(std::floating_point auto x) { return x; }
template <class T, ptrdiff_t N>
constexpr auto extractDualValRecurse(const Dual<T, N> &x) {
return extractDualValRecurse(x.value());
}

template <bool Preserve = false, MatrixDimension S, ptrdiff_t Ui, ptrdiff_t Uj>
constexpr auto hessian(MutPtrVector<double> grad, MutArray<double, S> hess,
MutPtrVector<Dual<Dual<double, Ui>, Uj>> dx,
const auto &f, auto update) -> double {
using D = Dual<double, Ui>;
using DD = Dual<D, Uj>;
ptrdiff_t N = x.size();
MutPtrVector<T> grad = vector<T>(arena, N);
MutSquarePtrMatrix<T> hess = matrix<T>(arena, N);
auto p = arena->scope();
MutPtrVector<DD> dx = vector<DD>(arena, N);
for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i];
ptrdiff_t N = dx.size();
invariant(N == grad.size());
invariant(N == hess.numCol());
invariant(N == hess.numRow());
for (ptrdiff_t j = 0;; j += Uj) {
for (ptrdiff_t i = j;; i += Ui) {
bool jbr = j + Uj >= N;
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
dx[j + k].gradient(k).value() = 1.0;

for (ptrdiff_t i = 0;; i += Ui) {
// df^2/dx_i dx_j
bool ibr = i + Ui - Uj >= j;
// we want to copy into both regions _(j, j+Uj) and _(i, i+Ui)
// these regions overlap for the first `i` iteration only
if (i == j)
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
dx[j + k] = DD(D(x[j + k], k), k);
for (ptrdiff_t k = (i == j) ? Uj : 0; ((k < Ui) && (i + k < N)); ++k)
dx[i + k] = D(x[i + k], k);
// these regions overlap for the last `i` iteration only
for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k)
dx[i + k].value().gradient(k) = 1.0;

DD fx = utils::call(*arena, f, dx);
DD fx = f(dx);
// DD fx = utils::call(arena, f, dx);
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
for (ptrdiff_t l = 0; ((l < Ui) && (i + l < N)); ++l)
hess(j + k, i + l) = fx.gradient()[k].gradient()[l];
if (i == j)
update(hess(j + k, i + l), fx.gradient()[k].gradient()[l]);
if (jbr)
for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k)
grad[i + k] = fx.value().gradient()[k];

if (i + Ui >= N) {
if (j + Uj >= N) return std::make_tuple(fx.value().value(), grad, hess);
// we have another `j` iteration, so we reset
// if `i != j`, the `i` and `j` blocks aren't contiguous, we reset both
// if `i == j`, we have one block; we only bother resetting
// the lower `j` subsection, because we're about to overwrite
// the upper `i` subsection anyway
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
dx[j + k] = x[j + k];
if (i != j)
for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k)
dx[i + k] = x[i + k];
break;
}
// if we're here, we have another `i` iteration
// if we're in the first `i` iteration, we set the first Uj iter
if (i == j)
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
dx[j + k] = DD(x[j + k], k);
for (ptrdiff_t k = (i == j ? Uj : 0); ((k < Ui) && (i + k < N)); ++k)
dx[i + k] = x[i + k]; // reset `i` block
if constexpr (!Preserve)
if (ibr && jbr) return fx.value().value();
for (ptrdiff_t k = 0; ((k < Ui) && (i + k < N)); ++k)
dx[i + k].value().gradient(k) = 0.0;
if (!ibr) continue;
for (ptrdiff_t k = 0; ((k < Uj) && (j + k < N)); ++k)
dx[j + k].gradient(k).value() = 0.0;
if constexpr (Preserve)
if (jbr) return fx.value().value();
break;
}
}
}

constexpr auto extractDualValRecurse(std::floating_point auto x) { return x; }
template <class T, ptrdiff_t N>
constexpr auto extractDualValRecurse(const Dual<T, N> &x) {
return extractDualValRecurse(x.value());
template <MatrixDimension S>
constexpr auto hessian(utils::Arena<> arena, MutPtrVector<double> grad,
MutArray<double, S> hess, PtrVector<double> x,
const auto &f, auto update) -> double {
constexpr ptrdiff_t Ui = 8;
constexpr ptrdiff_t Uj = 2;
using D = Dual<double, Ui>;
using DD = Dual<D, Uj>;
ptrdiff_t N = x.size();
MutPtrVector<DD> dx = vector<DD>(&arena, N);
for (ptrdiff_t i = 0; i < N; ++i) dx[i] = x[i];
return hessian<false>(grad, hess, dx, f, update);
}
constexpr auto hessian(utils::Arena<> *arena, PtrVector<double> x,
const auto &f) {
ptrdiff_t N = x.size();
MutPtrVector<double> grad = vector<double>(arena, N);
MutSquarePtrMatrix<double> hess = matrix<double>(arena, N);
Assign assign{};
return std::make_tuple(hessian(*arena, grad, hess, x, f, assign), grad, hess);
}
static_assert(MatrixDimension<SquareDims>);

} // namespace poly::math
2 changes: 2 additions & 0 deletions include/Math/Math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ template <typename T, typename C>
concept TrivialCompatibile = Trivial<T> && Compatible<T, C>;
template <typename T>
concept TrivialVecOrMat = Trivial<T> && VecOrMat<T>;
template <typename T>
concept TrivialDataMatrix = Trivial<T> && DataMatrix<T>;

// // TODO: binary func invocable trait?
// template <typename Op, typename T, typename S>
Expand Down
1 change: 1 addition & 0 deletions include/Math/MatrixDimensions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,6 @@ concept MatrixDimension = requires(D d) {
static_assert(MatrixDimension<SquareDims>);
static_assert(MatrixDimension<DenseDims>);
static_assert(MatrixDimension<StridedDims>);
static_assert(!MatrixDimension<unsigned>);

} // namespace poly::math
8 changes: 5 additions & 3 deletions test/dual_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ TEST(DualTest, BasicAssertions) {
EXPECT_TRUE(std::abs(fxx - f) < 1e-10);
EXPECT_TRUE(norm2(g - gx) < 1e-10);
EXPECT_TRUE(norm2(g - gxx) < 1e-10);
for (ptrdiff_t i = 1; i < hxx.numRow(); ++i)
for (ptrdiff_t j = 0; j < i; ++j) hxx(i, j) = hxx(j, i);
// std::cout << "B = " << B << "\nhxx = " << hxx << std::endl;
std::cout << "g = " << g << "\ngxx = " << gxx << std::endl;
std::cout << "B = " << B << "\nhxx = " << hxx << std::endl;
for (ptrdiff_t i = 0; i < hxx.numRow(); ++i)
for (ptrdiff_t j = i + 1; j < hxx.numCol(); ++j) hxx(i, j) = hxx(j, i);
std::cout << "hxx = " << hxx << std::endl;
EXPECT_TRUE(norm2(B - hxx) < 1e-10);
};

Expand Down

0 comments on commit 76aafc7

Please sign in to comment.