Skip to content

Commit

Permalink
discretePosterior in HNFG
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Oct 2, 2024
1 parent 14d1594 commit 1bb5b95
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ static double PotentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [this, &continuousValues](const auto &pair) {
auto errorFunc = [&continuousValues](const auto &pair) {
return PotentiallyPrunedComponentError(pair.first, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
Expand Down
18 changes: 15 additions & 3 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,19 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
const Values& values) const {
const Values& continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each factor.
for (auto& factor : factors_) {
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
// Compute factor error and add it.
result = result + hnf->errorTree(values);
result = result + hnf->errorTree(continuousValues);

} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
// If continuous only, get the (double) error
// and add it to every leaf of the result
result = result + nf->error(values);
result = result + nf->error(continuousValues);

} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// If discrete, just add its errorTree as well
Expand All @@ -210,4 +210,16 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
return result;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
const Values& continuousValues) const {
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return p / p.sum();
}

/* ************************************************************************ */
} // namespace gtsam
17 changes: 15 additions & 2 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,23 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*
* @note: Gaussian and hybrid Gaussian factors are not considered!
*
* @param values Manifold values at which to compute the error.
* @param continuousValues Manifold values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> errorTree(const Values& values) const;
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const;

/**
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
* This is efficient as this simply takes -exp(.) of errorTree and normalizes.
*
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
* which we would need, are hard to recover.
*
* @param continuousValues Continuous values x to condition on.
* @return DecisionTreeFactor
*/
AlgebraicDecisionTree<Key> discretePosterior(
const Values& continuousValues) const;

/// @}
};
Expand Down

0 comments on commit 1bb5b95

Please sign in to comment.