Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP python bindings via pybind11 #1429

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

# vscode files
.vscode/
python/build/*
python/grf/__pycache__/*
python/grf.egg-info/*
python/grf/**/*.so
33 changes: 22 additions & 11 deletions core/src/forest/ForestPredictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,28 @@ std::vector<Prediction> ForestPredictor::predict(const Forest& forest,
const Data& data,
bool estimate_variance,
bool oob_prediction) const {
if (estimate_variance && forest.get_ci_group_size() <= 1) {
throw std::runtime_error("To estimate variance during prediction, the forest must"
" be trained with ci_group_size greater than 1.");
}

std::vector<std::vector<size_t>> leaf_nodes_by_tree = tree_traverser.get_leaf_nodes(forest, data, oob_prediction);
std::vector<std::vector<bool>> trees_by_sample = tree_traverser.get_valid_trees_by_sample(forest, data, oob_prediction);

return prediction_collector->collect_predictions(forest, train_data, data,
leaf_nodes_by_tree, trees_by_sample,
estimate_variance, oob_prediction);
// debug information
std::cout << "Entering ForestPredictor::predict" << std::endl;
std::cout << "Number of trees: " << forest.get_trees().size() << std::endl;
std::cout << "Train data dimensions: " << train_data.get_num_rows() << "x" << train_data.get_num_cols() << std::endl;
std::cout << "Test data dimensions: " << data.get_num_rows() << "x" << data.get_num_cols() << std::endl;

if (estimate_variance && forest.get_ci_group_size() <= 1) {
throw std::runtime_error("To estimate variance during prediction, the forest must"
" be trained with ci_group_size greater than 1.");
}

std::vector<std::vector<size_t>> leaf_nodes_by_tree = tree_traverser.get_leaf_nodes(forest, data, oob_prediction);
// debug
std::cout << "Leaf nodes by tree size: " << leaf_nodes_by_tree.size() << std::endl;

std::vector<std::vector<bool>> trees_by_sample = tree_traverser.get_valid_trees_by_sample(forest, data, oob_prediction);
// debug
std::cout << "Trees by sample size: " << trees_by_sample.size() << std::endl;

return prediction_collector->collect_predictions(forest, train_data, data,
leaf_nodes_by_tree, trees_by_sample,
estimate_variance, oob_prediction);
}

} // namespace grf
114 changes: 114 additions & 0 deletions python/grf/PyUtilities.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// PyUtilities.cpp

#include "PyUtilities.h"

py::dict PyUtilities::create_forest_object(Forest& forest, const std::vector<Prediction>& predictions) {
py::dict forest_dict;
forest_dict["forest"] = serialize_forest(forest);
if (!predictions.empty()) {
forest_dict["predictions"] = create_prediction_object(predictions);
}
return forest_dict;
}

py::dict PyUtilities::serialize_forest(const Forest& forest) {
py::dict forest_dict;
py::list trees;
for (const auto& tree : forest.get_trees()) {
trees.append(serialize_tree(tree));
}
forest_dict["trees"] = trees;
forest_dict["num_variables"] = forest.get_num_variables();
forest_dict["ci_group_size"] = forest.get_ci_group_size();
return forest_dict;
}

Forest PyUtilities::deserialize_forest(const py::dict& forest_object) {
py::list tree_list = forest_object["trees"].cast<py::list>();
std::vector<std::unique_ptr<Tree>> trees;
for (const auto& tree_dict : tree_list) {
trees.push_back(deserialize_tree(tree_dict.cast<py::dict>()));
}

// Extract additional parameters from forest_object
size_t num_variables = forest_object["num_variables"].cast<size_t>();
size_t ci_group_size = forest_object["ci_group_size"].cast<size_t>();

// Create a non-const reference to trees
auto& trees_ref = trees;

return Forest(trees_ref, num_variables, ci_group_size);
}

Data PyUtilities::convert_data(const py::array_t<double>& input_data) {
py::buffer_info buf = input_data.request();
if (buf.ndim != 2) {
throw std::runtime_error("Number of dimensions must be 2");
}
size_t num_rows = buf.shape[0];
size_t num_cols = buf.shape[1];
double* data_ptr = static_cast<double*>(buf.ptr);
return Data(data_ptr, num_rows, num_cols);
}

