Skip to content

Commit

Permalink
Add 3x3 special case to Solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jun 28, 2024
1 parent ecc2624 commit daa4f13
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 23 deletions.
46 changes: 40 additions & 6 deletions src/autodiff/VariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,46 @@ VariableMatrix Solve(const VariableMatrix& A, const VariableMatrix& B) {

if (A.Rows() == 2 && A.Cols() == 2) {
// Compute optimal inverse instead of using Eigen's general solver
sleipnir::VariableMatrix Ainv{2, 2};
Ainv(0, 0) = A(1, 1);
Ainv(0, 1) = -A(0, 1);
Ainv(1, 0) = -A(1, 0);
Ainv(1, 1) = A(0, 0);
auto detA = A(0, 0) * A(1, 1) - A(0, 1) * A(1, 0);
//
// [a b]⁻¹ ___1___ [ d −b]
// [c d] = ad − bc [−c a]

const auto& a = A(0, 0);
const auto& b = A(0, 1);
const auto& c = A(1, 0);
const auto& d = A(1, 1);

sleipnir::VariableMatrix Ainv{{d, -b}, {-c, a}};
auto detA = a * d - b * c;
Ainv /= detA;

return Ainv * B;
} else if (A.Rows() == 3 && A.Cols() == 3) {
// Compute optimal inverse instead of using Eigen's general solver
//
// [a b c]⁻¹
// [d e f]
// [g h i]
// 1 [ei − fh ch − bi bf − ce]
// = --------------------------------- [fg − di ai − cg cd − af]
// aei − afh − bdi + bfg + cdh − ceg [dh − eg bg − ah ae − bd]

const auto& a = A(0, 0);
const auto& b = A(0, 1);
const auto& c = A(0, 2);
const auto& d = A(1, 0);
const auto& e = A(1, 1);
const auto& f = A(1, 2);
const auto& g = A(2, 0);
const auto& h = A(2, 1);
const auto& i = A(2, 2);

sleipnir::VariableMatrix Ainv{
{e * i - f * h, c * h - b * i, b * f - c * e},
{f * g - d * i, a * i - c * g, c * d - a * f},
{d * h - e * g, b * g - a * h, a * e - b * d}};
auto detA =
a * e * i - a * f * h - b * d * i + b * f * g + c * d * h - c * e * g;
Ainv /= detA;

return Ainv * B;
Expand Down
48 changes: 31 additions & 17 deletions test/src/autodiff/VariableMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,38 @@ TEST_CASE("VariableMatrix - Block() free function", "[VariableMatrix]") {
}

TEST_CASE("VariableMatrix - Solve() free function", "[VariableMatrix]") {
sleipnir::VariableMatrix A1{{1.0, 2.0}, {3.0, 4.0}};
sleipnir::VariableMatrix B1{{5.0}, {6.0}};
sleipnir::VariableMatrix X1 = sleipnir::Solve(A1, B1);
sleipnir::VariableMatrix A_22{{1.0, 2.0}, {3.0, 4.0}};
sleipnir::VariableMatrix B_21{{5.0}, {6.0}};
sleipnir::VariableMatrix X_21 = sleipnir::Solve(A_22, B_21);

Eigen::Matrix<double, 2, 1> expected1{{-4.0}, {4.5}};
CHECK(X1.Rows() == 2);
CHECK(X1.Cols() == 1);
CHECK(A1.Value() * X1.Value() == B1.Value());
CHECK(X1.Value() == expected1);
Eigen::Matrix<double, 2, 1> expected_21{{-4.0}, {4.5}};
CHECK(X_21.Rows() == 2);
CHECK(X_21.Cols() == 1);
CHECK(A_22.Value() * X_21.Value() == B_21.Value());
CHECK(X_21.Value() == expected_21);

sleipnir::VariableMatrix A2{
sleipnir::VariableMatrix A_33{
{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}};
sleipnir::VariableMatrix B2{{10.0}, {11.0}, {12.0}};
sleipnir::VariableMatrix X2 = sleipnir::Solve(A2, B2);

Eigen::Matrix<double, 3, 1> expected2{{-7.5}, {6.0}, {11.0 / 6.0}};
CHECK(X2.Rows() == 3);
CHECK(X2.Cols() == 1);
CHECK((A2.Value() * X2.Value() - B2.Value()).norm() < 1e-12);
CHECK((X2.Value() - expected2).norm() < 1e-12);
sleipnir::VariableMatrix B_31{{10.0}, {11.0}, {12.0}};
sleipnir::VariableMatrix X_31 = sleipnir::Solve(A_33, B_31);

Eigen::Matrix<double, 3, 1> expected_31{{-7.5}, {6.0}, {11.0 / 6.0}};
CHECK(X_31.Rows() == 3);
CHECK(X_31.Cols() == 1);
CHECK((A_33.Value() * X_31.Value() - B_31.Value()).norm() < 1e-12);
CHECK((X_31.Value() - expected_31).norm() < 1e-12);

sleipnir::VariableMatrix A_44{{1.0, 2.0, 3.0, -4.0},
{-5.0, 6.0, 7.0, 8.0},
{9.0, 10.0, 11.0, 12.0},
{13.0, 14.0, 15.0, 16.0}};
sleipnir::VariableMatrix B_41{{17.0}, {18.0}, {19.0}, {20.0}};
sleipnir::VariableMatrix X_41 = sleipnir::Solve(A_44, B_41);

Eigen::Matrix<double, 4, 1> expected_41{
{4.44089e-16}, {-16.25}, {16.5}, {0.0}};
CHECK(X_41.Rows() == 4);
CHECK(X_41.Cols() == 1);
CHECK((A_44.Value() * X_41.Value() - B_41.Value()).norm() < 1e-12);
CHECK((X_41.Value() - expected_41).norm() < 1e-12);
}

0 comments on commit daa4f13

Please sign in to comment.