diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index d3b26d4ef8..f1e8a8498a 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -233,7 +233,7 @@ static double PotentiallyPrunedComponentError( AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [this, &continuousValues](const auto &pair) { + auto errorFunc = [&continuousValues](const auto &pair) { return PotentiallyPrunedComponentError(pair.first, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 0e7e9c692d..56b75d15ee 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -181,19 +181,19 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( /* ************************************************************************* */ AlgebraicDecisionTree HybridNonlinearFactorGraph::errorTree( - const Values& values) const { + const Values& continuousValues) const { AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto& factor : factors_) { if (auto hnf = std::dynamic_pointer_cast(factor)) { // Compute factor error and add it. - result = result + hnf->errorTree(values); + result = result + hnf->errorTree(continuousValues); } else if (auto nf = std::dynamic_pointer_cast(factor)) { // If continuous only, get the (double) error // and add it to every leaf of the result - result = result + nf->error(values); + result = result + nf->error(continuousValues); } else if (auto df = std::dynamic_pointer_cast(factor)) { // If discrete, just add its errorTree as well @@ -210,4 +210,16 @@ AlgebraicDecisionTree HybridNonlinearFactorGraph::errorTree( return result; } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridNonlinearFactorGraph::discretePosterior( + const Values& continuousValues) const { + AlgebraicDecisionTree errors = this->errorTree(continuousValues); + AlgebraicDecisionTree p = errors.apply([](double error) { + // NOTE: The 0.5 term is handled by each factor + return exp(-error); + }); + return p / p.sum(); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 53920a4aad..dd18cfa601 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -98,10 +98,23 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { * * @note: Gaussian and hybrid Gaussian factors are not considered! * - * @param values Manifold values at which to compute the error. + * @param continuousValues Manifold values at which to compute the error. * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree errorTree(const Values& values) const; + AlgebraicDecisionTree errorTree(const Values& continuousValues) const; + + /** + * @brief Computer posterior P(M|X=x) when all continuous values X are given. + * This is efficient as this simply takes -exp(.) of errorTree and normalizes. + * + * @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys, + * which we would need, are hard to recover. + * + * @param continuousValues Continuous values x to condition on. + * @return DecisionTreeFactor + */ + AlgebraicDecisionTree discretePosterior( + const Values& continuousValues) const; /// @} };