Skip to content

Commit

Permalink
Merge pull request #1590 from borglab/hybrid-tablefactor-3
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jul 27, 2023
2 parents 4b0f386 + e649fc6 commit c5740b2
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 30 deletions.
71 changes: 55 additions & 16 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ namespace gtsam {
ADT::print("", formatter);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
Expand All @@ -101,14 +117,6 @@ namespace gtsam {
return DecisionTreeFactor(keys, result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
Expand Down Expand Up @@ -188,10 +196,45 @@ namespace gtsam {

/* ************************************************************************ */
std::vector<double> DecisionTreeFactor::probabilities() const {
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());

std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}

/* An operation that takes each leaf probability, and computes the
* nrAssignments by checking the difference between the keys in the factor
* and the keys in the assignment.
* The nrAssignments is then used to append
* the correct number of leaf probability values to the `probs` vector
* defined above.
*/
auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : a) {
assignment_keys.insert(k);
}

// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));

// Compute the total number of assignments in the (pruned) subtree
size_t nrAssignments = 1;
for (auto&& k : diff) {
nrAssignments *= cardinalities_.at(k);
}
// Add p `nrAssignments` times to the probs vector.
probs.insert(probs.end(), nrAssignments, p);

return p;
};

// Go through the tree
this->apply(op);

return probs;
}

Expand Down Expand Up @@ -305,11 +348,7 @@ namespace gtsam {
const size_t N = maxNrAssignments;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}
std::vector<double> probabilities = this->probabilities();

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
Expand Down
7 changes: 7 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ namespace gtsam {
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::Unary op) const;

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;

/**
Expand Down
38 changes: 37 additions & 1 deletion gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTree<Key, double>& dtree)
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}

/**
* @brief Compute the correct ordering of the leaves in the decision tree.
*
* This is done by first taking all the values which have modulo 0 value with
* the cardinality of the innermost key `n`, and we go up to modulo n.
*
* @param dt The DecisionTree
* @return std::vector<double>
*/
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
std::vector<double> probs = dt.probabilities();
std::vector<double> ordered;

size_t n = dkeys[0].second;

for (size_t k = 0; k < n; ++k) {
for (size_t idx = 0; idx < probs.size(); ++idx) {
if (idx % n == k) {
ordered.push_back(probs[idx]);
}
}
}
return ordered;
}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dtf)
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
: TableFactor(c.discreteKeys(), c) {}

/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
Expand Down
8 changes: 7 additions & 1 deletion gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}

/// Constructor from DecisionTreeFactor
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);

/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);

/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);

Expand Down Expand Up @@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul);
};

/// multiple with DecisionTreeFactor
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

static double safe_div(const double& a, const double& b);
Expand Down
11 changes: 11 additions & 0 deletions gtsam/discrete/tests/testTableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableFactor.h>
Expand Down Expand Up @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) {
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));

// Test for 9=3x3 values.
DiscreteKey V(0, 3), W(1, 3);
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
TableFactor f5(conditional5);
// GTSAM_PRINT(f5);
TableFactor expected_f5(
X & Y,
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
EXPECT(assert_equal(expected_f5, f5, 1e-6));
}

/* ************************************************************************* */
Expand Down
2 changes: 0 additions & 2 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(discreteProbs);
Expand Down
7 changes: 1 addition & 6 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;

Expand All @@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
discrete_factor_idxs.push_back(i);
}
}

const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);

// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

Expand All @@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

return prunedDiscreteProbs;
}
Expand All @@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {

HybridBayesNet prunedBayesNetFragment;

gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
Expand All @@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional);
}
}
gttoc_(HybridBayesNet_PruneMixtures);

return prunedBayesNetFragment;
}
Expand Down
3 changes: 0 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic_(assembleGraphTree);

GaussianFactorGraphTree result;

Expand Down Expand Up @@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
}
}

gttoc_(assembleGraphTree);

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}
Expand Down

0 comments on commit c5740b2

Please sign in to comment.