Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c++] Fix dump_model() information for root node #6569

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
12102cc
Fix value calculation in root node
neNasko1 Jul 24, 2024
c933399
Fix dask tests
neNasko1 Jul 26, 2024
c240016
Merge branch 'master' into fix-root-values
neNasko1 Jul 26, 2024
2f1de57
Create proper tests
neNasko1 Jul 29, 2024
273a1df
Merge branch 'master' into fix-root-values
neNasko1 Jul 29, 2024
208df85
Test only on cpu
neNasko1 Jul 29, 2024
130879b
Merge branch 'fix-root-values' of github.com:neNasko1/LightGBM into f…
neNasko1 Jul 29, 2024
48e6b96
Disable new tests for CUDA
neNasko1 Jul 30, 2024
26b9859
Merge with #5964
neNasko1 Aug 3, 2024
88e3dec
Finish merging with dump_model unification
neNasko1 Aug 3, 2024
e1274dc
Improve tests
neNasko1 Aug 3, 2024
38ee92c
Add linear test for stump
neNasko1 Aug 4, 2024
3b423de
Fix CUDA compilation
neNasko1 Aug 5, 2024
c89e257
Merge branch 'master' into fix-root-values
neNasko1 Aug 5, 2024
3de14d9
Merge branch 'master' into fix-root-values
neNasko1 Aug 6, 2024
fc42c1c
Merge branch 'master' into fix-root-values
neNasko1 Aug 14, 2024
3ffcac6
Comments after code review
neNasko1 Aug 14, 2024
d5a82c4
Fix test
neNasko1 Aug 15, 2024
be7675d
Reenable cuda testing
neNasko1 Aug 15, 2024
f616e03
Tests
neNasko1 Aug 15, 2024
6c6bc33
Merge branch 'microsoft:master' into fix-root-values
neNasko1 Aug 15, 2024
c28a2cf
test cuda
neNasko1 Aug 15, 2024
6113f90
.
neNasko1 Aug 15, 2024
94cf7f0
Fix warning
neNasko1 Aug 15, 2024
01aa952
reenable tests
neNasko1 Aug 15, 2024
fadaa83
.
neNasko1 Aug 15, 2024
b9c681b
Merge branch 'fix-cuda' into fix-root-values
neNasko1 Aug 15, 2024
a323acb
fix cuda
neNasko1 Aug 15, 2024
0fd0c59
Fix compilation error
neNasko1 Aug 15, 2024
4cc5dd4
Fix weight
neNasko1 Aug 15, 2024
a743a87
Fix numerical
neNasko1 Aug 15, 2024
031c945
Make tests more robust
neNasko1 Aug 16, 2024
91993a9
Merge branch 'master' into fix-root-values
neNasko1 Sep 2, 2024
f744f64
Merge branch 'master' into fix-root-values
neNasko1 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class CUDATree : public Tree {
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const override;

inline void AsConstantTree(double val) override;
inline void AsConstantTree(double val, int count) override;

const int* cuda_leaf_parent() const { return cuda_leaf_parent_; }

Expand Down
5 changes: 3 additions & 2 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,14 @@ class Tree {
shrinkage_ = 1.0f;
}

virtual inline void AsConstantTree(double val) {
virtual inline void AsConstantTree(double val, int count = 0) {
num_leaves_ = 1;
shrinkage_ = 1.0f;
leaf_value_[0] = val;
if (is_linear_) {
leaf_const_[0] = val;
}
leaf_count_[0] = count;
}

/*! \brief Serialize this object to string*/
Expand Down Expand Up @@ -563,7 +564,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_weight_[new_node_idx] = leaf_weight_[leaf];
internal_weight_[new_node_idx] = left_weight + right_weight;
internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
Expand Down
2 changes: 1 addition & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3913,7 +3913,7 @@ def _get_split_feature(
return feature_name

def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
return set(tree.keys()) == {"leaf_value"}
return set(tree.keys()) == {"leaf_value", "leaf_count"}

# Create the node record, and populate universal data members
node: Dict[str, Union[int, str, None]] = OrderedDict()
Expand Down
5 changes: 4 additions & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
score_updater->AddScore(init_scores[cur_tree_id], cur_tree_id);
}
}
new_tree->AsConstantTree(init_scores[cur_tree_id]);
new_tree->AsConstantTree(init_scores[cur_tree_id], num_data_);
} else {
// extend init_scores with zeros
new_tree->AsConstantTree(0, num_data_);
}
}
// add model
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/rf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class RF : public GBDT {
output = init_scores_[cur_tree_id];
}
}
new_tree->AsConstantTree(output);
new_tree->AsConstantTree(output, num_data_);
MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
UpdateScore(new_tree.get(), cur_tree_id);
MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
Expand Down
5 changes: 3 additions & 2 deletions src/io/cuda/cuda_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ void CUDATree::SyncLeafOutputFromCUDAToHost() {
CopyFromCUDADeviceToHost<double>(leaf_value_.data(), cuda_leaf_value_, leaf_value_.size(), __FILE__, __LINE__);
}

