Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc/hier shrinkage #139

Merged
merged 20 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9118a99
Made modifications for non-leaf nodes to store averaage set counts an…
sidc321 Jun 27, 2023
1ae3fc8
Trees can now be saved and reconstructed
sidc321 Jun 29, 2023
cab8352
Changed parameter names to conform to existing format. Removed unnecc…
sidc321 Jun 29, 2023
206a699
Added unit tests to check weightsFull and average_counts has correct …
sidc321 Jun 30, 2023
e960086
Cleanup comments, new lines,etc
sidc321 Jun 30, 2023
cd4560e
Cleaned up naming convention to match existing
sidc321 Jun 30, 2023
d6d69b2
Hierarchical shrinkage working, still need to update R interface to a…
sidc12321 Jul 7, 2023
8883f95
Fixed hierarchical shrinkage estimates and created basic tests
sidc321 Jul 11, 2023
a15848c
updated the weight matrix to respect HS and added tests
sidc321 Jul 25, 2023
0c22d2f
Updated documentation to pass check
sidc321 Jul 25, 2023
25bc6a9
Deprecated smaller weigth vector and removed overloading of var_id to…
sidc321 Jul 27, 2023
fd19333
Python package successfully builds
sidc321 Jul 27, 2023
4ef3814
fixed indexing issue due to var_id now having a shorter length
sidc321 Jul 30, 2023
dd36bc5
black linting
sidc321 Jul 30, 2023
e6097af
Hierarchical shrinkage implemented in Python package, basic tests are…
sidc321 Aug 1, 2023
8f52104
Linting
sidc321 Aug 1, 2023
eaa6424
cleanup
sidc321 Aug 1, 2023
aabc204
Added better commenting for tests
sidc321 Aug 1, 2023
4c12140
hierarchical shrinkage parameters are renamed (camel case), code modi…
sidc321 Aug 2, 2023
15d502c
reformatted variable names and cleaned up hierShrinkgeLambda cehcking…
sidc321 Aug 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions Python/extension/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ extern "C" {
size_t num_test_rows,
std::vector<double>& predictions,
std::vector<double>& weight_matrix,
std::vector<double>& coefs
std::vector<double>& coefs,
bool hier_shrinkage,
double lambda_shrinkage
) {


Expand Down Expand Up @@ -311,7 +313,9 @@ extern "C" {
nthread,
exact,
false,
nullptr
nullptr,
hier_shrinkage,
lambda_shrinkage
);

size_t idx = 0;
Expand Down Expand Up @@ -356,7 +360,9 @@ extern "C" {
nthread,
exact,
use_weights,
weights
weights,
hier_shrinkage,
lambda_shrinkage
);

}
Expand All @@ -379,7 +385,9 @@ extern "C" {
bool verbose,
std::vector<double>& predictions,
std::vector<double>& weight_matrix,
std::vector<size_t> training_idx
std::vector<size_t> training_idx,
bool hier_shrinkage,
double lambda_shrinkage
) {
if (verbose)
std::cout << forest_pt << std::endl;
Expand Down Expand Up @@ -416,7 +424,9 @@ extern "C" {
&treeCounts,
doubleOOB,
exact,
training_idx_use
training_idx_use,
hier_shrinkage,
lambda_shrinkage
);

size_t idx = 0;
Expand All @@ -435,7 +445,9 @@ extern "C" {
nullptr,
doubleOOB,
exact,
training_idx_use
training_idx_use,
hier_shrinkage,
lambda_shrinkage
);
}

