From dc2c98383738b977b986caacd7468b85dc5d6890 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 11:53:29 -0700 Subject: [PATCH 01/11] Document and fix bug in modes --- gtsam/hybrid/tests/Switching.h | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 82876fd2c4..547facce9a 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -114,11 +114,11 @@ inline std::pair> makeBinaryOrdering( return {new_order, levels}; } -/* *************************************************************************** - */ +/* ****************************************************************************/ using MotionModel = BetweenFactor; // Test fixture with switching network. +/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1)) struct Switching { size_t K; DiscreteKeys modes; @@ -140,8 +140,8 @@ struct Switching { : K(K) { using noiseModel::Isotropic; - // Create DiscreteKeys for binary K modes. - for (size_t k = 0; k < K; k++) { + // Create DiscreteKeys for K-1 binary modes. + for (size_t k = 0; k < K - 1; k++) { modes.emplace_back(M(k), 2); } @@ -153,25 +153,26 @@ struct Switching { } // Create hybrid factor graph. - // Add a prior on X(0). + + // Add a prior ϕ(X(0)) on X(0). nonlinearFactorGraph.emplace_shared>( X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma)); - // Add "motion models". + // Add "motion models" ϕ(X(k),X(k+1)). for (size_t k = 0; k < K - 1; k++) { auto motion_models = motionModels(k, between_sigma); nonlinearFactorGraph.emplace_shared(modes[k], motion_models); } - // Add measurement factors + // Add measurement factors ϕ(X(k);z_k). auto measurement_noise = Isotropic::Sigma(1, prior_sigma); for (size_t k = 1; k < K; k++) { nonlinearFactorGraph.emplace_shared>( X(k), measurements.at(k), measurement_noise); } - // Add "mode chain" + // Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2)) addModeChain(&nonlinearFactorGraph, discrete_transition_prob); // Create the linearization point. @@ -179,8 +180,6 @@ struct Switching { linearizationPoint.insert(X(k), static_cast(k + 1)); } - // The ground truth is robot moving forward - // and one less than the linearization point linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint); } @@ -196,7 +195,7 @@ struct Switching { } /** - * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). + * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1). * E.g. if K=4, we want M0, M1 and M2. * * @param fg The factor graph to which the mode chain is added. From 28a2cd347514a79c3cc9c6044fa451bd959832eb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 13:25:34 -0700 Subject: [PATCH 02/11] Check invariants --- gtsam/hybrid/tests/testHybridGaussianConditional.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 02163df9e4..803d42f034 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -168,6 +168,9 @@ TEST(HybridGaussianConditional, ContinuousParents) { // Check that the continuous parent keys are correct: EXPECT(continuousParentKeys.size() == 1); EXPECT(continuousParentKeys[0] == X(0)); + + EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv0)); + EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv1)); } /* ************************************************************************* */ From 8675dc62df1277f84f6b689a2afbd64df5fb6b5e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 15:41:17 -0700 Subject: [PATCH 03/11] Throw if sampling conditions not satisfied --- gtsam/discrete/DiscreteConditional.cpp | 18 +++++++++++++----- gtsam/discrete/DiscreteConditional.h | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 3f0c9f5118..5ab0c59ec4 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); + // throw if more than one frontal: + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sampleInPlace can only be called on single " + "variable conditionals"); + } + Key j = firstFrontalKey(); + // throw if values already contains j: + if (values->count(j) > 0) { + throw std::invalid_argument( + "DiscreteConditional::sampleInPlace: values already contains j"); + } size_t sampled = sample(*values); // Sample variable given parents (*values)[j] = sampled; // store result in partial solution } @@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const { } /* ************************************************************************* */ -double DiscreteConditional::negLogConstant() const { - return 0.0; -} +double DiscreteConditional::negLogConstant() const { return 0.0; } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index e16100d0aa..f59e292856 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional static_cast(this)->print(s, formatter); } - /// Evaluate, just look up in AlgebraicDecisonTree + /// Evaluate, just look up in AlgebraicDecisionTree double evaluate(const DiscreteValues& values) const { return ADT::operator()(values); } From 80a4cd1bfca60ce28667f000dce397435a5eaeda Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 15:42:02 -0700 Subject: [PATCH 04/11] Only sample if not provided --- gtsam/discrete/DiscreteBayesNet.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 1c5c81e456..56265b0a4a 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // sample each node in turn in topological sort order (parents first) for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { - (*it)->sampleInPlace(&result); + const DiscreteConditional::shared_ptr& conditional = *it; + // Sample the conditional only if value for j not already in result + const Key j = conditional->firstFrontalKey(); + if (result.count(j) == 0) { + conditional->sampleInPlace(&result); + } } return result; } From 846c7a1a99b7f573dd7f6ec23c9d91c2e63e7f9a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 22:56:16 -0700 Subject: [PATCH 05/11] Make testable --- gtsam/hybrid/HybridGaussianFactorGraph.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 923f48e380..7e3aac663d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -231,4 +231,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph GaussianFactorGraph operator()(const DiscreteValues& assignment) const; }; +// traits +template <> +struct traits + : public Testable {}; + } // namespace gtsam From 21171b3a9a59926d2bdaf7545f6b9bc56565f62a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:57:37 -0700 Subject: [PATCH 06/11] Fix equality --- gtsam/hybrid/HybridGaussianConditional.cpp | 9 ++++----- gtsam/hybrid/HybridGaussianFactor.cpp | 7 +++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 1712e06a9f..fb89f72fc9 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -192,11 +192,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf, // Check the base and the factors: return BaseFactor::equals(*e, tol) && - conditionals_.equals(e->conditionals_, - [tol](const GaussianConditional::shared_ptr &f1, - const GaussianConditional::shared_ptr &f2) { - return f1->equals(*(f2), tol); - }); + conditionals_.equals( + e->conditionals_, [tol](const auto &f1, const auto &f2) { + return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); + }); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 7f8e808bf4..6a9fc5c352 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -154,10 +154,9 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, - [tol](const sharedFactor &f1, const sharedFactor &f2) { - return f1->equals(*f2, tol); - }); + factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) { + return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); + }); } /* *******************************************************************************/ From d042359a997949255a4ddd2e34e1d7d01bd8eba7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:57:48 -0700 Subject: [PATCH 07/11] choose method --- gtsam/hybrid/HybridGaussianConditional.cpp | 2 +- gtsam/hybrid/HybridGaussianConditional.h | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index fb89f72fc9..1db13e95b3 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -168,7 +168,7 @@ size_t HybridGaussianConditional::nrComponents() const { } /* *******************************************************************************/ -GaussianConditional::shared_ptr HybridGaussianConditional::operator()( +GaussianConditional::shared_ptr HybridGaussianConditional::choose( const DiscreteValues &discreteValues) const { auto &ptr = conditionals_(discreteValues); if (!ptr) return nullptr; diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 827b7f309d..68c63e7bd7 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -159,9 +159,15 @@ class GTSAM_EXPORT HybridGaussianConditional /// @{ /// @brief Return the conditional Gaussian for the given discrete assignment. - GaussianConditional::shared_ptr operator()( + GaussianConditional::shared_ptr choose( const DiscreteValues &discreteValues) const; + /// @brief Syntactic sugar for choose. + GaussianConditional::shared_ptr operator()( + const DiscreteValues &discreteValues) const { + return choose(discreteValues); + } + /// Returns the total number of continuous components size_t nrComponents() const; From 5b909e3c2918a74228d5ee3ad50f7d4afe13cbd6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 23:58:21 -0700 Subject: [PATCH 08/11] Clarify choose method --- gtsam/hybrid/HybridBayesNet.cpp | 2 +- gtsam/hybrid/HybridBayesNet.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 6e96afb257..3c77e3f9aa 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -206,7 +206,7 @@ GaussianBayesNet HybridBayesNet::choose( for (auto &&conditional : *this) { if (auto gm = conditional->asHybrid()) { // If conditional is hybrid, select based on assignment. - gbn.push_back((*gm)(assignment)); + gbn.push_back(gm->choose(assignment)); } else if (auto gc = conditional->asGaussian()) { // If continuous only, add Gaussian conditional. gbn.push_back(gc); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 22a43f3bd3..62688e8b20 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -127,6 +127,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete * value assignment. * + * @note Any pure discrete factors are ignored. + * * @param assignment The discrete value assignment for the discrete keys. * @return GaussianBayesNet */ From 12349b9201d2fbdce6e245d8acd0da5acf85f24a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 00:32:06 -0700 Subject: [PATCH 09/11] Helper method can be static --- gtsam/hybrid/HybridGaussianFactor.cpp | 17 ++++++++--------- gtsam/hybrid/HybridGaussianFactor.h | 4 ---- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 6a9fc5c352..b04db4977e 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -212,16 +212,15 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() } /* *******************************************************************************/ -double HybridGaussianFactor::potentiallyPrunedComponentError( - const sharedFactor &gf, const VectorValues &values) const { +/// Helper method to compute the error of a component. +static double PotentiallyPrunedComponentError( + const GaussianFactor::shared_ptr &gf, const VectorValues &values) { // Check if valid pointer if (gf) { return gf->error(values); } else { - // If not valid, pointer, it means this component was pruned, - // so we return maximum error. - // This way the negative exponential will give - // a probability value close to 0.0. + // If nullptr this component was pruned, so we return maximum error. This + // way the negative exponential will give a probability value close to 0.0. return std::numeric_limits::max(); } } @@ -230,8 +229,8 @@ double HybridGaussianFactor::potentiallyPrunedComponentError( AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [this, &continuousValues](const sharedFactor &gf) { - return this->potentiallyPrunedComponentError(gf, continuousValues); + auto errorFunc = [&continuousValues](const sharedFactor &gf) { + return PotentiallyPrunedComponentError(gf, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); return error_tree; @@ -241,7 +240,7 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( double HybridGaussianFactor::error(const HybridValues &values) const { // Directly index to get the component, no need to build the whole tree. const sharedFactor gf = factors_(values.discrete()); - return potentiallyPrunedComponentError(gf, values.continuous()); + return PotentiallyPrunedComponentError(gf, values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index f23d065b6e..e5a5754094 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -189,10 +189,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { */ static Factors augment(const FactorValuePairs &factors); - /// Helper method to compute the error of a component. - double potentiallyPrunedComponentError( - const sharedFactor &gf, const VectorValues &continuousValues) const; - /// Helper struct to assist private constructor below. struct ConstructorHelper; From 44fb786b7aac9227424c785465e98dda037f18f6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 01:38:06 -0700 Subject: [PATCH 10/11] Much more comprehensive tests --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 118 +++++++++++++++++++--- 1 file changed, 102 insertions(+), 16 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index b555f6bd9f..79979ac83a 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -62,32 +62,117 @@ TEST(HybridBayesNet, Add) { } /* ****************************************************************************/ -// Test evaluate for a pure discrete Bayes net P(Asia). +// Test API for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.emplace_shared(Asia, "4/6"); - HybridValues values; - values.insert(asiaKey, 0); - EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9); + const auto pAsia = std::make_shared(Asia, "4/6"); + bayesNet.push_back(pAsia); + HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; + + // choose + GaussianBayesNet empty; + EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); + + // evaluate + EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal(zero, bayesNet.sample(&rng))); + EXPECT(assert_equal(one, bayesNet.sample(one, &rng))); + EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng))); + + // error + EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // toFactorGraph + HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({}); + EXPECT(assert_equal(expectedFG, fg)); + + // prune, imperative :-( + EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); + EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); } /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bn = tiny::createHybridBayesNet(); - EXPECT_LONGS_EQUAL(3, bn.size()); + auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode) + EXPECT_LONGS_EQUAL(3, bayesNet.size()); const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; - auto fg = bn.toFactorGraph(vv); + HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; + + // Check Invariants for components + HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); + GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()), + gc1 = hgc->choose(one.discrete()); + GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian(); + GaussianConditional::CheckInvariants(*gc0, vv); + GaussianConditional::CheckInvariants(*gc1, vv); + GaussianConditional::CheckInvariants(*px, vv); + HybridGaussianConditional::CheckInvariants(*hgc, zero); + HybridGaussianConditional::CheckInvariants(*hgc, one); + + // choose + GaussianBayesNet expectedChosen; + expectedChosen.push_back(gc0); + expectedChosen.push_back(px); + auto chosen0 = bayesNet.choose(zero.discrete()); + auto chosen1 = bayesNet.choose(one.discrete()); + EXPECT(assert_equal(expectedChosen, chosen0, 1e-9)); + + // logProbability + const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior + const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior + EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); + + // evaluate + EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); + + // error + const double error0 = chosen0.error(vv) + gc0->negLogConstant() - + px->negLogConstant() - log(0.4); + const double error1 = chosen1.error(vv) + gc1->negLogConstant() - + px->negLogConstant() - log(0.6); + EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9); + EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9); + + // toFactorGraph + auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}}); EXPECT_LONGS_EQUAL(3, fg.size()); // Check that the ratio of probPrime to evaluate is the same for all modes. std::vector ratio(2); - for (size_t mode : {0, 1}) { - const HybridValues hv{vv, {{M(0), mode}}}; - ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv); - } + ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); + ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + + // prune, imperative :-( + auto pruned = bayesNet.prune(1); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); + EXPECT(!pruned.equals(bayesNet)); + } /* ****************************************************************************/ @@ -223,12 +308,15 @@ TEST(HybridBayesNet, Optimize) { /* ****************************************************************************/ // Test Bayes net error TEST(HybridBayesNet, Pruning) { + // Create switching network with three continuous variables and two discrete: + // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) Switching s(3); HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph.eliminateSequential(); EXPECT_LONGS_EQUAL(5, posterior->size()); + // Optimize HybridValues delta = posterior->optimize(); auto actualTree = posterior->evaluate(delta.continuous()); @@ -254,7 +342,6 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues); - // NOTE(dellaert): the discrete errors were not added in logProbability tree! logProbability += posterior->at(3)->asDiscrete()->logProbability(hybridValues); logProbability += @@ -316,10 +403,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { #endif // regression - DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; DecisionTreeFactor::ADT potentials( - dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); - DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); + s.modes, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); // Prune! posterior->prune(maxNrLeaves); From 3cd816341ccbc2b0b3e617d9723431c1c4775ec9 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 30 Sep 2024 12:49:29 -0700 Subject: [PATCH 11/11] refactor printErrors --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 92 ++++++++-------------- 1 file changed, 34 insertions(+), 58 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c107aa8a8f..8a2a7fd158 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) { index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); } +/* ************************************************************************ */ +static void printFactor(const std::shared_ptr &factor, + const DiscreteValues &assignment, + const KeyFormatter &keyFormatter) { + if (auto hgf = std::dynamic_pointer_cast(factor)) { + hgf->operator()(assignment) + ->print("HybridGaussianFactor, component:", keyFormatter); + } else if (auto gf = std::dynamic_pointer_cast(factor)) { + factor->print("GaussianFactor:\n", keyFormatter); + } else if (auto df = std::dynamic_pointer_cast(factor)) { + factor->print("DiscreteFactor:\n", keyFormatter); + } else if (auto hc = std::dynamic_pointer_cast(factor)) { + if (hc->isContinuous()) { + factor->print("GaussianConditional:\n", keyFormatter); + } else if (hc->isDiscrete()) { + factor->print("DiscreteConditional:\n", keyFormatter); + } else { + hc->asHybrid() + ->choose(assignment) + ->print("HybridConditional, component:\n", keyFormatter); + } + } else { + factor->print("Unknown factor type\n", keyFormatter); + } +} + /* ************************************************************************ */ void HybridGaussianFactorGraph::printErrors( const HybridValues &values, const std::string &str, @@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors( &printCondition) const { std::cout << str << "size: " << size() << std::endl << std::endl; - std::stringstream ss; - for (size_t i = 0; i < factors_.size(); i++) { auto &&factor = factors_[i]; - std::cout << "Factor " << i << ": "; - - // Clear the stringstream - ss.str(std::string()); - - if (auto hgf = std::dynamic_pointer_cast(factor)) { - if (factor == nullptr) { - std::cout << "nullptr" - << "\n"; - } else { - hgf->operator()(values.discrete())->print(ss.str(), keyFormatter); - std::cout << "error = " << factor->error(values) << std::endl; - } - } else if (auto hc = std::dynamic_pointer_cast(factor)) { - if (factor == nullptr) { - std::cout << "nullptr" - << "\n"; - } else { - if (hc->isContinuous()) { - factor->print(ss.str(), keyFormatter); - std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; - } else if (hc->isDiscrete()) { - factor->print(ss.str(), keyFormatter); - std::cout << "error = " << hc->asDiscrete()->error(values.discrete()) - << "\n"; - } else { - // Is hybrid - auto conditionalComponent = - hc->asHybrid()->operator()(values.discrete()); - conditionalComponent->print(ss.str(), keyFormatter); - std::cout << "error = " << conditionalComponent->error(values) - << "\n"; - } - } - } else if (auto gf = std::dynamic_pointer_cast(factor)) { - const double errorValue = (factor != nullptr ? gf->error(values) : .0); - if (!printCondition(factor.get(), errorValue, i)) - continue; // User-provided filter did not pass - - if (factor == nullptr) { - std::cout << "nullptr" - << "\n"; - } else { - factor->print(ss.str(), keyFormatter); - std::cout << "error = " << errorValue << "\n"; - } - } else if (auto df = std::dynamic_pointer_cast(factor)) { - if (factor == nullptr) { - std::cout << "nullptr" - << "\n"; - } else { - factor->print(ss.str(), keyFormatter); - std::cout << "error = " << df->error(values.discrete()) << std::endl; - } - - } else { + if (factor == nullptr) { + std::cout << "Factor " << i << ": nullptr\n"; continue; } + const double errorValue = factor->error(values); + if (!printCondition(factor.get(), errorValue, i)) + continue; // User-provided filter did not pass + // Print the factor + std::cout << "Factor " << i << ", error = " << errorValue << "\n"; + printFactor(factor, values.discrete(), keyFormatter); std::cout << "\n"; } std::cout.flush();