From 8b3dfd85e70099f43d18260c72020116b7495e64 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 2 Oct 2024 08:39:18 -0700 Subject: [PATCH] New product factor class --- gtsam/hybrid/HybridGaussianFactor.cpp | 10 +- gtsam/hybrid/HybridGaussianFactor.h | 12 +- gtsam/hybrid/HybridGaussianProductFactor.cpp | 89 +++++++++ gtsam/hybrid/HybridGaussianProductFactor.h | 117 +++++++++++ .../tests/testHybridGaussianProductFactor.cpp | 185 ++++++++++++++++++ 5 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 gtsam/hybrid/HybridGaussianProductFactor.cpp create mode 100644 gtsam/hybrid/HybridGaussianProductFactor.h create mode 100644 gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index f1e8a8498a..aa88ded30b 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -18,17 +18,17 @@ * @date Mar 12, 2022 */ +#include #include #include #include #include #include +#include #include #include #include -#include "gtsam/base/types.h" - namespace gtsam { /* *******************************************************************************/ @@ -215,6 +215,12 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() return {factors_, wrap}; } +/* *******************************************************************************/ +HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const { + return {{factors_, + [](const auto &pair) { return GaussianFactorGraph{pair.first}; }}}; +} + /* *******************************************************************************/ /// Helper method to compute the error of a component. static double PotentiallyPrunedComponentError( diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 15993f5823..d160798b6b 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -164,6 +165,14 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { sum = factor.add(sum); return sum; } + + /** + * @brief Helper function to return factors and functional to create a + * DecisionTree of Gaussian Factor Graphs. + * + * @return HybridGaussianProductFactor + */ + virtual HybridGaussianProductFactor asProductFactor() const; /// @} protected: @@ -175,7 +184,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { */ GaussianFactorGraphTree asGaussianFactorGraphTree() const; - private: + +private: /** * @brief Helper function to augment the [A|b] matrices in the factor * components with the additional scalar values. This is done by storing the diff --git a/gtsam/hybrid/HybridGaussianProductFactor.cpp b/gtsam/hybrid/HybridGaussianProductFactor.cpp new file mode 100644 index 0000000000..c9b4c07dd6 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianProductFactor.cpp @@ -0,0 +1,89 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianProductFactor.h + * @date Oct 2, 2024 + * @author Frank Dellaert + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + +static GaussianFactorGraph add(const GaussianFactorGraph &graph1, + const GaussianFactorGraph &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; +}; + +HybridGaussianProductFactor operator+(const HybridGaussianProductFactor &a, + const HybridGaussianProductFactor &b) { + return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add)); +} + +HybridGaussianProductFactor HybridGaussianProductFactor::operator+( + const HybridGaussianFactor &factor) const { + return *this + factor.asProductFactor(); +} + +HybridGaussianProductFactor HybridGaussianProductFactor::operator+( + const GaussianFactor::shared_ptr &factor) const { + return *this + HybridGaussianProductFactor(factor); +} + +HybridGaussianProductFactor &HybridGaussianProductFactor::operator+=( + const GaussianFactor::shared_ptr &factor) { + *this = *this + factor; + return *this; +} + +HybridGaussianProductFactor & +HybridGaussianProductFactor::operator+=(const HybridGaussianFactor &factor) { + *this = *this + factor; + return *this; +} + +void HybridGaussianProductFactor::print(const std::string &s, + const KeyFormatter &formatter) const { + KeySet keys; + auto printer = [&](const Y &graph) { + if (keys.size() == 0) + keys = graph.keys(); + return "Graph of size " + std::to_string(graph.size()); + }; + Base::print(s, formatter, printer); + if (keys.size() > 0) { + std::stringstream ss; + ss << s << " Keys:"; + for (auto &&key : keys) + ss << " " << formatter(key); + std::cout << ss.str() << "." << std::endl; + } +} + +HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const { + auto emptyGaussian = [](const GaussianFactorGraph &graph) { + bool hasNull = + std::any_of(graph.begin(), graph.end(), + [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); + return hasNull ? GaussianFactorGraph() : graph; + }; + return {Base(*this, emptyGaussian)}; +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianProductFactor.h b/gtsam/hybrid/HybridGaussianProductFactor.h new file mode 100644 index 0000000000..f1bd8bc3c5 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianProductFactor.h @@ -0,0 +1,117 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianProductFactor.h + * @date Oct 2, 2024 + * @author Frank Dellaert + * @author Varun Agrawal + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + +class HybridGaussianFactor; + +/// Alias for DecisionTree of GaussianFactorGraphs +class HybridGaussianProductFactor : public DecisionTree { + public: + using Y = GaussianFactorGraph; + using Base = DecisionTree; + + /// @name Constructors + /// @{ + + /// Default constructor + HybridGaussianProductFactor() = default; + + /** + * @brief Construct from a single factor + * @tparam FACTOR Factor type + * @param factor Shared pointer to the factor + */ + template + HybridGaussianProductFactor(const std::shared_ptr& factor) : Base(Y{factor}) {} + + /** + * @brief Construct from DecisionTree + * @param tree Decision tree to construct from + */ + HybridGaussianProductFactor(Base&& tree) : Base(std::move(tree)) {} + + ///@} + + /// @name Operators + ///@{ + + /// Add GaussianFactor into HybridGaussianProductFactor + HybridGaussianProductFactor operator+(const GaussianFactor::shared_ptr& factor) const; + + /// Add HybridGaussianFactor into HybridGaussianProductFactor + HybridGaussianProductFactor operator+(const HybridGaussianFactor& factor) const; + + /// Add-assign operator for GaussianFactor + HybridGaussianProductFactor& operator+=(const GaussianFactor::shared_ptr& factor); + + /// Add-assign operator for HybridGaussianFactor + HybridGaussianProductFactor& operator+=(const HybridGaussianFactor& factor); + + ///@} + + /// @name Testable + /// @{ + + /** + * @brief Print the HybridGaussianProductFactor + * @param s Optional string to prepend + * @param formatter Optional key formatter + */ + void print(const std::string& s = "", const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /** + * @brief Check if this HybridGaussianProductFactor is equal to another + * @param other The other HybridGaussianProductFactor to compare with + * @param tol Tolerance for floating point comparisons + * @return true if equal, false otherwise + */ + bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const { + return Base::equals(other, [tol](const Y& a, const Y& b) { return a.equals(b, tol); }); + } + + /// @} + + /// @name Other methods + ///@{ + + /** + * @brief Remove empty GaussianFactorGraphs from the decision tree + * @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed + * + * If any GaussianFactorGraph in the decision tree contains a nullptr, convert + * that leaf to an empty GaussianFactorGraph. This is needed because the DecisionTree + * will otherwise create a GaussianFactorGraph with a single (null) factor, + * which doesn't register as null. + */ + HybridGaussianProductFactor removeEmpty() const; + + ///@} +}; + +// Testable traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp b/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp new file mode 100644 index 0000000000..bd830794ae --- /dev/null +++ b/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp @@ -0,0 +1,185 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridGaussianProductFactor.cpp + * @brief Unit tests for HybridGaussianProductFactor + * @author Frank Dellaert + * @date October 2024 + */ + +#include "gtsam/inference/Key.h" +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +#include + +using namespace std; +using namespace gtsam; +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +/* ************************************************************************* */ +namespace examples { +static const DiscreteKey m1(M(1), 2), m2(M(2), 3); + +auto A1 = Matrix::Zero(2, 1); +auto A2 = Matrix::Zero(2, 2); +auto b = Matrix::Zero(2, 1); + +auto f10 = std::make_shared(X(1), A1, X(2), A2, b); +auto f11 = std::make_shared(X(1), A1, X(2), A2, b); + +auto A3 = Matrix::Zero(2, 3); +auto f20 = std::make_shared(X(1), A1, X(3), A3, b); +auto f21 = std::make_shared(X(1), A1, X(3), A3, b); +auto f22 = std::make_shared(X(1), A1, X(3), A3, b); + +HybridGaussianFactor hybridFactorA(m1, {f10, f11}); +HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22}); +// Simulate a pruned hybrid factor, in this case m2==1 is nulled out. +HybridGaussianFactor prunedFactorB(m2, {f20, nullptr, f22}); +} // namespace examples + +/* ************************************************************************* */ +// Constructor +TEST(HybridGaussianProductFactor, Construct) { + HybridGaussianProductFactor product; +} + +/* ************************************************************************* */ +// Add two Gaussian factors and check only one leaf in tree +TEST(HybridGaussianProductFactor, AddTwoGaussianFactors) { + using namespace examples; + + HybridGaussianProductFactor product; + product += f10; + product += f11; + + // Check that the product has only one leaf and no discrete variables. + EXPECT_LONGS_EQUAL(1, product.nrLeaves()); + EXPECT(product.labels().empty()); + + // Retrieve the single leaf + auto leaf = product(Assignment()); + + // Check that the leaf contains both factors + EXPECT_LONGS_EQUAL(2, leaf.size()); + EXPECT(leaf.at(0) == f10); + EXPECT(leaf.at(1) == f11); +} + +/* ************************************************************************* */ +// Add two GaussianConditionals and check the resulting tree +TEST(HybridGaussianProductFactor, AddTwoGaussianConditionals) { + // Create two GaussianConditionals + Vector1 d(1.0); + Matrix11 R = I_1x1, S = I_1x1; + auto gc1 = std::make_shared(X(1), d, R, X(2), S); + auto gc2 = std::make_shared(X(2), d, R); + + // Create a HybridGaussianProductFactor and add the conditionals + HybridGaussianProductFactor product; + product += std::static_pointer_cast(gc1); + product += std::static_pointer_cast(gc2); + + // Check that the product has only one leaf and no discrete variables + EXPECT_LONGS_EQUAL(1, product.nrLeaves()); + EXPECT(product.labels().empty()); + + // Retrieve the single leaf + auto leaf = product(Assignment()); + + // Check that the leaf contains both conditionals + EXPECT_LONGS_EQUAL(2, leaf.size()); + EXPECT(leaf.at(0) == gc1); + EXPECT(leaf.at(1) == gc2); +} + +/* ************************************************************************* */ +// Check AsProductFactor +TEST(HybridGaussianProductFactor, AsProductFactor) { + using namespace examples; + auto product = hybridFactorA.asProductFactor(); + + // Let's check that this worked: + Assignment mode; + mode[m1.first] = 1; + auto actual = product(mode); + EXPECT(actual.at(0) == f11); +} + +/* ************************************************************************* */ +// "Add" one hybrid factors together. +TEST(HybridGaussianProductFactor, AddOne) { + using namespace examples; + HybridGaussianProductFactor product; + product += hybridFactorA; + + // Let's check that this worked: + Assignment mode; + mode[m1.first] = 1; + auto actual = product(mode); + EXPECT(actual.at(0) == f11); +} + +/* ************************************************************************* */ +// "Add" two HFG together. +TEST(HybridGaussianProductFactor, AddTwo) { + using namespace examples; + + // Create product of two hybrid factors: it will be a decision tree now on + // both discrete variables m1 and m2: + HybridGaussianProductFactor product; + product += hybridFactorA; + product += hybridFactorB; + + // Let's check that this worked: + auto actual00 = product({{M(1), 0}, {M(2), 0}}); + EXPECT(actual00.at(0) == f10); + EXPECT(actual00.at(1) == f20); + + auto actual12 = product({{M(1), 1}, {M(2), 2}}); + EXPECT(actual12.at(0) == f11); + EXPECT(actual12.at(1) == f22); +} + +/* ************************************************************************* */ +// "Add" two HFG together. +TEST(HybridGaussianProductFactor, AddPruned) { + using namespace examples; + + // Create product of two hybrid factors: it will be a decision tree now on + // both discrete variables m1 and m2: + HybridGaussianProductFactor product; + product += hybridFactorA; + product += prunedFactorB; + EXPECT_LONGS_EQUAL(6, product.nrLeaves()); + + auto pruned = product.removeEmpty(); + EXPECT_LONGS_EQUAL(5, pruned.nrLeaves()); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */