Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HybridGaussianConditional inherits from HybridGaussianFactor #1836

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,37 @@
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
const HybridGaussianConditional::Conditionals &conditionals) {
auto func = [](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (conditional) {
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = -conditional->logNormalizationConstant();
}
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
};
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
}

HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
discreteParents, GetFactorValuePairs(conditionals)),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// Calculate logConstant_ as the minimum of the log normalizers of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
logConstant_ = std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
this->logConstant_ = std::min(
this->logConstant_, -conditional->logNormalizationConstant());
}
});
}
Expand All @@ -64,29 +78,14 @@ HybridGaussianConditional::HybridGaussianConditional(
DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
// derived from HybridGaussianFactor, no?
GaussianFactorGraphTree HybridGaussianConditional::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm =
this->logConstant_ - gc->logNormalizationConstant();
-this->logConstant_ - gc->logNormalizationConstant();
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
Expand Down Expand Up @@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << std::endl
<< " logNormalizationConstant: " << logConstant_ << std::endl
<< " logNormalizationConstant: " << logNormalizationConstant()
<< std::endl
<< std::endl;
conditionals_.print(
"", [&](Key k) { return formatter(k); },
Expand Down Expand Up @@ -216,7 +216,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
-logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return {likelihood_m, 0.0};
} else {
Expand Down Expand Up @@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError(
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
-logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
Expand Down
29 changes: 12 additions & 17 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,22 @@ class HybridValues;
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianConditional
: public HybridFactor,
public Conditional<HybridFactor, HybridGaussianConditional> {
: public HybridGaussianFactor,
public Conditional<HybridGaussianFactor, HybridGaussianConditional> {
public:
using This = HybridGaussianConditional;
using shared_ptr = std::shared_ptr<HybridGaussianConditional>;
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, HybridGaussianConditional>;
using shared_ptr = std::shared_ptr<This>;
using BaseFactor = HybridGaussianFactor;
using BaseConditional = Conditional<BaseFactor, HybridGaussianConditional>;

/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double logConstant_;

/**
* @brief Convert a HybridGaussianConditional of conditionals into
Expand Down Expand Up @@ -107,8 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
Expand Down Expand Up @@ -149,7 +152,7 @@ class GTSAM_EXPORT HybridGaussianConditional

/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }
double logNormalizationConstant() const override { return -logConstant_; }

/**
* Create a likelihood factor for a hybrid Gaussian conditional,
Expand Down Expand Up @@ -232,14 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
*/
void prune(const DecisionTreeFactor &discreteProbs);

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}

private:
Expand Down
33 changes: 32 additions & 1 deletion gtsam/hybrid/tests/testHybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
auto actual = hybrid_conditional.errorTree(vv);

// Check result.
std::vector<DiscreteKey> discrete_keys = {mode};
DiscreteKeys discrete_keys{mode};
std::vector<double> leaves = {conditionals[0]->error(vv),
conditionals[1]->error(vv)};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
Expand Down Expand Up @@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
EXPECT(continuousParentKeys[0] == X(0));
}

/* ************************************************************************* */
/// Check error with mode dependent constants.
TEST(HybridGaussianConditional, Error2) {
using namespace mode_dependent_constants;
auto actual = hybrid_conditional.errorTree(vv);

// Check result.
DiscreteKeys discrete_keys{mode};
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);

// Expected error is e(X) + log(|2πΣ|).
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
std::vector<double> leaves = {
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);

EXPECT(assert_equal(expected, actual, 1e-6));

// Check for non-tree version.
for (size_t mode : {0, 1}) {
const HybridValues hv{vv, {{M(0), mode}}};
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
conditionals[mode]->logNormalizationConstant() -
minLogNormalizer,
hybrid_conditional.error(hv), 1e-8);
}
}

/* ************************************************************************* */
/// Check that the likelihood is proportional to the conditional density given
/// the measurements.
Expand Down
Loading