Skip to content

Commit

Permalink
updated the weight matrix to respect HS and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sidc321 committed Jul 25, 2023
1 parent 8883f95 commit a15848c
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 61 deletions.
4 changes: 2 additions & 2 deletions R/R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ rcpp_OBBPredictInterface <- function(forest) {
.Call(`_Rforestry_rcpp_OBBPredictInterface`, forest)
}

rcpp_OBBPredictionsInterface <- function(forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx) {
.Call(`_Rforestry_rcpp_OBBPredictionsInterface`, forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx)
rcpp_OBBPredictionsInterface <- function(forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx, hier_shrinkage, lambda_shrinkage) {
.Call(`_Rforestry_rcpp_OBBPredictionsInterface`, forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx, hier_shrinkage, lambda_shrinkage)
}

rcpp_getObservationSizeInterface <- function(df) {
Expand Down
22 changes: 16 additions & 6 deletions R/R/forestry.R
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,9 @@ predict.forestry <- function(object,
use_weights = use_weights,
use_hold_out_idx = TRUE,
tree_weights = tree_weights,
hold_out_idx = (holdOutIdx-1)) # Change to 0 indexed for C++
hold_out_idx = (holdOutIdx-1), # Change to 0 indexed for C++
hier_shrinkage,
lambda_shrinkage)
}, error = function(err) {
print(err)
return(NULL)
Expand All @@ -1498,7 +1500,9 @@ predict.forestry <- function(object,
weightMatrix,
exact,
useTrainingIndices,
trainingIndices
trainingIndices,
hier_shrinkage,
lambda_shrinkage
)
}, error = function(err) {
print(err)
Expand Down Expand Up @@ -1535,7 +1539,9 @@ predict.forestry <- function(object,
weightMatrix,
exact,
useTrainingIndices,
trainingIndices
trainingIndices,
hier_shrinkage,
lambda_shrinkage
)
}, error = function(err) {
print(err)
Expand Down Expand Up @@ -1695,7 +1701,9 @@ getOOB <- function(object,
getOOBpreds <- function(object,
newdata = NULL,
doubleOOB = FALSE,
noWarning = FALSE
noWarning = FALSE,
hier_shrinkage = FALSE,
lambda_shrinkage = FALSE
) {

if (!object@replace &&
Expand Down Expand Up @@ -1760,7 +1768,9 @@ getOOBpreds <- function(object,
FALSE,
TRUE,
FALSE,
c(-1))
c(-1),
hier_shrinkage,
lambda_shrinkage)

# If we have scaled the observations, we want to rescale the predictions
if (object@scale) {
Expand Down Expand Up @@ -1893,7 +1903,7 @@ getVI <- function(object,
#' @param aggregation Specifies which aggregation version is used to predict for the
#' observation, must be one of `average`,`oob`, and `doubleOOB`.
#' @return A list with four entries. `weightMatrix` is a matrix specifying the
#' weight given to training observatio i when prediction on observation j.
#' weight given to training observation i when prediction on observation j.
#' `avgIndices` gives the indices which are in the averaging set for each new
#' observation. `avgWeights` gives the weights corresponding to each averaging
#' observation returned in `avgIndices`. `obsInfo` gives the full observation vectors
Expand Down
107 changes: 85 additions & 22 deletions R/tests/testthat/test-forestry_hierarchicalShrinkage.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
test_that("Tests hierarchical shrinkage works as expected", {
context("check total node number equals length of full weights and count vector")
context("Check total node number equals length of full weights and count vector")

set.seed(238943202)
# Test example with simple step functio
#set.seed(238943202)
x <- iris[, -1]
y <- iris[, 1]
rf <- forestry(x,
Expand All @@ -17,9 +16,8 @@ test_that("Tests hierarchical shrinkage works as expected", {
expect_equal(length(rf@R_forest[[1]]$average_count),num_nodes)
expect_equal(length(rf@R_forest[[1]]$weightsFull),num_nodes)

context("check output predictions when using hierarchical shrinkage for small tree")
context("Check output predictions when using hierarchical shrinkage for small tree")

# Test example with simple step functio
test_idx <- sample(nrow(iris), 100)
x_train <- data.frame(width= iris[-test_idx, 2])
y_train <- iris[-test_idx, 1]
Expand All @@ -45,29 +43,94 @@ test_that("Tests hierarchical shrinkage works as expected", {
weightRightPath = fdata$weightsFull[1]*(1-1/(1+lambda_shrinkage/fdata$average_count[1])) + fdata$weightsFull[3]/(1+lambda_shrinkage/fdata$average_count[1])
expectedPredictions[expectedPredictions==1] = weightLeftPath
expectedPredictions[expectedPredictions==2] = weightRightPath
shrinked_pred = predict(rf, x_test, hier_shrinkage=T, lambda_shrinkage=2)

shrinked_pred = predict(rf, x_test, hier_shrinkage=T, lambda_shrinkage=lambda_shrinkage)
expect_equal(shrinked_pred,expectedPredictions)

context("check output predictions when using hierarchical shrinkage for lambda=0")
context("Check output predictions when using hierarchical shrinkage for lambda=0 and large lambda")
test_idx <- sample(nrow(iris), 100)
x_train <- data.frame(width= iris[-test_idx, 2])
x_train <- data.frame(width= iris[-test_idx, -1])
y_train <- iris[-test_idx, 1]
x_test <- data.frame(width= iris[test_idx, 2])
x_test <- data.frame(width= iris[test_idx, -1])
y_test <- iris[test_idx, 1]

rf <- forestry(x = x_train, y = y_train, ntree = 10)
unshrink_predictions = predict(rf,x_test)
rf <- forestry(x = x_train, y = y_train, ntree = 10,replace=F,sampsize = length(y_train),mtry=ncol(x))
noshrink_predictions = predict(rf,x_test)
lambda0_predictions = predict(rf, x_test, hier_shrinkage=T, lambda_shrinkage=0)
expect_equal(noshrink_predictions,lambda0_predictions)

context("check output predictions when using hierarchical shrinkage for large lambda")
test_idx <- sample(nrow(iris), 100)
x_train <- data.frame(width= iris[-test_idx, 2])
y_train <- iris[-test_idx, 1]
x_test <- data.frame(width= iris[test_idx, 2])
lambdalarge_predictions = predict(rf, x_test, hier_shrinkage=T, lambda_shrinkage=1e10)
tot_prediction_diffs = lambdalarge_predictions-mean(y_train)
expect_true(all.equal(tot_prediction_diffs ,rep(0,length(tot_prediction_diffs) )))

context("Check hierarchical shrinkage prediction matches getOOBpreds")
rf <- forestry(x = iris[,-1],
y = iris[,1],
OOBhonest = TRUE,ntree=10)

doubleOOBpreds <- getOOBpreds(rf, doubleOOB = TRUE,
noWarning = TRUE,hier_shrinkage = T,lambda_shrinkage = 10)
OOBpreds <- getOOBpreds(rf, noWarning = TRUE,hier_shrinkage = T,lambda_shrinkage = 10)
predict_doubleOOBpreds <- predict(rf, aggregation = "doubleOOB",hier_shrinkage = T,lambda_shrinkage = 10)
predict_OOBpreds <- predict(rf, aggregation = "oob",hier_shrinkage = T,lambda_shrinkage = 10)

# Expect OOB preds from getOOB preds and predict to be the same
expect_equal(all.equal(predict_OOBpreds,
OOBpreds), TRUE)

# Expect double OOB preds to be the same from predict and getOOBpreds
expect_equal(all.equal(predict_doubleOOBpreds,
doubleOOBpreds), TRUE)

context("Check Weight matrix matches predictions from hierarchical shrinkage")
x = iris[,-1]
y = iris[,1]
rf <- forestry(x=x,y=y,ntree=10)
shrink_preds = predict(rf,x,weightMatrix = TRUE,hier_shrinkage = T,lambda_shrinkage =10)
noshrink = predict(rf,x,weightMatrix = TRUE)
# now we reconstruct predictions from the weight matrix and check they match
weight_preds = as.vector(shrink_preds$weightMatrix %*% y)
expect_equal(all.equal(weight_preds,shrink_preds$predictions),TRUE)

context("Check hierarchical shrinkage predictions and weightMatrix oob and double oob matches expectations for lambda = 0 and large lambda")

x = iris[,-1]
y = iris[,1]
rf <- forestry(x,
y,
OOBhonest = TRUE)
predict_doubleOOBpreds_lambda0 <- predict(rf, aggregation = "doubleOOB",hier_shrinkage = T,lambda_shrinkage = 0, weightMatrix = TRUE)
predict_doubleOOBpreds <- predict(rf, aggregation = "doubleOOB", weightMatrix = TRUE)

predict_OOBpreds_lambda0 <- predict(rf, aggregation = "oob",hier_shrinkage = T,lambda_shrinkage = 0, weightMatrix = TRUE)
predict_OOBpreds <- predict(rf, aggregation = "oob", weightMatrix = TRUE)

expect_equal(all.equal(predict_doubleOOBpreds_lambda0$predictions,
predict_doubleOOBpreds$predictions), TRUE)
expect_equal(all.equal(predict_OOBpreds_lambda0$predictions,
predict_OOBpreds$predictions), TRUE)
expect_equal(all.equal(predict_doubleOOBpreds_lambda0$weightMatrix,
predict_doubleOOBpreds$weightMatrix), TRUE)
expect_equal(all.equal(predict_OOBpreds_lambda0$weightMatrix,
predict_OOBpreds$weightMatrix), TRUE)


predict_doubleOOBpreds_lambdalarge <- predict(rf, aggregation = "doubleOOB",hier_shrinkage = T,lambda_shrinkage = 1e10)
predict_OOBpreds_lambdalarge <- predict(rf, aggregation = "oob",hier_shrinkage = T,lambda_shrinkage = 1e10)

expect_true(mean(predict_doubleOOBpreds_lambdalarge<5.92) && mean(predict_doubleOOBpreds_lambdalarge>5.75))
expect_true(mean(predict_OOBpreds_lambdalarge<5.92) && mean(predict_OOBpreds_lambdalarge>5.75))

context("Check Weight matrix for large lambda")
x = iris[,-1]
y = iris[,1]
rf <- forestry(x,
y, sampsize=length(y), replace=FALSE,ntree=10)
predict_matrix_lambdalarge <- predict(rf, x, hier_shrinkage = T,lambda_shrinkage = 1e10, weightMatrix = TRUE)
expect_equal(rowSums(predict_matrix_lambdalarge$weightMatrix),rep(1,ncol(predict_matrix_lambdalarge$weightMatrix)))
dims = dim(predict_matrix_lambdalarge$weightMatrix)
expect_equal(all.equal(predict_matrix_lambdalarge$weightMatrix,
matrix(rep(1/nrow(x),prod(dims)),nrow=dims[1])
),TRUE)

rf <- forestry(x = x_train, y = y_train, ntree = 10)
unshrink_predictions = predict(rf,x_test)
lambda0_predictions = predict(rf, x_test, hier_shrinkage=T, lambda_shrinkage=1e10)
tot_prediction_diffs = sum(abs(lambda0_predictions-lambda0_predictions[0]))
expect_equal(0, tot_prediction_diffs)
})
57 changes: 44 additions & 13 deletions src/RFNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ void RFNode::predict(
} else {

// Give all updateIndex the mean of the node as prediction values
// Weight by shrinkage factor if shrinkage is turned on
for (
std::vector<size_t>::iterator it = (*updateIndex).begin();
it != (*updateIndex).end();
Expand Down Expand Up @@ -231,8 +232,13 @@ void RFNode::predict(
}

for (size_t i = 0; i<idx_in_leaf.size(); i++) {
(*weightMatrix)(idx, idx_in_leaf[i] - 1) +=
(double) 1.0 / ((double) idx_in_leaf.size());
if(hier_shrinkage){
(*weightMatrix)(idx, idx_in_leaf[i] - 1) +=
(double) 1.0 / ((double) idx_in_leaf.size()) * (double) 1.0 /(1+lambda_shrinkage/parentAverageCount);
}else{
(*weightMatrix)(idx, idx_in_leaf[i] - 1) +=
(double) 1.0 / ((double) idx_in_leaf.size());
}
}
}
}
Expand All @@ -253,6 +259,42 @@ void RFNode::predict(
// If not a leaf then we need to separate the prediction tasks
} else {

// shrink predictions (and weight) on non-leaf nodes if hierarchical shrinkage is on
if(hier_shrinkage){
for (
std::vector<size_t>::iterator it = (*updateIndex).begin();
it != (*updateIndex).end();
++it
) {
double current_level_weight = 1/(1+lambda_shrinkage / getAverageCount());
double parent_level_weight = 1/(1+lambda_shrinkage/parentAverageCount);
outputPrediction[*it] += predictedMean * (parent_level_weight-current_level_weight);

// need to update the weights outside leaf node if shrinkage applied
if(weightMatrix){
// If weightMatrix is not a NULL pointer, then we want to update it,
// because we have choosen aggregation = "weightmatrix".
std::vector<size_t> idx_in_leaf =
(*trainingData).get_all_row_idx(predictionAveragingIndices);


// The following will lock the access to weightMatrix
std::lock_guard<std::mutex> lock(mutex_weightMatrix);

// Set the row which we update in the weightMatrix
size_t idx = *it;
if (OOBIndex) {
idx = (*OOBIndex)[*it];
}

for (size_t i = 0; i<idx_in_leaf.size(); i++) {
(*weightMatrix)(idx, idx_in_leaf[i] - 1) +=
(double) 1.0 / ((double) idx_in_leaf.size()) * (parent_level_weight-current_level_weight);
}
}
}
}

// Separate prediction tasks to two children
std::vector<size_t>* leftPartitionIndex = new std::vector<size_t>();
std::vector<size_t>* rightPartitionIndex = new std::vector<size_t>();
Expand Down Expand Up @@ -490,17 +532,6 @@ void RFNode::predict(

}

if(hier_shrinkage){
for (
std::vector<size_t>::iterator it = (*updateIndex).begin();
it != (*updateIndex).end();
++it
) {
double current_level_weight = 1/(1+lambda_shrinkage / getAverageCount());
double parent_level_weight = 1/(1+lambda_shrinkage/parentAverageCount);
outputPrediction[*it] += predictedMean*(parent_level_weight-current_level_weight);
}
}
// Recursively get predictions from its children
if ((*leftPartitionIndex).size() > 0) {
(*getLeftChild()).predict(
Expand Down
10 changes: 6 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ BEGIN_RCPP
END_RCPP
}
// rcpp_OBBPredictionsInterface
Rcpp::List rcpp_OBBPredictionsInterface(SEXP forest, Rcpp::List x, bool existing_df, bool doubleOOB, bool returnWeightMatrix, bool exact, bool use_training_idx, Rcpp::IntegerVector training_idx);
RcppExport SEXP _Rforestry_rcpp_OBBPredictionsInterface(SEXP forestSEXP, SEXP xSEXP, SEXP existing_dfSEXP, SEXP doubleOOBSEXP, SEXP returnWeightMatrixSEXP, SEXP exactSEXP, SEXP use_training_idxSEXP, SEXP training_idxSEXP) {
Rcpp::List rcpp_OBBPredictionsInterface(SEXP forest, Rcpp::List x, bool existing_df, bool doubleOOB, bool returnWeightMatrix, bool exact, bool use_training_idx, Rcpp::IntegerVector training_idx, bool hier_shrinkage, double lambda_shrinkage);
RcppExport SEXP _Rforestry_rcpp_OBBPredictionsInterface(SEXP forestSEXP, SEXP xSEXP, SEXP existing_dfSEXP, SEXP doubleOOBSEXP, SEXP returnWeightMatrixSEXP, SEXP exactSEXP, SEXP use_training_idxSEXP, SEXP training_idxSEXP, SEXP hier_shrinkageSEXP, SEXP lambda_shrinkageSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -142,7 +142,9 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< bool >::type exact(exactSEXP);
Rcpp::traits::input_parameter< bool >::type use_training_idx(use_training_idxSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type training_idx(training_idxSEXP);
rcpp_result_gen = Rcpp::wrap(rcpp_OBBPredictionsInterface(forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx));
Rcpp::traits::input_parameter< bool >::type hier_shrinkage(hier_shrinkageSEXP);
Rcpp::traits::input_parameter< double >::type lambda_shrinkage(lambda_shrinkageSEXP);
rcpp_result_gen = Rcpp::wrap(rcpp_OBBPredictionsInterface(forest, x, existing_df, doubleOOB, returnWeightMatrix, exact, use_training_idx, training_idx, hier_shrinkage, lambda_shrinkage));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -250,7 +252,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_Rforestry_rcpp_cppBuildInterface", (DL_FUNC) &_Rforestry_rcpp_cppBuildInterface, 45},
{"_Rforestry_rcpp_cppPredictInterface", (DL_FUNC) &_Rforestry_rcpp_cppPredictInterface, 13},
{"_Rforestry_rcpp_OBBPredictInterface", (DL_FUNC) &_Rforestry_rcpp_OBBPredictInterface, 1},
{"_Rforestry_rcpp_OBBPredictionsInterface", (DL_FUNC) &_Rforestry_rcpp_OBBPredictionsInterface, 8},
{"_Rforestry_rcpp_OBBPredictionsInterface", (DL_FUNC) &_Rforestry_rcpp_OBBPredictionsInterface, 10},
{"_Rforestry_rcpp_getObservationSizeInterface", (DL_FUNC) &_Rforestry_rcpp_getObservationSizeInterface, 1},
{"_Rforestry_rcpp_AddTreeInterface", (DL_FUNC) &_Rforestry_rcpp_AddTreeInterface, 2},
{"_Rforestry_rcpp_CppToR_translator", (DL_FUNC) &_Rforestry_rcpp_CppToR_translator, 1},
Expand Down
16 changes: 12 additions & 4 deletions src/forestry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,9 @@ std::vector<double> forestry::predictOOB(
std::vector<size_t>* treeCounts,
bool doubleOOB,
bool exact,
std::vector<size_t> &training_idx
std::vector<size_t> &training_idx,
bool hier_shrinkage,
double lambda_shrinkage
) {

bool use_training_idx = !training_idx.empty();
Expand Down Expand Up @@ -721,7 +723,9 @@ std::vector<double> forestry::predictOOB(
getMinNodeSizeToSplitAvg(),
xNew,
weightMatrix,
training_idx
training_idx,
hier_shrinkage,
lambda_shrinkage
);
#if DOPARELLEL
std::lock_guard<std::mutex> lock(threadLock);
Expand Down Expand Up @@ -831,7 +835,9 @@ std::vector<double> forestry::predictOOB(
}

void forestry::calculateOOBError(
bool doubleOOB
bool doubleOOB,
bool hier_shrinkage,
double lambda_shrinkage
) {

size_t numObservations = getTrainingData()->getNumRows();
Expand Down Expand Up @@ -889,7 +895,9 @@ void forestry::calculateOOBError(
getMinNodeSizeToSplitAvg(),
nullptr,
NULL,
training_idx
training_idx,
hier_shrinkage,
lambda_shrinkage
);

#if DOPARELLEL
Expand Down
8 changes: 6 additions & 2 deletions src/forestry.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ class forestry {
std::vector<size_t>* treeCounts,
bool doubleOOB,
bool exact,
std::vector<size_t> &training_idx
std::vector<size_t> &training_idx,
bool hier_shrinkage,
double lambda_shrinkage
);

void fillinTreeInfo(
Expand All @@ -92,7 +94,9 @@ class forestry {
size_t getTotalNodeCount();

void calculateOOBError(
bool doubleOOB = false
bool doubleOOB = false,
bool hier_shrinkage = false,
double lambda_shrinkage = 0
);

double getOOBError() {
Expand Down
Loading

0 comments on commit a15848c

Please sign in to comment.