py::dict PyUtilities::create_prediction_object(const std::vector<Prediction>& predictions) {
py::dict prediction_dict;
prediction_dict["predictions"] = create_prediction_matrix(predictions);
prediction_dict["variance"] = create_variance_matrix(predictions);
prediction_dict["error"] = create_error_matrix(predictions);
prediction_dict["excess_error"] = create_excess_error_matrix(predictions);
return prediction_dict;
}

py::array_t<double> PyUtilities::create_prediction_matrix(const std::vector<Prediction>& predictions) {
// Implementation depends on the structure of Prediction class
// This is a placeholder
return py::array_t<double>();
}

py::array_t<double> PyUtilities::create_variance_matrix(const std::vector<Prediction>& predictions) {
// Implementation depends on the structure of Prediction class
// This is a placeholder
return py::array_t<double>();
}

py::array_t<double> PyUtilities::create_error_matrix(const std::vector<Prediction>& predictions) {
// Implementation depends on the structure of Prediction class
// This is a placeholder
return py::array_t<double>();
}

py::array_t<double> PyUtilities::create_excess_error_matrix(const std::vector<Prediction>& predictions) {
// Implementation depends on the structure of Prediction class
// This is a placeholder
return py::array_t<double>();
}

py::dict PyUtilities::serialize_tree(const std::unique_ptr<Tree>& tree) {
py::dict tree_dict;
tree_dict["root_node"] = tree->get_root_node();
tree_dict["child_nodes"] = tree->get_child_nodes();
tree_dict["leaf_samples"] = tree->get_leaf_samples();
tree_dict["split_vars"] = tree->get_split_vars();
tree_dict["split_values"] = tree->get_split_values();
tree_dict["drawn_samples"] = tree->get_drawn_samples();
tree_dict["send_missing_left"] = tree->get_send_missing_left();
// Skipping prediction_values for now
return tree_dict;
}

std::unique_ptr<Tree> PyUtilities::deserialize_tree(const py::dict& tree_dict) {
size_t root_node = tree_dict["root_node"].cast<size_t>();
std::vector<std::vector<size_t>> child_nodes = tree_dict["child_nodes"].cast<std::vector<std::vector<size_t>>>();
std::vector<std::vector<size_t>> leaf_samples = tree_dict["leaf_samples"].cast<std::vector<std::vector<size_t>>>();
std::vector<size_t> split_vars = tree_dict["split_vars"].cast<std::vector<size_t>>();
std::vector<double> split_values = tree_dict["split_values"].cast<std::vector<double>>();
std::vector<size_t> drawn_samples = tree_dict["drawn_samples"].cast<std::vector<size_t>>();
std::vector<bool> send_missing_left = tree_dict["send_missing_left"].cast<std::vector<bool>>();
// Using default PredictionValues for now
PredictionValues prediction_values;
return std::make_unique<Tree>(
root_node, child_nodes, leaf_samples, split_vars, split_values,
drawn_samples, send_missing_left, prediction_values
);
}
36 changes: 36 additions & 0 deletions python/grf/PyUtilities.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// PyUtilities.h

#ifndef PY_UTILITIES_H
#define PY_UTILITIES_H

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <memory>
#include <vector>
#include "forest/Forest.h"
#include "tree/Tree.h"
#include "commons/Data.h"
#include "prediction/Prediction.h"

namespace py = pybind11;
using namespace grf;

class PyUtilities {
public:
static py::dict create_forest_object(Forest& forest, const std::vector<Prediction>& predictions);
static py::dict serialize_forest(const Forest& forest);
static Forest deserialize_forest(const py::dict& forest_object);
static Data convert_data(const py::array_t<double>& input_data);
static py::dict create_prediction_object(const std::vector<Prediction>& predictions);
static py::array_t<double> create_prediction_matrix(const std::vector<Prediction>& predictions);
static py::array_t<double> create_variance_matrix(const std::vector<Prediction>& predictions);
static py::array_t<double> create_error_matrix(const std::vector<Prediction>& predictions);
static py::array_t<double> create_excess_error_matrix(const std::vector<Prediction>& predictions);

private:
static py::dict serialize_tree(const std::unique_ptr<Tree>& tree);
static std::unique_ptr<Tree> deserialize_tree(const py::dict& tree_dict);
};

