Skip to content

Commit

Permalink
LDL factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Sep 5, 2023
1 parent 76aafc7 commit 0f4f8d1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 7 deletions.
59 changes: 55 additions & 4 deletions include/Math/LinearAlgebra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include "Math/Rational.hpp"
#include "Utilities/Invariant.hpp"
#include <concepts>
namespace poly::math::LU {
#include <cstddef>
namespace poly::math {
namespace LU {
[[nodiscard]] constexpr auto ldivrat(SquarePtrMatrix<Rational> F,
PtrVector<unsigned> ipiv,
MutPtrMatrix<Rational> rhs) -> bool {
Expand Down Expand Up @@ -236,8 +238,7 @@ template <typename S> constexpr auto factImpl(MutSquarePtrMatrix<S> A) {
S invAkk = 1.0 / A(k, k);
for (ptrdiff_t i = k + 1; i < M; ++i) A(i, k) = A(i, k) * invAkk;
for (ptrdiff_t i = k + 1; i < M; ++i)
for (ptrdiff_t j = k + 1; j < M; ++j)
A(i, j) = A(i, j) - A(i, k) * A(k, j);
for (ptrdiff_t j = k + 1; j < M; ++j) A(i, j) -= A(i, k) * A(k, j);
}
return ipiv;
}
Expand Down Expand Up @@ -271,4 +272,54 @@ constexpr void rdiv(MutSquarePtrMatrix<T> A, MutPtrMatrix<T> B) {
rdiv(A, ipiv, B);
}

} // namespace poly::math::LU
} // namespace LU

/// factorizes symmetric full-rank (but not necessarilly positive-definite)
/// matrix A into LDL^T, where L is lower-triangular with 1s on the diagonal
/// Only uses the lower triangle of A, overwriting it.
/// `D` is stored into the diagonal of `A`.
namespace LDL {

/// NOT OWNING
/// TODO: make the API consistent between LU and LDL
template <typename T> class Fact {
MutSquarePtrMatrix<T> fact;

public:
constexpr Fact(MutSquarePtrMatrix<T> A) : fact{A} {
Row M = fact.numRow();
invariant(M == fact.numCol());
for (ptrdiff_t k = 0; k < M; ++k) {
T Akk = fact(k, k);
T invAkk = 1.0 / Akk;
for (ptrdiff_t i = k + 1; i < M; ++i) fact(i, k) = fact(i, k) * invAkk;
for (ptrdiff_t i = k + 1; i < M; ++i) {
T Aik = fact(i, k) * Akk;
for (ptrdiff_t j = k + 1; j <= i; ++j) fact(i, j) -= Aik * fact(j, k);
}
}
}

constexpr void ldiv(MutPtrMatrix<T> rhs) {
auto [M, N] = rhs.size();
invariant(ptrdiff_t(fact.numRow()), ptrdiff_t(M));
// LDL' x = rhs
// L y = rhs // L is UnitLowerTriangular
for (ptrdiff_t m = 0; m < M; ++m)
for (ptrdiff_t k = 0; k < m; ++k) rhs(m, _) -= fact(m, k) * rhs(k, _);
// D L' x = y
// L' x = D^-1 y
for (ptrdiff_t m = ptrdiff_t(M); m--;) {
rhs(m, _) /= fact(m, m);
for (ptrdiff_t k = m + 1; k < M; ++k) rhs(m, _) -= fact(k, m) * rhs(k, _);
}
}
};

template <typename T>
constexpr void ldiv(MutSquarePtrMatrix<T> A, MutPtrMatrix<T> B) {
Fact(A).ldiv(B);
}

} // namespace LDL
} // namespace poly::math
20 changes: 17 additions & 3 deletions test/linear_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ TEST(LinearAlgebraTest, BasicAssertions) {
}

// NOLINTNEXTLINE(modernize-use-trailing-return-type)
TEST(DoubleLU, BasicAssertions) {
SquareMatrix<double> A(4), B(4), C(4), D(4);
TEST(DoubleFactorization, BasicAssertions) {
SquareMatrix<double> A(7), B(7), C(7), D(7);
std::mt19937 gen(0);
std::uniform_real_distribution<double> dist(-1, 1);
for (ptrdiff_t i = 0; i < 100; ++i) {
for (ptrdiff_t i = 0; i < 10; ++i) {
for (auto &a : A) a = dist(gen);
for (auto &b : B) b = dist(gen);
C << B;
// LU
// B = A \ B
// C == A*B == A * (A \ B)
LU::fact(A).ldiv(MutPtrMatrix<double>(B));
Expand All @@ -63,5 +64,18 @@ TEST(DoubleLU, BasicAssertions) {
D << A;
LU::ldiv(A, MutPtrMatrix<double>(B));
EXPECT_TRUE(norm2(D * B - C) < 1e-10);

// LDL; make `A` symmetric
D << A + A.transpose();
A << D;
B << C;
// B = A \ B
// C == A*B == A * (A \ B)
LDL::Fact(D).ldiv(MutPtrMatrix<double>(B));
EXPECT_TRUE(norm2(A * B - C) < 1e-10);
B << C;
D << A;
LDL::ldiv(A, MutPtrMatrix<double>(B));
EXPECT_TRUE(norm2(D * B - C) < 1e-10);
}
}

0 comments on commit 0f4f8d1

Please sign in to comment.