void CUDATree::AsConstantTree(double val) {
Tree::AsConstantTree(val);
void CUDATree::AsConstantTree(double val, int count) {
Tree::AsConstantTree(val, count);
CopyFromHostToCUDADevice<double>(cuda_leaf_value_, &val, 1, __FILE__, __LINE__);
CopyFromHostToCUDADevice<int>(cuda_leaf_count_, &count, 1, __FILE__, __LINE__);
}

} // namespace LightGBM
Expand Down
4 changes: 2 additions & 2 deletions src/io/cuda/cuda_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ __global__ void SplitKernel( // split information
split_gain[new_node_index] = static_cast<float>(cuda_split_info->gain);
} else if (thread_index == 4) {
// save current leaf value to internal node before change
internal_weight[new_node_index] = leaf_weight[leaf_index];
internal_weight[new_node_index] = cuda_split_info->left_sum_hessians + cuda_split_info->right_sum_hessians;
leaf_weight[leaf_index] = cuda_split_info->left_sum_hessians;
} else if (thread_index == 5) {
internal_value[new_node_index] = leaf_value[leaf_index];
Expand Down Expand Up @@ -210,7 +210,7 @@ __global__ void SplitCategoricalKernel( // split information
split_gain[new_node_index] = static_cast<float>(cuda_split_info->gain);
} else if (thread_index == 4) {
// save current leaf value to internal node before change
internal_weight[new_node_index] = leaf_weight[leaf_index];
internal_weight[new_node_index] = cuda_split_info->left_sum_hessians + cuda_split_info->right_sum_hessians;
leaf_weight[leaf_index] = cuda_split_info->left_sum_hessians;
} else if (thread_index == 5) {
internal_value[new_node_index] = leaf_value[leaf_index];
Expand Down
21 changes: 12 additions & 9 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,15 @@ std::string Tree::ToJSON() const {
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{";
str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n';
if (is_linear_) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
str_buf << LinearModelToJSON(0) << "}" << "\n";
str_buf << "\"leaf_count\":" << leaf_count_[0] << ", " << '\n';
str_buf << LinearModelToJSON(0);
} else {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[0];
}
str_buf << "}" << '\n';
} else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
}
Expand Down Expand Up @@ -731,6 +734,12 @@ Tree::Tree(const char* str, size_t* used_len) {
is_linear_ = false;
}

if (key_vals.count("leaf_count")) {
leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}

#ifdef USE_CUDA
is_cuda_tree_ = false;
#endif // USE_CUDA
Expand Down Expand Up @@ -793,12 +802,6 @@ Tree::Tree(const char* str, size_t* used_len) {
leaf_weight_.resize(num_leaves_);
}

if (key_vals.count("leaf_count")) {
leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}

if (key_vals.count("decision_type")) {
decision_type_ = CommonC::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
} else {
Expand Down
7 changes: 5 additions & 2 deletions src/treelearner/cuda/cuda_leaf_splits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ void CUDALeafSplits::InitValues(
const double lambda_l1, const double lambda_l2,
const score_t* cuda_gradients, const score_t* cuda_hessians,
const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf,
const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) {
const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf,
double* root_sum_gradients, double* root_sum_hessians) {
cuda_gradients_ = cuda_gradients;
cuda_hessians_ = cuda_hessians;
cuda_sum_of_gradients_buffer_.SetValue(0);
cuda_sum_of_hessians_buffer_.SetValue(0);
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf);
CopyFromCUDADeviceToHost<double>(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
Expand All @@ -53,11 +55,12 @@ void CUDALeafSplits::InitValues(
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale) {
cuda_gradients_ = reinterpret_cast<const score_t*>(cuda_gradients_and_hessians);
cuda_hessians_ = nullptr;
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale);
CopyFromCUDADeviceToHost<double>(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
Expand Down
4 changes: 2 additions & 2 deletions src/treelearner/cuda/cuda_leaf_splits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ class CUDALeafSplits {
const score_t* cuda_gradients, const score_t* cuda_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians);
hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians);

void InitValues(
const double lambda_l1, const double lambda_l2,
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale);

void InitValues();
Expand Down
26 changes: 16 additions & 10 deletions src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
leaf_best_split_default_left_.resize(config_->num_leaves, 0);
leaf_num_data_.resize(config_->num_leaves, 0);
leaf_data_start_.resize(config_->num_leaves, 0);
leaf_sum_gradients_.resize(config_->num_leaves, 0.0f);
leaf_sum_hessians_.resize(config_->num_leaves, 0.0f);

if (!boosting_on_cuda_) {
Expand Down Expand Up @@ -122,6 +123,7 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_gradients_[0],
&leaf_sum_hessians_[0],
cuda_gradient_discretizer_->grad_scale_ptr(),
cuda_gradient_discretizer_->hess_scale_ptr());
Expand All @@ -137,6 +139,7 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_gradients_[0],
&leaf_sum_hessians_[0]);
}
leaf_num_data_[0] = root_num_data;
Expand All @@ -162,6 +165,12 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
const bool track_branch_features = !(config_->interaction_constraints_vector.empty());
std::unique_ptr<CUDATree> tree(new CUDATree(config_->num_leaves, track_branch_features,
config_->linear_tree, config_->gpu_device_id, has_categorical_feature_));
// set the root value by hand, as it is not handled by splits
tree->SetLeafOutput(0, CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(
leaf_sum_gradients_[smaller_leaf_index_], leaf_sum_hessians_[smaller_leaf_index_],
config_->lambda_l1, config_->lambda_l2, config_->path_smooth,
static_cast<data_size_t>(num_data_), 0));
tree->SyncLeafOutputFromHostToCUDA();
for (int i = 0; i < config_->num_leaves - 1; ++i) {
global_timer.Start("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf");
const data_size_t num_data_in_smaller_leaf = leaf_num_data_[smaller_leaf_index_];
Expand Down Expand Up @@ -293,8 +302,6 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
best_split_info);
}

