Skip to content

Commit

Permalink
Move code to cpp file.
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Sep 26, 2024
1 parent 71d5a6c commit 3e6227f
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 155 deletions.
51 changes: 32 additions & 19 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,47 @@

namespace gtsam {
/* *******************************************************************************/
HybridGaussianConditional::ConstructorHelper::ConstructorHelper(
const HybridGaussianConditional::Conditionals &conditionals) {
negLogConstant = std::numeric_limits<double>::infinity();

auto func =
[&](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (frontals.empty()) {
frontals = KeyVector(c->frontals().begin(), c->frontals().end());
parents = KeyVector(c->parents().begin(), c->parents().end());
struct HybridGaussianConditional::ConstructorHelper {
KeyVector frontals, parents;
HybridGaussianFactor::FactorValuePairs pairs;
double negLogConstant;
/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals) {
negLogConstant = std::numeric_limits<double>::infinity();

auto func = [&](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (frontals.empty()) {
frontals = KeyVector(c->frontals().begin(), c->frontals().end());
parents = KeyVector(c->parents().begin(), c->parents().end());
}
value = c->negLogConstant();
negLogConstant = std::min(negLogConstant, value);
}
value = c->negLogConstant();
negLogConstant = std::min(negLogConstant, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
}
};

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(helper.frontals.size()),
conditionals_(conditionals),
negLogConstant_(helper.negLogConstant) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
Expand Down
14 changes: 2 additions & 12 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,13 @@ class GTSAM_EXPORT HybridGaussianConditional

private:
/// Helper struct for private constructor.
struct ConstructorHelper {
KeyVector frontals, parents;
HybridGaussianFactor::FactorValuePairs pairs;
double negLogConstant;
/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals);
};
struct ConstructorHelper;

/// Private constructor that uses helper struct above.
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(helper.frontals.size()),
conditionals_(conditionals),
negLogConstant_(helper.negLogConstant) {}
const ConstructorHelper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
Expand Down
99 changes: 61 additions & 38 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,53 +70,76 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
}

/* *******************************************************************************/
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &factor : factors) {
if (factor && continuousKeys.empty()) {
continuousKeys = factor->keys();
break;
struct HybridGaussianFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty
Factors factorsTree;

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &factor : factors) {
if (factor && continuousKeys.empty()) {
continuousKeys = factor->keys();
break;
}
}
}

// Build the DecisionTree from the factor vector
factorsTree = Factors(discreteKeys, factors);
}
// Build the DecisionTree from the factor vector
factorsTree = Factors(discreteKeys, factors);
}

/* *******************************************************************************/
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
}
}

// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(discreteKeys, factorPairs);
}

// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(discreteKeys, factorPairs);
}
ConstructorHelper(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
}
});

// Build the FactorValuePairs DecisionTree
pairs = factorPairs;
}
};

/* *******************************************************************************/
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKeys &discreteKeys, const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
}
});
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}

// Build the FactorValuePairs DecisionTree
pairs = factorPairs;
}
HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}

HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}

/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
Expand Down
30 changes: 5 additions & 25 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factors Vector of gaussian factors, one for each mode.
*/
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
const std::vector<GaussianFactor::shared_ptr> &factors);

/**
* @brief Construct a new HybridGaussianFactor on a single discrete key,
Expand All @@ -102,8 +101,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factorPairs Vector of gaussian factor-scalar pairs, one per mode.
*/
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
const std::vector<GaussianFactorValuePair> &factorPairs);

/**
* @brief Construct a new HybridGaussianFactor on a several discrete keys M,
Expand All @@ -115,8 +113,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factors The decision tree of Gaussian factor/scalar pairs.
*/
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
const FactorValuePairs &factors);

/// @}
/// @name Testable
Expand Down Expand Up @@ -197,27 +194,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const sharedFactor &gf, const VectorValues &continuousValues) const;

/// Helper struct to assist private constructor below.
struct ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty
Factors factorsTree;

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors);

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs);

ConstructorHelper(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs);
};
struct ConstructorHelper;

// Private constructor using ConstructorHelper above.
HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}
HybridGaussianFactor(const ConstructorHelper &helper);

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
Expand Down
93 changes: 56 additions & 37 deletions gtsam/hybrid/HybridNonlinearFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,52 +25,71 @@
namespace gtsam {

/* *******************************************************************************/
static void CopyOrCheckContinuousKeys(const NonlinearFactor::shared_ptr& factor,
KeyVector* continuousKeys) {
if (!factor) return;
if (continuousKeys->empty()) {
*continuousKeys = factor->keys();
} else if (factor->keys() != *continuousKeys) {
throw std::runtime_error(
"HybridNonlinearFactor: all factors should have the same keys!");
struct HybridNonlinearFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs factorTree;

void copyOrCheckContinuousKeys(const NonlinearFactor::shared_ptr& factor) {
if (!factor) return;
if (continuousKeys.empty()) {
continuousKeys = factor->keys();
} else if (factor->keys() != continuousKeys) {
throw std::runtime_error(
"HybridNonlinearFactor: all factors should have the same keys!");
}
}
}

ConstructorHelper(const DiscreteKey& discreteKey,
const std::vector<NonlinearFactor::shared_ptr>& factors)
: discreteKeys({discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
// Extract continuous keys from the first non-null factor
for (const auto& factor : factors) {
pairs.emplace_back(factor, 0.0);
copyOrCheckContinuousKeys(factor);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}

ConstructorHelper(const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& pairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto& pair : pairs) {
copyOrCheckContinuousKeys(pair.first);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}

ConstructorHelper(const DiscreteKeys& discreteKeys,
const FactorValuePairs& factorPairs)
: discreteKeys(discreteKeys), factorTree(factorPairs) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const NonlinearFactorValuePair& pair) {
copyOrCheckContinuousKeys(pair.first);
});
}
};

/* *******************************************************************************/
HybridNonlinearFactor::ConstructorHelper::ConstructorHelper(
HybridNonlinearFactor::HybridNonlinearFactor(const ConstructorHelper& helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorTree) {}

HybridNonlinearFactor::HybridNonlinearFactor(
const DiscreteKey& discreteKey,
const std::vector<NonlinearFactor::shared_ptr>& factors)
: discreteKeys({discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
// Extract continuous keys from the first non-null factor
for (const auto& factor : factors) {
pairs.emplace_back(factor, 0.0);
CopyOrCheckContinuousKeys(factor, &continuousKeys);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}
: HybridNonlinearFactor(ConstructorHelper(discreteKey, factors)) {}

/* *******************************************************************************/
HybridNonlinearFactor::ConstructorHelper::ConstructorHelper(
HybridNonlinearFactor::HybridNonlinearFactor(
const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& pairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto& pair : pairs) {
CopyOrCheckContinuousKeys(pair.first, &continuousKeys);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}
: HybridNonlinearFactor(ConstructorHelper(discreteKey, pairs)) {}

/* *******************************************************************************/
HybridNonlinearFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKeys& discreteKeys, const FactorValuePairs& factorPairs)
: discreteKeys(discreteKeys), factorTree(factorPairs) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const NonlinearFactorValuePair& pair) {
CopyOrCheckContinuousKeys(pair.first, &continuousKeys);
});
}
HybridNonlinearFactor::HybridNonlinearFactor(const DiscreteKeys& discreteKeys,
const FactorValuePairs& factors)
: HybridNonlinearFactor(ConstructorHelper(discreteKeys, factors)) {}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
Expand Down
Loading

0 comments on commit 3e6227f

Please sign in to comment.