#endif // PY_UTILITIES_H
1 change: 1 addition & 0 deletions python/grf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .regression import RegressionForest
109 changes: 109 additions & 0 deletions python/grf/_grf_python.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// _grf_python.cpp

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "forest/ForestTrainers.h"
#include "forest/ForestPredictors.h"
#include "PyUtilities.h"

namespace py = pybind11;
using namespace grf;

py::object regression_train(
py::array_t<double> train_matrix,
size_t outcome_index,
size_t sample_weight_index,
bool use_sample_weights,
unsigned int mtry,
unsigned int num_trees,
unsigned int min_node_size,
double sample_fraction,
bool honesty,
double honesty_fraction,
bool honesty_prune_leaves,
size_t ci_group_size,
double alpha,
double imbalance_penalty,
std::vector<size_t> clusters,
unsigned int samples_per_cluster,
bool compute_oob_predictions,
unsigned int num_threads,
unsigned int seed)
{
ForestTrainer trainer = regression_trainer();

Data data = PyUtilities::convert_data(train_matrix);
data.set_outcome_index(outcome_index);
if (use_sample_weights) {
data.set_weight_index(sample_weight_index);
}

ForestOptions options(num_trees, ci_group_size, sample_fraction, mtry, min_node_size, honesty,
honesty_fraction, honesty_prune_leaves, alpha, imbalance_penalty, num_threads, seed, clusters, samples_per_cluster);
Forest forest = trainer.train(data, options);

std::vector<Prediction> predictions;
if (compute_oob_predictions) {
ForestPredictor predictor = regression_predictor(num_threads);
predictions = predictor.predict_oob(forest, data, false);
}

return PyUtilities::create_forest_object(forest, predictions);
}

py::dict regression_predict(
py::object forest_object,
py::array_t<double> train_matrix,
size_t outcome_index,
py::array_t<double> test_matrix,
unsigned int num_threads,
bool estimate_variance)
{
Data train_data = PyUtilities::convert_data(train_matrix);
train_data.set_outcome_index(outcome_index);

Data data = PyUtilities::convert_data(test_matrix);
Forest forest = PyUtilities::deserialize_forest(forest_object);

// Add debug information
std::cout << "Number of trees in forest: " << forest.get_trees().size() << std::endl;
std::cout << "Number of variables: " << forest.get_num_variables() << std::endl;
std::cout << "Train data dimensions: " << train_data.get_num_rows() << "x" << train_data.get_num_cols() << std::endl;
std::cout << "Test data dimensions: " << data.get_num_rows() << "x" << data.get_num_cols() << std::endl;

ForestPredictor predictor = regression_predictor(num_threads);
std::vector<Prediction> predictions;
try {
predictions = predictor.predict(forest, train_data, data, estimate_variance);
} catch (const std::exception& e) {
std::cerr << "Error in predict: " << e.what() << std::endl;
throw;
}

return PyUtilities::create_prediction_object(predictions);
}

py::dict regression_predict_oob(
py::object forest_object,
py::array_t<double> train_matrix,
size_t outcome_index,
unsigned int num_threads,
bool estimate_variance)
{
Data data = PyUtilities::convert_data(train_matrix);
data.set_outcome_index(outcome_index);

Forest forest = PyUtilities::deserialize_forest(forest_object);

ForestPredictor predictor = regression_predictor(num_threads);
std::vector<Prediction> predictions = predictor.predict_oob(forest, data, estimate_variance);

return PyUtilities::create_prediction_object(predictions);
}

PYBIND11_MODULE(_grf_python, m) {
m.def("regression_train", &regression_train, "Train a regression forest");
m.def("regression_predict", &regression_predict, "Predict using a regression forest");
m.def("regression_predict_oob", &regression_predict_oob, "Predict OOB using a regression forest");
}
Loading