Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant speedup #1881

Merged
merged 12 commits into from
Oct 23, 2024
63 changes: 58 additions & 5 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <optional>
#include <set>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

Expand Down Expand Up @@ -286,6 +287,10 @@ namespace gtsam {
return branches_;
}

std::vector<NodePtr>& branches() {
return branches_;
}

/** add a branch: TODO merge into constructor */
void push_back(NodePtr&& node) {
// allSame_ is restricted to leaf nodes in a decision tree
Expand Down Expand Up @@ -482,8 +487,8 @@ namespace gtsam {
/****************************************************************************/
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {}
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}

template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
Expand Down Expand Up @@ -554,6 +559,36 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label);
}

/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Unary& op,
DecisionTree&& other) noexcept
: root_(std::move(other.root_)) {
// Apply the unary operation directly to each leaf in the tree
if (root_) {
// Define a helper function to traverse and apply the operation
struct ApplyUnary {
const Unary& op;
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
// Apply the unary operation to the leaf's constant value
leaf->constant_ = op(leaf->constant_);
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
// Recurse into the choice branches
for (NodePtr& branch : choice->branches()) {
(*this)(branch);
}
}
}
};

ApplyUnary applyUnary{op};
applyUnary(root_);
}
// Reset the other tree's root to nullptr to avoid dangling references
other.root_ = nullptr;
}

/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
Expand Down Expand Up @@ -694,7 +729,7 @@ namespace gtsam {
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
Expand All @@ -710,7 +745,7 @@ namespace gtsam {

// If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

Expand Down Expand Up @@ -951,11 +986,16 @@ namespace gtsam {
return root_->equals(*other.root_);
}

/****************************************************************************/
template<typename L, typename Y>
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
if (root_ == nullptr)
throw std::invalid_argument(
"DecisionTree::operator() called on empty tree");
return root_->operator ()(x);
}

/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
Expand All @@ -966,6 +1006,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}

/****************************************************************************/
/// Apply unary operator with assignment
template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
Expand Down Expand Up @@ -1049,6 +1090,18 @@ namespace gtsam {
return ss.str();
}

/******************************************************************************/
/******************************************************************************/
template <typename L, typename Y>
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const {
using AB = std::pair<A, B>;
const DecisionTree<L, AB> ab(*this, AB_of_Y);
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
return {a, b};
}

/******************************************************************************/

} // namespace gtsam
33 changes: 27 additions & 6 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ namespace gtsam {

/** ------------------------ Node base class --------------------------- */
struct Node {
using Ptr = std::shared_ptr<const Node>;
using Ptr = std::shared_ptr<Node>;

#ifdef DT_DEBUG_MEMORY
static int nrNodes;
Expand Down Expand Up @@ -156,10 +156,10 @@ namespace gtsam {
template <typename It, typename ValueIt>
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);

/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
/**
* Internal helper function to create a tree from keys, cardinalities, and Y
* values. Calls `build` which builds the tree bottom-up, before we prune in
* a top-down fashion.
*/
template <typename It, typename ValueIt>
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
Expand Down Expand Up @@ -228,6 +228,15 @@ namespace gtsam {
DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1);

/**
* @brief Move constructor for DecisionTree. Very efficient as does not
* allocate anything, just changes in-place. But `other` is consumed.
*
* @param op The unary operation to apply to the moved DecisionTree.
* @param other The DecisionTree to move from, will be empty afterwards.
*/
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;

/**
* @brief Convert from a different value type.
*
Expand All @@ -239,7 +248,7 @@ namespace gtsam {
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);

/**
* @brief Convert from a different value type X to value type Y, also transate
* @brief Convert from a different value type X to value type Y, also translate
* labels via map from type M to L.
*
* @tparam M Previous label type.
Expand Down Expand Up @@ -406,6 +415,18 @@ namespace gtsam {
const ValueFormatter& valueFormatter,
bool showZero = true) const;

/**
* @brief Convert into two trees with value types A and B.
*
* @tparam A First new value type.
* @tparam B Second new value type.
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
* @return A pair of DecisionTrees with value types A and B respectively.
*/
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;

/// @name Advanced Interface
/// @{

Expand Down
55 changes: 54 additions & 1 deletion gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @brief DecisionTree unit tests
* @author Frank Dellaert
* @author Can Erdogan
* @date Jan 30, 2012
Expand Down Expand Up @@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
std::cout << s;
Base::print("", keyFormatter, valueFormatter);
}

/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
Expand Down Expand Up @@ -271,6 +272,58 @@ TEST(DecisionTree, Example) {
DOT(acnotb);
}

/* ************************************************************************** */
// Test that we can create two trees out of one, using a function that returns a pair.
TEST(DecisionTree, Split) {
// Create labels
string A("A"), B("B");

// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));

// Define a function that returns an int/bool pair
auto split_function = [](const int& value) -> std::pair<int, bool> {
return {value*3, value*3 % 2 == 0};
};

// Split the original tree into two new trees
auto [la,lb] = original.split<int,bool>(split_function);

// Check the first resulting tree
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));

// Check the second resulting tree
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
}


/* ************************************************************************** */
// Test that we can create a tree by modifying an rvalue.
TEST(DecisionTree, Consume) {
// Create labels
string A("A"), B("B");

// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));

DT modified([](int i){return i*2;}, std::move(original));

// Check the first resulting tree
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));

// Check original was moved
EXPECT(original.root_ == nullptr);
}

/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };
Expand Down
Loading
Loading