double sum_left_gradients = 0.0f;
double sum_right_gradients = 0.0f;
cuda_data_partition_->Split(best_split_info,
best_leaf_index_,
right_leaf_index,
Expand All @@ -313,10 +320,10 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
&leaf_data_start_[right_leaf_index],
&leaf_sum_hessians_[best_leaf_index_],
&leaf_sum_hessians_[right_leaf_index],
&sum_left_gradients,
&sum_right_gradients);
&leaf_sum_gradients_[best_leaf_index_],
&leaf_sum_gradients_[right_leaf_index]);
#ifdef DEBUG
CheckSplitValid(best_leaf_index_, right_leaf_index, sum_left_gradients, sum_right_gradients);
CheckSplitValid(best_leaf_index_, right_leaf_index);
#endif // DEBUG
smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index);
larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_);
Expand Down Expand Up @@ -374,6 +381,7 @@ void CUDASingleGPUTreeLearner::ResetConfig(const Config* config) {
leaf_best_split_default_left_.resize(config_->num_leaves, 0);
leaf_num_data_.resize(config_->num_leaves, 0);
leaf_data_start_.resize(config_->num_leaves, 0);
leaf_sum_gradients_.resize(config_->num_leaves, 0.0f);
leaf_sum_hessians_.resize(config_->num_leaves, 0.0f);
}
cuda_histogram_constructor_->ResetConfig(config);
Expand Down Expand Up @@ -562,9 +570,7 @@ void CUDASingleGPUTreeLearner::SelectFeatureByNode(const Tree* tree) {
#ifdef DEBUG
void CUDASingleGPUTreeLearner::CheckSplitValid(
const int left_leaf,
const int right_leaf,
const double split_sum_left_gradients,
const double split_sum_right_gradients) {
const int right_leaf) {
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<data_size_t> left_data_indices(leaf_num_data_[left_leaf]);
std::vector<data_size_t> right_data_indices(leaf_num_data_[right_leaf]);
CopyFromCUDADeviceToHost<data_size_t>(left_data_indices.data(),
Expand All @@ -585,9 +591,9 @@ void CUDASingleGPUTreeLearner::CheckSplitValid(
sum_right_gradients += host_gradients_[index];
sum_right_hessians += host_hessians_[index];
}
CHECK_LE(std::fabs(sum_left_gradients - split_sum_left_gradients), 1e-6f);
CHECK_LE(std::fabs(sum_left_gradients - leaf_sum_gradients_[left_leaf]), 1e-6f);
CHECK_LE(std::fabs(sum_left_hessians - leaf_sum_hessians_[left_leaf]), 1e-6f);
CHECK_LE(std::fabs(sum_right_gradients - split_sum_right_gradients), 1e-6f);
CHECK_LE(std::fabs(sum_right_gradients - leaf_sum_gradients_[right_leaf]), 1e-6f);
CHECK_LE(std::fabs(sum_right_hessians - leaf_sum_hessians_[right_leaf]), 1e-6f);
}
#endif // DEBUG
Expand Down
4 changes: 2 additions & 2 deletions src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {

#ifdef DEBUG
void CheckSplitValid(
const int left_leaf, const int right_leaf,
const double sum_left_gradients, const double sum_right_gradients);
const int left_leaf, const int right_leaf);
#endif // DEBUG

void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree);
Expand Down Expand Up @@ -103,6 +102,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
std::vector<uint8_t> leaf_best_split_default_left_;
std::vector<data_size_t> leaf_num_data_;
std::vector<data_size_t> leaf_data_start_;
std::vector<double> leaf_sum_gradients_;
std::vector<double> leaf_sum_hessians_;
int smaller_leaf_index_;
int larger_leaf_index_;
Expand Down
6 changes: 6 additions & 0 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr);

