Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler HybridGaussianFoo constructors #1848

Merged
merged 23 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
}
if (auto p = std::dynamic_pointer_cast<HybridFactor>(factor)) {
} else if (auto p = std::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
Expand Down
89 changes: 49 additions & 40 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,57 +27,66 @@
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>

#include <cstddef>

namespace gtsam {
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
const HybridGaussianConditional::Conditionals &conditionals) {
auto func = [](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (conditional) {
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = conditional->negLogConstant();
/* *******************************************************************************/
struct HybridGaussianConditional::ConstructorHelper {
std::optional<size_t> nrFrontals;
HybridGaussianFactor::FactorValuePairs pairs;
double minNegLogConstant;

/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals)
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
nrFrontals = c->nrFrontals();
}
value = c->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we store the value, we should never have to call negLogConstant in this file again!
And that’s why we currently need to keep conditionals :-( at least until the hiding is removed.

};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable.");
}
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
};
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
}
}
};

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(*helper.nrFrontals),
conditionals_(conditionals),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These things are here twice! Once in BaseFactor, and once in conditionals_.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about that too. There was a reason I couldn't get rid of conditionals_, which I think was to do with negLogConstant. I'll think about this again so we can remove it.

negLogConstant_(helper.minNegLogConstant) {}

HybridGaussianConditional::HybridGaussianConditional(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents, GetFactorValuePairs(conditionals)),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate negLogConstant_ as the minimum of the negative-log normalizers of
// the conditionals, by visiting the decision tree:
negLogConstant_ = std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->negLogConstant_ =
std::min(this->negLogConstant_, conditional->negLogConstant());
}
});
}
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

HybridGaussianConditional::HybridGaussianConditional(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(continuousFrontals, continuousParents,
DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
Expand Down Expand Up @@ -222,8 +231,8 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
return {likelihood_m, Cgm_Kgcm};
}
});
return std::make_shared<HybridGaussianFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods);
}

/* ************************************************************************* */
Expand Down
70 changes: 31 additions & 39 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,11 @@ class GTSAM_EXPORT HybridGaussianConditional

private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.

///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_;

/**
* @brief Convert a HybridGaussianConditional of conditionals into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/**
* @brief Helper function to get the pruner functor.
*
* @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &discreteProbs);

public:
/// @name Constructors
/// @{
Expand All @@ -93,37 +77,28 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional() = default;

/**
* @brief Construct a new HybridGaussianConditional object.
* @brief Construct from one discrete key and vector of conditionals.
*
* @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents.
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent.
*/
HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/**
* @brief Construct from multiple discrete keys and conditional tree.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/// @}
/// @name Testable
/// @{
Expand Down Expand Up @@ -207,6 +182,23 @@ class GTSAM_EXPORT HybridGaussianConditional
/// @}

private:
/// Helper struct for private constructor.
struct ConstructorHelper;
dellaert marked this conversation as resolved.
Show resolved Hide resolved

/// Private constructor that uses helper struct above.
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

//// Get the pruner functor from pruned discrete probabilities.
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &prunedProbabilities);

/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

Expand Down
96 changes: 77 additions & 19 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,20 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the additional scalar values.
* This is done by storing the value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
static HybridGaussianFactor::Factors augment(
const HybridGaussianFactor::FactorValuePairs &factors) {
/* *******************************************************************************/
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
const FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
HybridGaussianFactor::Factors gaussianFactors;
Factors gaussianFactors;
AlgebraicDecisionTree<Key> valueTree;
std::tie(gaussianFactors, valueTree) = unzip(factors);

Expand Down Expand Up @@ -73,22 +65,88 @@ static HybridGaussianFactor::Factors augment(
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return HybridGaussianFactor::Factors(factors, update);
return Factors(factors, update);
}

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
struct HybridGaussianFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty
Factors factorsTree;

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &factor : factors) {
if (factor && continuousKeys.empty()) {
continuousKeys = factor->keys();
break;
}
}

// Build the DecisionTree from the factor vector
factorsTree = Factors(discreteKeys, factors);
}

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
}
}

// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(discreteKeys, factorPairs);
}

ConstructorHelper(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
}
});

// Build the FactorValuePairs DecisionTree
pairs = factorPairs;
}
};

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}

HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All three of the above constructors need the divider
/* *******************************************************************************/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more for me when I add it.

const FactorValuePairs &factors)
: Base(continuousKeys, discreteKeys), factors_(augment(factors)) {}
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}

/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false;

// This will return false if either factors_ is empty or e->factors_ is empty,
// but not if both are empty or both are not empty:
// This will return false if either factors_ is empty or e->factors_ is
// empty, but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;

// Check the base and the factors:
Expand Down
Loading
Loading