Skip to content

Commit

Permalink
Merge pull request #1836 from borglab/improved-api
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Sep 20, 2024
2 parents 08967d1 + 245f3e0 commit 017044e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 42 deletions.
48 changes: 24 additions & 24 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,37 @@
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
const HybridGaussianConditional::Conditionals &conditionals) {
auto func = [](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (conditional) {
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = -conditional->logNormalizationConstant();
}
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
};
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
}

HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
discreteParents, GetFactorValuePairs(conditionals)),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// Calculate logConstant_ as the minimum of the log normalizers of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
logConstant_ = std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
this->logConstant_ = std::min(
this->logConstant_, -conditional->logNormalizationConstant());
}
});
}
Expand All @@ -64,29 +78,14 @@ HybridGaussianConditional::HybridGaussianConditional(
DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
// derived from HybridGaussianFactor, no?
GaussianFactorGraphTree HybridGaussianConditional::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm =
this->logConstant_ - gc->logNormalizationConstant();
-this->logConstant_ - gc->logNormalizationConstant();
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
Expand Down Expand Up @@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << std::endl
<< " logNormalizationConstant: " << logConstant_ << std::endl
<< " logNormalizationConstant: " << logNormalizationConstant()
<< std::endl
<< std::endl;
conditionals_.print(
"", [&](Key k) { return formatter(k); },
Expand Down Expand Up @@ -216,7 +216,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
-logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return {likelihood_m, 0.0};
} else {
Expand Down Expand Up @@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError(
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
-logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
Expand Down
29 changes: 12 additions & 17 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,22 @@ class HybridValues;
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianConditional
: public HybridFactor,
public Conditional<HybridFactor, HybridGaussianConditional> {
: public HybridGaussianFactor,
public Conditional<HybridGaussianFactor, HybridGaussianConditional> {
public:
using This = HybridGaussianConditional;
using shared_ptr = std::shared_ptr<HybridGaussianConditional>;
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, HybridGaussianConditional>;
using shared_ptr = std::shared_ptr<This>;
using BaseFactor = HybridGaussianFactor;
using BaseConditional = Conditional<BaseFactor, HybridGaussianConditional>;

/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double logConstant_;

/**
* @brief Convert a HybridGaussianConditional of conditionals into
Expand Down Expand Up @@ -107,8 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
Expand Down Expand Up @@ -149,7 +152,7 @@ class GTSAM_EXPORT HybridGaussianConditional

/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }
double logNormalizationConstant() const override { return -logConstant_; }

/**
* Create a likelihood factor for a hybrid Gaussian conditional,
Expand Down Expand Up @@ -232,14 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
*/
void prune(const DecisionTreeFactor &discreteProbs);

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}

private:
Expand Down
33 changes: 32 additions & 1 deletion gtsam/hybrid/tests/testHybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
auto actual = hybrid_conditional.errorTree(vv);

// Check result.
std::vector<DiscreteKey> discrete_keys = {mode};
DiscreteKeys discrete_keys{mode};
std::vector<double> leaves = {conditionals[0]->error(vv),
conditionals[1]->error(vv)};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
Expand Down Expand Up @@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
EXPECT(continuousParentKeys[0] == X(0));
}

/* ************************************************************************* */
/// Check error with mode dependent constants.
TEST(HybridGaussianConditional, Error2) {
using namespace mode_dependent_constants;
auto actual = hybrid_conditional.errorTree(vv);

// Check result.
DiscreteKeys discrete_keys{mode};
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);

// Expected error is e(X) + log(|2πΣ|).
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
std::vector<double> leaves = {
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);

EXPECT(assert_equal(expected, actual, 1e-6));

// Check for non-tree version.
for (size_t mode : {0, 1}) {
const HybridValues hv{vv, {{M(0), mode}}};
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
conditionals[mode]->logNormalizationConstant() -
minLogNormalizer,
hybrid_conditional.error(hv), 1e-8);
}
}

/* ************************************************************************* */
/// Check that the likelihood is proportional to the conditional density given
/// the measurements.
Expand Down

0 comments on commit 017044e

Please sign in to comment.