From bb22831662641ae9d99bc56456a8243a55c6604a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Oct 2024 12:14:48 -0400 Subject: [PATCH 1/2] implement errorTree in DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 16 ---------------- gtsam/discrete/DecisionTreeFactor.h | 5 +---- gtsam/discrete/DiscreteFactor.cpp | 16 ++++++++++++++++ gtsam/discrete/DiscreteFactor.h | 11 +++++------ gtsam/discrete/TableFactor.cpp | 5 ----- gtsam/discrete/TableFactor.h | 5 +---- 6 files changed, 23 insertions(+), 35 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index d1b68f4bfb..68c09295c9 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,22 +62,6 @@ namespace gtsam { return error(values.discrete()); } - /* ************************************************************************ */ - AlgebraicDecisionTree DecisionTreeFactor::errorTree() const { - // Get all possible assignments - DiscreteKeys dkeys = discreteKeys(); - // Reverse to make cartesian product output a more natural ordering. - DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); - const auto assignments = DiscreteValues::CartesianProduct(rdkeys); - - // Construct vector with error values - std::vector errors; - for (const auto& assignment : assignments) { - errors.push_back(error(assignment)); - } - return AlgebraicDecisionTree(dkeys, errors); - } - /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 784b11e518..07d2cac149 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -141,7 +141,7 @@ namespace gtsam { } /// Calculate error for DiscreteValues `x`, is -log(probability). - double error(const DiscreteValues& values) const; + double error(const DiscreteValues& values) const override; /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { @@ -292,9 +292,6 @@ namespace gtsam { */ double error(const HybridValues& values) const override; - /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree errorTree() const override; - /// @} private: diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index b44d4fce2e..2b11046f44 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const { return this->error(c.discrete()); } +/* ************************************************************************ */ +AlgebraicDecisionTree DiscreteFactor::errorTree() const { + // Get all possible assignments + DiscreteKeys dkeys = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rdkeys); + + // Construct vector with error values + std::vector errors; + for (const auto& assignment : assignments) { + errors.push_back(error(assignment)); + } + return AlgebraicDecisionTree(dkeys, errors); +} + /* ************************************************************************* */ std::vector expNormalize(const std::vector& logProbs) { double maxLogProb = -std::numeric_limits::infinity(); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 771efbe5b4..19af5bd131 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual double operator()(const DiscreteValues&) const = 0; /// Error is just -log(value) - double error(const DiscreteValues& values) const; + virtual double error(const DiscreteValues& values) const; /** * The Factor::error simply extracts the \class DiscreteValues from the @@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { double error(const HybridValues& c) const override; /// Compute error for each assignment and return as a tree - virtual AlgebraicDecisionTree errorTree() const = 0; + virtual AlgebraicDecisionTree errorTree() const; /// Multiply in a DecisionTreeFactor and return the result as /// DecisionTreeFactor @@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { // DiscreteFactor // traits -template<> struct traits : public Testable {}; - +template <> +struct traits : public Testable {}; /** * @brief Normalize a set of log probabilities. @@ -177,7 +177,6 @@ template<> struct traits : public Testable {}; * of the (unnormalized) log probabilities are either very large or very * small. */ -std::vector expNormalize(const std::vector &logProbs); - +std::vector expNormalize(const std::vector& logProbs); } // namespace gtsam diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index b360617f56..f4e023a4da 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const { return error(values.discrete()); } -/* ************************************************************************ */ -AlgebraicDecisionTree TableFactor::errorTree() const { - return toDecisionTreeFactor().errorTree(); -} - /* ************************************************************************ */ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 228b363376..f0ecd66a3f 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { double operator()(const DiscreteValues& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). - double error(const DiscreteValues& values) const; + double error(const DiscreteValues& values) const override; /// multiply two TableFactors TableFactor operator*(const TableFactor& f) const { @@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ double error(const HybridValues& values) const override; - /// Compute error for each assignment and return as a tree - AlgebraicDecisionTree errorTree() const override; - /// @} }; From f42a297a40ebe5af59e4eec113bac333cfd08a2c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Oct 2024 12:15:04 -0400 Subject: [PATCH 2/2] fix docstring --- gtsam/hybrid/tests/Switching.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 547facce9a..3d8174db7e 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -158,7 +158,7 @@ struct Switching { nonlinearFactorGraph.emplace_shared>( X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma)); - // Add "motion models" ϕ(X(k),X(k+1)). + // Add "motion models" ϕ(X(k),X(k+1),M(k)). for (size_t k = 0; k < K - 1; k++) { auto motion_models = motionModels(k, between_sigma); nonlinearFactorGraph.emplace_shared(modes[k],