Skip to content

Commit

Permalink
Shift error values before exponentiating
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Oct 9, 2024
1 parent 19fdb43 commit 34bb1d0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
42 changes: 27 additions & 15 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
#include <utility>
#include <vector>

#include "gtsam/discrete/DecisionTreeFactor.h"

namespace gtsam {

/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
Expand Down Expand Up @@ -226,6 +228,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second};
}

/* ************************************************************************ */
/// Take negative log-values, shift them so that the minimum value is 0, and
/// then exponentiate to create a DecisionTreeFactor (not normalized yet!).
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials = DecisionTree<Key, double>(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
Expand All @@ -237,15 +251,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities.
auto potential = [&](const auto &pair) -> double {
// In this case, compute a discrete factor from the remaining error.
auto calculateError = [&](const auto &pair) -> double {
auto [factor, scalar] = pair;
// If factor is null, it has been pruned, hence return potential zero
if (!factor) return 0.0;
return exp(-scalar - factor->error(kEmpty));
// If factor is null, it has been pruned, hence return infinite error
if (!factor) return std::numeric_limits<double>::infinity();
return scalar + factor->error(kEmpty);
};
DecisionTree<Key, double> potentials(gmf->factors(), potential);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), potentials);
DecisionTree<Key, double> errors(gmf->factors(), calculateError);
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));

} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
Expand Down Expand Up @@ -275,7 +289,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) {
auto potential = [&](const auto &pair) -> double {
auto calculateError = [&](const auto &pair) -> double {
const auto &[conditional, factor] = pair.first;
const double scalar = pair.second;
if (conditional && factor) {
Expand All @@ -284,19 +298,17 @@ static std::shared_ptr<Factor> createDiscreteFactor(
// - factor->error(kempty) is the error remaining after elimination
// - negLogK is what is given to the conditional to normalize
const double negLogK = conditional->negLogConstant();
const double error = scalar + factor->error(kEmpty) - negLogK;
return exp(-error);
return scalar + factor->error(kEmpty) - negLogK;
} else if (!conditional && !factor) {
// If the factor is null, it has been pruned, hence return potential of
// zero
return 0.0;
// If the factor has been pruned, return infinite error
return std::numeric_limits<double>::infinity();
} else {
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
}
};

DecisionTree<Key, double> potentials(eliminationResults, potential);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, potentials);
DecisionTree<Key, double> errors(eliminationResults, calculateError);
return DiscreteFactorFromErrors(discreteSeparator, errors);
}

/* *******************************************************************************/
Expand Down
16 changes: 2 additions & 14 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
}

/* ************************************************************************* */
Expand Down Expand Up @@ -333,19 +333,7 @@ TEST(HybridBayesNet, Switching) {
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because
// the continuos factor whose error(kEmpty) we need is not available..

// However, we can still check the total error for the clique factors_x1 and
// the elimination results are equal, modulo -again- the negative log constant
// of the conditional.
for (auto &&mode : {modeZero, modeOne}) {
auto gc_x1 = (*p_x1_given_m)(mode);
double originalError_x1 = factors_x1.error({continuousValues, mode});
const double actualError = gc_x1->negLogConstant() +
gc_x1->error(continuousValues) +
phi_x1->error(mode);
EXPECT_DOUBLES_EQUAL(originalError_x1, actualError, 1e-9);
}
// the continuous factor whose error(kEmpty) we need is not available..

// Now test full elimination of the graph:
auto hybridBayesNet = graph.eliminateSequential();
Expand Down

0 comments on commit 34bb1d0

Please sign in to comment.