Expand All @@ -457,21 +469,21 @@ extern "C" {

info_holder = forest->getForest()->at(tree_idx)->getTreeInfo(forest->getTrainingData());
int num_nodes = forest->getForest()->at(tree_idx)->getNodeCount();
int num_leaf_nodes = forest->getForest()->at(tree_idx)->getLeafNodeCount();

for (int i = 0; i < num_nodes + num_leaf_nodes; i++) {
for (int i = 0; i < num_nodes; i++) {
treeInfo[i] = (double)info_holder->var_id.at(i);
}

for (int i = 0; i < num_leaf_nodes; i++) {
treeInfo[num_nodes + num_leaf_nodes + i] = info_holder->values.at(i);
for (int i = 0; i < num_nodes; i++) {
treeInfo[num_nodes + i] = info_holder->values.at(i);
}

for (int i = 0; i < num_nodes; i++) {
treeInfo[num_nodes + num_leaf_nodes*2 + i] = info_holder->split_val.at(i);
treeInfo[num_nodes *2 + num_leaf_nodes*2 + i] = (double)info_holder->naLeftCount.at(i);
treeInfo[num_nodes *3 + num_leaf_nodes*2 + i] = (double)info_holder->naRightCount.at(i);
treeInfo[num_nodes *4 + num_leaf_nodes*2 + i] = (double)info_holder->naDefaultDirection.at(i);
treeInfo[num_nodes *2 + i] = info_holder->split_val.at(i);
treeInfo[num_nodes *3 + i] = (double)info_holder->naLeftCount.at(i);
treeInfo[num_nodes *4 + i] = (double)info_holder->naRightCount.at(i);
treeInfo[num_nodes *5 + i] = (double)info_holder->naDefaultDirection.at(i);
treeInfo[num_nodes *6 + i] = (double)info_holder->average_count.at(i);
treeInfo[num_nodes *7 + i] = (double)info_holder->split_count.at(i);
}

// Populate splitting samples for the tree
Expand All @@ -488,7 +500,7 @@ extern "C" {
av_info[i+1] = info_holder->averagingSampleIndex.at(i);
}

treeInfo[num_nodes *5 + num_leaf_nodes*2] = info_holder->seed;
treeInfo[num_nodes *8] = info_holder->seed;
}


Expand Down Expand Up @@ -578,6 +590,12 @@ extern "C" {
std::unique_ptr< std::vector< std::vector<int> > > var_ids(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > average_counts(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<int> > > split_counts(
new std::vector< std::vector<int> >
);
std::unique_ptr< std::vector< std::vector<double> > > split_vals(
new std::vector< std::vector<double> >
);
Expand Down Expand Up @@ -608,6 +626,8 @@ extern "C" {

// Reserve space for each of the vectors equal to ntree
var_ids->reserve(ntree);
average_counts->reserve(ntree);
split_counts->reserve(ntree);
split_vals->reserve(ntree);
averagingSampleIndex->reserve(ntree);
splittingSampleIndex->reserve(ntree);
Expand All @@ -619,36 +639,32 @@ extern "C" {
predictWeights->reserve(ntree);

// Now actually populate the vectors
size_t ind = 0, ind_s = 0, ind_a = 0, ind_var = 0, ind_weights = 0;
size_t ind = 0, ind_s = 0, ind_a = 0;
for(size_t i = 0; i < ntree; i++){
// Should be num total nodes + num leaf nodes
std::vector<int> cur_var_ids((tree_counts[4*i]+tree_counts[4*i+3]), 0);
// Should be num total nodes
std::vector<int> cur_var_ids((tree_counts[4*i]), 0);
std::vector<int> cur_average_counts((tree_counts[4*i]), 0);
std::vector<int> cur_split_counts((tree_counts[4*i]), 0);
std::vector<double> cur_split_vals(tree_counts[4*i], 0);
std::vector<int> curNaLeftCounts(tree_counts[4*i], 0);
std::vector<int> curNaRightCounts(tree_counts[4*i], 0);
std::vector<int> curNaDefaultDirections(tree_counts[4*i], 0);
std::vector<size_t> curSplittingSampleIndex(tree_counts[4*i+1], 0);
std::vector<size_t> curAveragingSampleIndex(tree_counts[4*i+2], 0);
std::vector<double> cur_predict_weights(tree_counts[4*i+3], 0);
std::vector<double> cur_predict_weights(tree_counts[4*i], 0);

for(size_t j = 0; j < tree_counts[4*i]; j++){
cur_split_vals.at(j) = thresholds[ind];
curNaLeftCounts.at(j) = na_left_count[ind];
curNaRightCounts.at(j) = na_right_count[ind];
curNaDefaultDirections.at(j) = na_default_directions[ind];
cur_predict_weights.at(j) = predict_weights[ind];
cur_var_ids.at(j) = features[ind];
cur_average_counts.at(j) = features[ind];
cur_split_counts.at(j) = features[ind];

ind++;
}

for (size_t j = 0; j < tree_counts[4*i+3]; j++){
cur_predict_weights.at(j) = predict_weights[ind_weights];
ind_weights++;
}

for (size_t j = 0; j < (tree_counts[4*i]+tree_counts[4*i+3]); j++) {
cur_var_ids.at(j) = features[ind_var];
ind_var++;
}

for(size_t j = 0; j < tree_counts[4*i+1]; j++){
curSplittingSampleIndex.at(j) = split_idx[ind_s];
Expand All @@ -661,6 +677,8 @@ extern "C" {
}

var_ids->push_back(cur_var_ids);
average_counts->push_back(cur_average_counts);
split_counts->push_back(cur_split_counts);
split_vals->push_back(cur_split_vals);
naLeftCounts->push_back(curNaLeftCounts);
naRightCounts->push_back(curNaRightCounts);
Expand All @@ -677,6 +695,8 @@ extern "C" {
categoricalFeatureCols_copy,
treeSeeds,
var_ids,
average_counts,
split_counts,
split_vals,
naLeftCounts,
naRightCounts,
Expand Down
8 changes: 6 additions & 2 deletions Python/extension/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ extern "C" {
bool verbose,
std::vector<double>& predictions,
std::vector<double>& weight_matrix,
std::vector<size_t> training_idx
std::vector<size_t> training_idx,
bool hier_shrinkage,
double lambda_shrinkage
);
void predict_forest(
void* forest_pt,
Expand All @@ -126,7 +128,9 @@ extern "C" {
size_t num_test_rows,
std::vector<double>& predictions,
std::vector<double>& weight_matrix,
std::vector<double>& coefs
std::vector<double>& coefs,
bool hier_shrinkage = false,
double lambda_shrinkage = 0
);
void fill_tree_info(
void* forest_ptr,
Expand Down
16 changes: 12 additions & 4 deletions Python/extension/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ py::tuple predictOOB_forest_wrapper(
bool use_training_idx,
unsigned int n_preds,
unsigned int n_weight_matrix,
py::array_t<size_t> training_idx
py::array_t<size_t> training_idx,
bool hier_shrinkage,
double lambda_shrinkage
) {
py::array_t<double> predictions = create_numpy_array(n_preds);
std::vector<double> predictions_vector(n_preds);
Expand All @@ -199,7 +201,9 @@ py::tuple predictOOB_forest_wrapper(
verbose,
predictions_vector,
weight_matrix_vector,
training_idx_vector
training_idx_vector,
hier_shrinkage,
lambda_shrinkage
);

copy_vector_to_numpy_array(predictions_vector, predictions);
Expand Down Expand Up @@ -241,7 +245,9 @@ py::tuple predict_forest_wrapper(
size_t num_test_rows,
unsigned int n_preds,
unsigned int n_weight_matrix,
unsigned int n_coefficients
unsigned int n_coefficients,
bool hier_shrinkage,
double lambda_shrinkage
) {
py::array_t<double> predictions = create_numpy_array(n_preds);
std::vector<double> predictions_vector(n_preds);
Expand All @@ -266,7 +272,9 @@ py::tuple predict_forest_wrapper(
num_test_rows,
predictions_vector,
weight_matrix_vector,
coefficients_vector
coefficients_vector,
hier_shrinkage,
lambda_shrinkage
);

copy_vector_to_numpy_array(predictions_vector, predictions);
Expand Down
Loading
Loading