Skip to content

Commit

Permalink
Python export_json method introduced
Browse files Browse the repository at this point in the history
pybind11 code simplified: forest* everywhere, removed excess functions
Using py::return_value_policy::reference to prevent Python GC from deleting the trained tree object
  • Loading branch information
Ilia-Shutov committed Aug 11, 2023
1 parent 04034eb commit adcdfe8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 60 deletions.
35 changes: 15 additions & 20 deletions Python/extension/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ extern "C" {
}


void* train_forest(
forestry* train_forest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -247,7 +247,7 @@ extern "C" {
}

void predict_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
unsigned int seed,
Expand All @@ -264,9 +264,6 @@ extern "C" {
bool hier_shrinkage,
double lambda_shrinkage
) {


forestry* forest = reinterpret_cast<forestry *>(forest_pt);
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);

forest->setDataframe(dta_frame);
Expand Down Expand Up @@ -376,7 +373,7 @@ extern "C" {


void predictOOB_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
bool doubleOOB,
Expand All @@ -390,9 +387,8 @@ extern "C" {
double lambda_shrinkage
) {
if (verbose)
std::cout << forest_pt << std::endl;
std::cout << forest << std::endl;

forestry* forest = reinterpret_cast<forestry *>(forest_pt);
DataFrame* dta_frame = reinterpret_cast<DataFrame *>(dataframe_pt);
forest->setDataframe(dta_frame);

Expand Down Expand Up @@ -456,15 +452,13 @@ extern "C" {
}

void fill_tree_info(
void* forest_ptr,
forestry* forest,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
) {

forestry* forest = reinterpret_cast<forestry *>(forest_ptr);

std::unique_ptr<tree_info> info_holder;

info_holder = forest->getForest()->at(tree_idx)->getTreeInfo(forest->getTrainingData());
Expand Down Expand Up @@ -710,24 +704,25 @@ extern "C" {
return forest;

}

size_t get_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);

size_t get_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getNodeCount());
}

size_t get_split_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);
size_t get_split_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getSplitNodeCount());
}

