From 34bb1d0f343adfa79a6a3d9a991f60e1777eca07 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 9 Oct 2024 20:03:30 +0900 Subject: [PATCH] Shift error values before exponentiating --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 42 ++++++++++++------- .../tests/testHybridGaussianFactorGraph.cpp | 16 +------ 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 167e53656e..5c83fe5146 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -48,6 +48,8 @@ #include #include +#include "gtsam/discrete/DecisionTreeFactor.h" + namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -226,6 +228,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors, return {std::make_shared(result.first), result.second}; } +/* ************************************************************************ */ +/// Take negative log-values, shift them so that the minimum value is 0, and +/// then exponentiate to create a DecisionTreeFactor (not normalized yet!). +static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( + const DiscreteKeys &discreteKeys, + const AlgebraicDecisionTree &errors) { + double min_log = errors.min(); + AlgebraicDecisionTree potentials = DecisionTree( + errors, [&min_log](const double x) { return exp(-(x - min_log)); }); + return std::make_shared(discreteKeys, potentials); +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -237,15 +251,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors, dfg.push_back(df); } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. - // In this case, compute discrete probabilities. - auto potential = [&](const auto &pair) -> double { + // In this case, compute a discrete factor from the remaining error. + auto calculateError = [&](const auto &pair) -> double { auto [factor, scalar] = pair; - // If factor is null, it has been pruned, hence return potential zero - if (!factor) return 0.0; - return exp(-scalar - factor->error(kEmpty)); + // If factor is null, it has been pruned, hence return infinite error + if (!factor) return std::numeric_limits::infinity(); + return scalar + factor->error(kEmpty); }; - DecisionTree potentials(gmf->factors(), potential); - dfg.emplace_shared(gmf->discreteKeys(), potentials); + DecisionTree errors(gmf->factors(), calculateError); + dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors)); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. @@ -275,7 +289,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, static std::shared_ptr createDiscreteFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { - auto potential = [&](const auto &pair) -> double { + auto calculateError = [&](const auto &pair) -> double { const auto &[conditional, factor] = pair.first; const double scalar = pair.second; if (conditional && factor) { @@ -284,19 +298,17 @@ static std::shared_ptr createDiscreteFactor( // - factor->error(kempty) is the error remaining after elimination // - negLogK is what is given to the conditional to normalize const double negLogK = conditional->negLogConstant(); - const double error = scalar + factor->error(kEmpty) - negLogK; - return exp(-error); + return scalar + factor->error(kEmpty) - negLogK; } else if (!conditional && !factor) { - // If the factor is null, it has been pruned, hence return potential of - // zero - return 0.0; + // If the factor has been pruned, return infinite error + return std::numeric_limits::infinity(); } else { throw std::runtime_error("createDiscreteFactor has mixed NULLs"); } }; - DecisionTree potentials(eliminationResults, potential); - return std::make_shared(discreteSeparator, potentials); + DecisionTree errors(eliminationResults, calculateError); + return DiscreteFactorFromErrors(discreteSeparator, errors); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 1aa4c8d492..01294b28c6 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -117,7 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { auto factor = std::dynamic_pointer_cast(result.second); CHECK(factor); // regression test - EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5)); + EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5)); } /* ************************************************************************* */ @@ -333,19 +333,7 @@ TEST(HybridBayesNet, Switching) { CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because - // the continuos factor whose error(kEmpty) we need is not available.. - - // However, we can still check the total error for the clique factors_x1 and - // the elimination results are equal, modulo -again- the negative log constant - // of the conditional. - for (auto &&mode : {modeZero, modeOne}) { - auto gc_x1 = (*p_x1_given_m)(mode); - double originalError_x1 = factors_x1.error({continuousValues, mode}); - const double actualError = gc_x1->negLogConstant() + - gc_x1->error(continuousValues) + - phi_x1->error(mode); - EXPECT_DOUBLES_EQUAL(originalError_x1, actualError, 1e-9); - } + // the continuous factor whose error(kEmpty) we need is not available.. // Now test full elimination of the graph: auto hybridBayesNet = graph.eliminateSequential();