Skip to content

Commit

Permalink
Merge pull request #1575 from borglab/hybrid-tablefactor-2
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jul 19, 2023
2 parents b56a04e + cb084b3 commit ba7c077
Show file tree
Hide file tree
Showing 24 changed files with 298 additions and 173 deletions.
13 changes: 12 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
}

/** Write graphviz format to stream `os`. */
Expand Down Expand Up @@ -827,6 +828,16 @@ namespace gtsam {
return total;
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
36 changes: 36 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,42 @@ namespace gtsam {
/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
8 changes: 8 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ namespace gtsam {
return DecisionTreeFactor(keys, result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ namespace gtsam {
/// @name Advanced Interface
/// @{

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;

/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Symbol.h>

#include <iomanip>

Expand Down
89 changes: 35 additions & 54 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;

// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor;

for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
}
}
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
Expand Down Expand Up @@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;

std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;

// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();

// Convert pointer from conditional to factor
auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));

gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);

// Add it back to the BayesNet
this->at(i) = conditional;
discreteProbs = discreteProbs * (*conditional->asDiscrete());

Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);

// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

/* To Prune, we visitWith every leaf in the GaussianMixture.
return prunedDiscreteProbs;
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);

/* To prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
Expand Down
13 changes: 3 additions & 10 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

/**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
Expand Down Expand Up @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

private:
/**
* @brief Update the discrete conditionals with the pruned versions.
* @brief Prune all the discrete conditionals.
*
* @param prunedDiscreteProbs
* @param maxNrLeaves
*/
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/Values.h>
Expand Down
5 changes: 3 additions & 2 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
* @date January, 2023
*/

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>

namespace gtsam {
Expand All @@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys;
for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
Expand Down Expand Up @@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
}
}
return keys;
Expand Down
Loading

0 comments on commit ba7c077

Please sign in to comment.