// set the root value by hand, as it is not handled by splits
tree->SetLeafOutput(0, FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>(
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
config_->lambda_l1, config_->lambda_l2, config_->max_delta_step,
BasicConstraint(), config_->path_smooth, static_cast<data_size_t>(num_data_), 0));

// root leaf
int left_leaf = 0;
int cur_depth = 1;
Expand Down
17 changes: 12 additions & 5 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,8 +1444,8 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task
@pytest.mark.parametrize("task", tasks)
@pytest.mark.parametrize("output", data_output)
def test_init_score(task, output, cluster):
if task == "ranking" and output == "scipy_csr_matrix":
pytest.skip("LGBMRanker is not currently tested on sparse matrices")
if task == "ranking":
pytest.skip("LGBMRanker is not currently tested for init_score")

with Client(cluster) as client:
_, _, _, _, dX, dy, dw, dg = _create_data(objective=task, output=output, group=None)
Expand All @@ -1462,10 +1462,17 @@ def test_init_score(task, output, cluster):
init_scores = dy.map_partitions(lambda x: pd.DataFrame([[init_score] * size_factor] * x.size))
else:
init_scores = dy.map_blocks(lambda x: np.full((x.size, size_factor), init_score))

model = model_factory(client=client, **params)
model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg)
# value of the root node is 0 when init_score is set
assert model.booster_.trees_to_dataframe()["value"][0] == 0
model.fit(dX, dy, sample_weight=dw, group=dg)
pred = model.predict(dX, raw_score=True)

model_init_score = model_factory(client=client, **params)
model_init_score.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg)
pred_init_score = model_init_score.predict(dX, raw_score=True)

# check if init score changes predictions
assert not np.allclose(pred, pred_init_score)


def sklearn_checks_to_run():
Expand Down
Loading
Loading