size_t get_leaf_node_count(void* forest_pt, int tree_idx) {
forestry* forest = reinterpret_cast<forestry *>(forest_pt);
size_t get_leaf_node_count(forestry* forest, int tree_idx) {
return(forest->getForest()->at(tree_idx)->getLeafNodeCount());
}

void delete_forestry(void* forest_pt, void* dataframe_pt) {
void delete_forestry(forestry* forest, void* dataframe_pt) {
delete(reinterpret_cast<DataFrame* >(dataframe_pt));
delete(reinterpret_cast<forestry* >(forest_pt));
delete(forest);
}
}

std::string export_json(forestry* forest, const std::vector<double>& colSds, const std::vector<double>& colMeans) {
return exportJson(*forest, colSds, colMeans);
}
34 changes: 13 additions & 21 deletions Python/extension/api.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#pragma once

#include <vector>
#include <string>
#include <iostream>
#include <random>
#include "forestry.h"
#include "utils.h"

#ifndef FORESTRYCPP_API_H
#define FORESTRYCPP_API_H

#endif //FORESTRYCPP_API_H


extern "C" {
void* train_forest(
forestry* train_forest(
void* data_ptr,
size_t ntree,
bool replace,
Expand Down Expand Up @@ -101,7 +98,7 @@ extern "C" {
unsigned int* tree_seeds
);
void predictOOB_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
bool doubleOOB,
Expand All @@ -115,7 +112,7 @@ extern "C" {
double lambda_shrinkage
);
void predict_forest(
void* forest_pt,
forestry* forest,
void* dataframe_pt,
double* test_data,
unsigned int seed,
Expand All @@ -133,21 +130,16 @@ extern "C" {
double lambda_shrinkage = 0
);
void fill_tree_info(
void* forest_ptr,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
);
void fill_tree_info(
void* forest_ptr,
forestry* forest,
int tree_idx,
std::vector<double>& treeInfo,
std::vector<int>& split_info,
std::vector<int>& av_info
);
size_t get_node_count(void* forest_pt, int tree_idx);
size_t get_split_node_count(void* forest_pt, int tree_idx);
size_t get_leaf_node_count(void* forest_pt, int tree_idx);
void delete_forestry(void* forest_pt, void* dataframe_pt);
}
size_t get_node_count(forestry* forest, int tree_idx);
size_t get_split_node_count(forestry* forest, int tree_idx);
size_t get_leaf_node_count(forestry* forest, int tree_idx);
void delete_forestry(forestry* forest, void* dataframe_pt);
}

std::string export_json(forestry* forest, const std::vector<double>& colSds, const std::vector<double>& colMeans);
43 changes: 24 additions & 19 deletions Python/extension/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void *reconstructree_wrapper(
}

py::tuple predictOOB_forest_wrapper(
void *forest_pt,
forestry* forest,
void *dataframe_pt,
py::array_t<double> test_data,
bool doubleOOB,
Expand All @@ -192,7 +192,7 @@ py::tuple predictOOB_forest_wrapper(
}

predictOOB_forest(
forest_pt,
forest,
dataframe_pt,
static_cast<double *>(test_data.request().ptr),
doubleOOB,
Expand Down Expand Up @@ -232,7 +232,7 @@ void show_array(std::vector<double> array) {
}

py::tuple predict_forest_wrapper(
void *forest_pt,
forestry* forest,
void *dataframe_pt,
py::array_t<double> test_data,
unsigned int seed,
Expand All @@ -259,7 +259,7 @@ py::tuple predict_forest_wrapper(
std::vector<double> coefficients_vector(n_coefficients);

predict_forest(
forest_pt,
forest,
dataframe_pt,
static_cast<double *>(test_data.request().ptr),
seed,
Expand All @@ -285,7 +285,7 @@ py::tuple predict_forest_wrapper(
}

void fill_tree_info_wrapper(
void *forest_ptr,
forestry* forest,
int tree_idx,
py::array_t<double> treeInfo,
py::array_t<int> split_info,
Expand All @@ -296,7 +296,7 @@ void fill_tree_info_wrapper(
auto av_info_vector = create_vector_from_numpy_array(av_info);

fill_tree_info(
forest_ptr,
forest,
tree_idx,
treeInfo_vector,
split_info_vector,
Expand All @@ -308,20 +308,22 @@ void fill_tree_info_wrapper(
copy_vector_to_numpy_array(av_info_vector, av_info);
}

size_t getTreeNodeCount(void *forest_ptr, int tree_idx) {
return get_node_count(forest_ptr,tree_idx);
}

size_t getTreeSplitNodeCount(void *forest_ptr, int tree_idx) {
return get_split_node_count(forest_ptr,tree_idx);
}
std::string export_json_wrapper(forestry* forest, py::array_t<double> colSdsNp, py::array_t<double> colMeansNp) {
auto colSds = create_vector_from_numpy_array(colSdsNp);
auto colMeans = create_vector_from_numpy_array(colMeansNp);

size_t getTreeLeafNodeCount(void *forest_ptr, int tree_idx) {
return get_leaf_node_count(forest_ptr,tree_idx);
return export_json(forest, colSds, colMeans);
}

PYBIND11_MODULE(extension, m)
{
py::class_<forestry>(m, "forestry", py::dynamic_attr())
.def(py::init([]() {
throw py::value_error("forestry instances cannot be created from python");
return nullptr;
})
);

m.doc() = R"pbdoc(
RForestry Python extension module
-----------------------
Expand All @@ -334,7 +336,7 @@ PYBIND11_MODULE(extension, m)
vector_get
)pbdoc";

m.def("train_forest", &train_forest, R"pbdoc(
m.def("train_forest", &train_forest, py::return_value_policy::reference, R"pbdoc(
Some help text here
Some other explanation about the train_forest function.
Expand All @@ -344,17 +346,17 @@ PYBIND11_MODULE(extension, m)
Some other explanation about the get_data function.
)pbdoc");
m.def("get_tree_node_count", &getTreeNodeCount, R"pbdoc(
m.def("get_tree_node_count", &get_node_count, R"pbdoc(
Some help text here
Some other explanation about the getTreeNodeCount function.
)pbdoc");
m.def("get_tree_split_count", &getTreeSplitNodeCount, R"pbdoc(
m.def("get_tree_split_count", &get_split_node_count, R"pbdoc(
Some help text here
Some other explanation about the getTreeSplitNodeCount function.
)pbdoc");
m.def("get_tree_leaf_count", &getTreeLeafNodeCount, R"pbdoc(
m.def("get_tree_leaf_count", &get_leaf_node_count, R"pbdoc(
Some help text here
Some other explanation about the getTreeLeafNodeCount function.
Expand Down Expand Up @@ -384,6 +386,9 @@ PYBIND11_MODULE(extension, m)
Some other explanation about the fill_tree_info function.
)pbdoc");
m.def("export_json", &export_json_wrapper, R"pbdoc(
Export forest to Treelite JSON string
)pbdoc");

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
Expand Down
6 changes: 6 additions & 0 deletions Python/random_forestry/forestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,12 @@ def load_forestry(filename: Path) -> Self:
with open(filename, "rb") as input_file:
return pickle.load(input_file) # nosec B301

def export_json(self):
"""
Export forest to Treelite JSON string
"""
return extension.export_json(self.forest, self.processed_dta.col_sd, self.processed_dta.col_means)

def __del__(self):
# Free the pointers to foretsry and dataframe
extension.delete_forestry(self.forest, self.dataframe)

0 comments on commit adcdfe8

Please sign in to comment.