Skip to content

Commit

Permalink
New product factor class
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Oct 4, 2024
1 parent 1bb5b95 commit 8b3dfd8
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 3 deletions.
10 changes: 8 additions & 2 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
* @date Mar 12, 2022
*/

#include <gtsam/base/types.h>
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

#include "gtsam/base/types.h"

namespace gtsam {

/* *******************************************************************************/
Expand Down Expand Up @@ -215,6 +215,12 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}

/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
return {{factors_,
[](const auto &pair) { return GaussianFactorGraph{pair.first}; }}};
}

/* *******************************************************************************/
/// Helper method to compute the error of a component.
static double PotentiallyPrunedComponentError(
Expand Down
12 changes: 11 additions & 1 deletion gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

Expand Down Expand Up @@ -164,6 +165,14 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
sum = factor.add(sum);
return sum;
}

/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return HybridGaussianProductFactor
*/
virtual HybridGaussianProductFactor asProductFactor() const;
/// @}

protected:
Expand All @@ -175,7 +184,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

private:

private:
/**
* @brief Helper function to augment the [A|b] matrices in the factor
* components with the additional scalar values. This is done by storing the
Expand Down
89 changes: 89 additions & 0 deletions gtsam/hybrid/HybridGaussianProductFactor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridGaussianProductFactor.h
* @date Oct 2, 2024
* @author Frank Dellaert
* @author Varun Agrawal
*/

#include <gtsam/base/types.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

static GaussianFactorGraph add(const GaussianFactorGraph &graph1,
const GaussianFactorGraph &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};

HybridGaussianProductFactor operator+(const HybridGaussianProductFactor &a,
const HybridGaussianProductFactor &b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
}

HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor &factor) const {
return *this + factor.asProductFactor();
}

HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr &factor) const {
return *this + HybridGaussianProductFactor(factor);
}

HybridGaussianProductFactor &HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr &factor) {
*this = *this + factor;
return *this;
}

HybridGaussianProductFactor &
HybridGaussianProductFactor::operator+=(const HybridGaussianFactor &factor) {
*this = *this + factor;
return *this;
}

void HybridGaussianProductFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
KeySet keys;
auto printer = [&](const Y &graph) {
if (keys.size() == 0)
keys = graph.keys();
return "Graph of size " + std::to_string(graph.size());
};
Base::print(s, formatter, printer);
if (keys.size() > 0) {
std::stringstream ss;
ss << s << " Keys:";
for (auto &&key : keys)
ss << " " << formatter(key);
std::cout << ss.str() << "." << std::endl;
}
}

HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
std::any_of(graph.begin(), graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : graph;
};
return {Base(*this, emptyGaussian)};
}

} // namespace gtsam
117 changes: 117 additions & 0 deletions gtsam/hybrid/HybridGaussianProductFactor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridGaussianProductFactor.h
* @date Oct 2, 2024
* @author Frank Dellaert
* @author Varun Agrawal
*/

#pragma once

#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

class HybridGaussianFactor;

/// Alias for DecisionTree of GaussianFactorGraphs
class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph> {
public:
using Y = GaussianFactorGraph;
using Base = DecisionTree<Key, Y>;

/// @name Constructors
/// @{

/// Default constructor
HybridGaussianProductFactor() = default;

/**
* @brief Construct from a single factor
* @tparam FACTOR Factor type
* @param factor Shared pointer to the factor
*/
template <class FACTOR>
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor) : Base(Y{factor}) {}

/**
* @brief Construct from DecisionTree
* @param tree Decision tree to construct from
*/
HybridGaussianProductFactor(Base&& tree) : Base(std::move(tree)) {}

///@}

/// @name Operators
///@{

/// Add GaussianFactor into HybridGaussianProductFactor
HybridGaussianProductFactor operator+(const GaussianFactor::shared_ptr& factor) const;

/// Add HybridGaussianFactor into HybridGaussianProductFactor
HybridGaussianProductFactor operator+(const HybridGaussianFactor& factor) const;

/// Add-assign operator for GaussianFactor
HybridGaussianProductFactor& operator+=(const GaussianFactor::shared_ptr& factor);

/// Add-assign operator for HybridGaussianFactor
HybridGaussianProductFactor& operator+=(const HybridGaussianFactor& factor);

///@}

/// @name Testable
/// @{

/**
* @brief Print the HybridGaussianProductFactor
* @param s Optional string to prepend
* @param formatter Optional key formatter
*/
void print(const std::string& s = "", const KeyFormatter& formatter = DefaultKeyFormatter) const;

/**
* @brief Check if this HybridGaussianProductFactor is equal to another
* @param other The other HybridGaussianProductFactor to compare with
* @param tol Tolerance for floating point comparisons
* @return true if equal, false otherwise
*/
bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const {
return Base::equals(other, [tol](const Y& a, const Y& b) { return a.equals(b, tol); });
}

/// @}

/// @name Other methods
///@{

/**
* @brief Remove empty GaussianFactorGraphs from the decision tree
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed
*
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
* that leaf to an empty GaussianFactorGraph. This is needed because the DecisionTree
* will otherwise create a GaussianFactorGraph with a single (null) factor,
* which doesn't register as null.
*/
HybridGaussianProductFactor removeEmpty() const;

///@}
};

// Testable traits
template <>
struct traits<HybridGaussianProductFactor> : public Testable<HybridGaussianProductFactor> {};

} // namespace gtsam
Loading

0 comments on commit 8b3dfd8

Please sign in to comment.