Skip to content

Commit

Permalink
Merge pull request #1858 from borglab/discrete-errorTree
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 16, 2024
2 parents db353a5 + f42a297 commit 77422d4
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 36 deletions.
16 changes: 0 additions & 16 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,6 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);

// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
5 changes: 1 addition & 4 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ namespace gtsam {
}

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
double error(const DiscreteValues& values) const override;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
Expand Down Expand Up @@ -292,9 +292,6 @@ namespace gtsam {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}

private:
Expand Down
16 changes: 16 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);

// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}

/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
Expand Down
11 changes: 5 additions & 6 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual double operator()(const DiscreteValues&) const = 0;

/// Error is just -log(value)
double error(const DiscreteValues& values) const;
virtual double error(const DiscreteValues& values) const;

/**
* The Factor::error simply extracts the \class DiscreteValues from the
Expand All @@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
double error(const HybridValues& c) const override;

/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
virtual AlgebraicDecisionTree<Key> errorTree() const;

/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
Expand Down Expand Up @@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
// DiscreteFactor

// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};

template <>
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};

/**
* @brief Normalize a set of log probabilities.
Expand All @@ -177,7 +177,6 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);

std::vector<double> expNormalize(const std::vector<double>& logProbs);

} // namespace gtsam
5 changes: 0 additions & 5 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
Expand Down
5 changes: 1 addition & 4 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double operator()(const DiscreteValues& values) const override;

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
double error(const DiscreteValues& values) const override;

/// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const {
Expand Down Expand Up @@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}
};

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct Switching {
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));

// Add "motion models" ϕ(X(k),X(k+1)).
// Add "motion models" ϕ(X(k),X(k+1),M(k)).
for (size_t k = 0; k < K - 1; k++) {
auto motion_models = motionModels(k, between_sigma);
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
Expand Down

0 comments on commit 77422d4

Please sign in to comment.