Skip to content

Commit

Permalink
Merge pull request #1854 from borglab/feature/more_testing
Browse files Browse the repository at this point in the history
More tests, some small bugfixes
  • Loading branch information
dellaert authored Sep 30, 2024
2 parents ffb2829 + 3cd8163 commit caa3821
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 117 deletions.
7 changes: 6 additions & 1 deletion gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
(*it)->sampleInPlace(&result);
const DiscreteConditional::shared_ptr& conditional = *it;
// Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) {
conditional->sampleInPlace(&result);
}
}
return result;
}
Expand Down
18 changes: 13 additions & 5 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {

/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
// throw if more than one frontal:
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace can only be called on single "
"variable conditionals");
}
Key j = firstFrontalKey();
// throw if values already contains j:
if (values->count(j) > 0) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j");
}
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
Expand Down Expand Up @@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
}

/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const {
return 0.0;
}
double DiscreteConditional::negLogConstant() const { return 0.0; }

/* ************************************************************************* */

Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter);
}

/// Evaluate, just look up in AlgebraicDecisonTree
/// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ GaussianBayesNet HybridBayesNet::choose(
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment.
gbn.push_back((*gm)(assignment));
gbn.push_back(gm->choose(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back(gc);
Expand Down
2 changes: 2 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.
*
* @note Any pure discrete factors are ignored.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesNet
*/
Expand Down
11 changes: 5 additions & 6 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ size_t HybridGaussianConditional::nrComponents() const {
}

/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr;
Expand All @@ -192,11 +192,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
conditionals_.equals(
e->conditionals_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
Expand Down
8 changes: 7 additions & 1 deletion gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,15 @@ class GTSAM_EXPORT HybridGaussianConditional
/// @{

/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
GaussianConditional::shared_ptr choose(
const DiscreteValues &discreteValues) const;

/// @brief Syntactic sugar for choose.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const {
return choose(discreteValues);
}

/// Returns the total number of continuous components
size_t nrComponents() const;

Expand Down
24 changes: 11 additions & 13 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
Expand Down Expand Up @@ -213,16 +212,15 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
}

/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
/// Helper method to compute the error of a component.
static double PotentiallyPrunedComponentError(
const GaussianFactor::shared_ptr &gf, const VectorValues &values) {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
// If nullptr this component was pruned, so we return maximum error. This
// way the negative exponential will give a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
Expand All @@ -231,8 +229,8 @@ double HybridGaussianFactor::potentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return PotentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
Expand All @@ -242,7 +240,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return potentiallyPrunedComponentError(gf, values.continuous());
return PotentiallyPrunedComponentError(gf, values.continuous());
}

} // namespace gtsam
4 changes: 0 additions & 4 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
static Factors augment(const FactorValuePairs &factors);

/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;

/// Helper struct to assist private constructor below.
struct ConstructorHelper;

Expand Down
92 changes: 34 additions & 58 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
static void printFactor(const std::shared_ptr<Factor> &factor,
const DiscreteValues &assignment,
const KeyFormatter &keyFormatter) {
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
hgf->operator()(assignment)
->print("HybridGaussianFactor, component:", keyFormatter);
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
factor->print("GaussianFactor:\n", keyFormatter);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
factor->print("DiscreteFactor:\n", keyFormatter);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (hc->isContinuous()) {
factor->print("GaussianConditional:\n", keyFormatter);
} else if (hc->isDiscrete()) {
factor->print("DiscreteConditional:\n", keyFormatter);
} else {
hc->asHybrid()
->choose(assignment)
->print("HybridConditional, component:\n", keyFormatter);
}
} else {
factor->print("Unknown factor type\n", keyFormatter);
}
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
Expand All @@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = " << factor->error(values) << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
<< "\n";
} else {
// Is hybrid
auto conditionalComponent =
hc->asHybrid()->operator()(values.discrete());
conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << df->error(values.discrete()) << std::endl;
}

} else {
if (factor == nullptr) {
std::cout << "Factor " << i << ": nullptr\n";
continue;
}
const double errorValue = factor->error(values);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

// Print the factor
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
printFactor(factor, values.discrete(), keyFormatter);
std::cout << "\n";
}
std::cout.flush();
Expand Down
5 changes: 5 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
};

// traits
template <>
struct traits<HybridGaussianFactorGraph>
: public Testable<HybridGaussianFactorGraph> {};

} // namespace gtsam
21 changes: 10 additions & 11 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
return {new_order, levels};
}

/* ***************************************************************************
*/
/* ****************************************************************************/
using MotionModel = BetweenFactor<double>;

// Test fixture with switching network.
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1))
struct Switching {
size_t K;
DiscreteKeys modes;
Expand All @@ -140,8 +140,8 @@ struct Switching {
: K(K) {
using noiseModel::Isotropic;

// Create DiscreteKeys for binary K modes.
for (size_t k = 0; k < K; k++) {
// Create DiscreteKeys for K-1 binary modes.
for (size_t k = 0; k < K - 1; k++) {
modes.emplace_back(M(k), 2);
}

Expand All @@ -153,34 +153,33 @@ struct Switching {
}

// Create hybrid factor graph.
// Add a prior on X(0).

// Add a prior ϕ(X(0)) on X(0).
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));

// Add "motion models".
// Add "motion models" ϕ(X(k),X(k+1)).
for (size_t k = 0; k < K - 1; k++) {
auto motion_models = motionModels(k, between_sigma);
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
motion_models);
}

// Add measurement factors
// Add measurement factors ϕ(X(k);z_k).
auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
for (size_t k = 1; k < K; k++) {
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(k), measurements.at(k), measurement_noise);
}

// Add "mode chain"
// Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2))
addModeChain(&nonlinearFactorGraph, discrete_transition_prob);

// Create the linearization point.
for (size_t k = 0; k < K; k++) {
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
}

// The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
}

Expand All @@ -196,7 +195,7 @@ struct Switching {
}

/**
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1).
* E.g. if K=4, we want M0, M1 and M2.
*
* @param fg The factor graph to which the mode chain is added.
Expand Down
Loading

0 comments on commit caa3821

Please sign in to comment.