From 0f4f8d15c234459ad8e9e15cc85ce135776811d1 Mon Sep 17 00:00:00 2001 From: chriselrod Date: Mon, 4 Sep 2023 22:27:07 -0400 Subject: [PATCH] LDL factorization --- include/Math/LinearAlgebra.hpp | 59 +++++++++++++++++++++++++++++++--- test/linear_algebra_test.cpp | 20 ++++++++++-- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/include/Math/LinearAlgebra.hpp b/include/Math/LinearAlgebra.hpp index 79a00fb..aa101b8 100644 --- a/include/Math/LinearAlgebra.hpp +++ b/include/Math/LinearAlgebra.hpp @@ -6,7 +6,9 @@ #include "Math/Rational.hpp" #include "Utilities/Invariant.hpp" #include -namespace poly::math::LU { +#include +namespace poly::math { +namespace LU { [[nodiscard]] constexpr auto ldivrat(SquarePtrMatrix F, PtrVector ipiv, MutPtrMatrix rhs) -> bool { @@ -236,8 +238,7 @@ template constexpr auto factImpl(MutSquarePtrMatrix A) { S invAkk = 1.0 / A(k, k); for (ptrdiff_t i = k + 1; i < M; ++i) A(i, k) = A(i, k) * invAkk; for (ptrdiff_t i = k + 1; i < M; ++i) - for (ptrdiff_t j = k + 1; j < M; ++j) - A(i, j) = A(i, j) - A(i, k) * A(k, j); + for (ptrdiff_t j = k + 1; j < M; ++j) A(i, j) -= A(i, k) * A(k, j); } return ipiv; } @@ -271,4 +272,54 @@ constexpr void rdiv(MutSquarePtrMatrix A, MutPtrMatrix B) { rdiv(A, ipiv, B); } -} // namespace poly::math::LU +} // namespace LU + +/// factorizes symmetric full-rank (but not necessarilly positive-definite) +/// matrix A into LDL^T, where L is lower-triangular with 1s on the diagonal +/// Only uses the lower triangle of A, overwriting it. +/// `D` is stored into the diagonal of `A`. +namespace LDL { + +/// NOT OWNING +/// TODO: make the API consistent between LU and LDL +template class Fact { + MutSquarePtrMatrix fact; + +public: + constexpr Fact(MutSquarePtrMatrix A) : fact{A} { + Row M = fact.numRow(); + invariant(M == fact.numCol()); + for (ptrdiff_t k = 0; k < M; ++k) { + T Akk = fact(k, k); + T invAkk = 1.0 / Akk; + for (ptrdiff_t i = k + 1; i < M; ++i) fact(i, k) = fact(i, k) * invAkk; + for (ptrdiff_t i = k + 1; i < M; ++i) { + T Aik = fact(i, k) * Akk; + for (ptrdiff_t j = k + 1; j <= i; ++j) fact(i, j) -= Aik * fact(j, k); + } + } + } + + constexpr void ldiv(MutPtrMatrix rhs) { + auto [M, N] = rhs.size(); + invariant(ptrdiff_t(fact.numRow()), ptrdiff_t(M)); + // LDL' x = rhs + // L y = rhs // L is UnitLowerTriangular + for (ptrdiff_t m = 0; m < M; ++m) + for (ptrdiff_t k = 0; k < m; ++k) rhs(m, _) -= fact(m, k) * rhs(k, _); + // D L' x = y + // L' x = D^-1 y + for (ptrdiff_t m = ptrdiff_t(M); m--;) { + rhs(m, _) /= fact(m, m); + for (ptrdiff_t k = m + 1; k < M; ++k) rhs(m, _) -= fact(k, m) * rhs(k, _); + } + } +}; + +template +constexpr void ldiv(MutSquarePtrMatrix A, MutPtrMatrix B) { + Fact(A).ldiv(B); +} + +} // namespace LDL +} // namespace poly::math diff --git a/test/linear_algebra_test.cpp b/test/linear_algebra_test.cpp index d134d54..08faab3 100644 --- a/test/linear_algebra_test.cpp +++ b/test/linear_algebra_test.cpp @@ -47,14 +47,15 @@ TEST(LinearAlgebraTest, BasicAssertions) { } // NOLINTNEXTLINE(modernize-use-trailing-return-type) -TEST(DoubleLU, BasicAssertions) { - SquareMatrix A(4), B(4), C(4), D(4); +TEST(DoubleFactorization, BasicAssertions) { + SquareMatrix A(7), B(7), C(7), D(7); std::mt19937 gen(0); std::uniform_real_distribution dist(-1, 1); - for (ptrdiff_t i = 0; i < 100; ++i) { + for (ptrdiff_t i = 0; i < 10; ++i) { for (auto &a : A) a = dist(gen); for (auto &b : B) b = dist(gen); C << B; + // LU // B = A \ B // C == A*B == A * (A \ B) LU::fact(A).ldiv(MutPtrMatrix(B)); @@ -63,5 +64,18 @@ TEST(DoubleLU, BasicAssertions) { D << A; LU::ldiv(A, MutPtrMatrix(B)); EXPECT_TRUE(norm2(D * B - C) < 1e-10); + + // LDL; make `A` symmetric + D << A + A.transpose(); + A << D; + B << C; + // B = A \ B + // C == A*B == A * (A \ B) + LDL::Fact(D).ldiv(MutPtrMatrix(B)); + EXPECT_TRUE(norm2(A * B - C) < 1e-10); + B << C; + D << A; + LDL::ldiv(A, MutPtrMatrix(B)); + EXPECT_TRUE(norm2(D * B - C) < 1e-10); } }