Skip to content

Commit

Permalink
fix bagging by query with pairwise lambdarank
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Sep 20, 2024
1 parent 0258f07 commit 90a95fa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
14 changes: 9 additions & 5 deletions src/boosting/bagging.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class BaggingSampleStrategy : public SampleStrategy {
train_data_ = train_data;
num_data_ = train_data->num_data();
num_queries_ = train_data->metadata().num_queries();
query_boundaries_ = train_data->metadata().query_boundaries();
if (config->objective == std::string("pairwise_lambdarank")) {
query_boundaries_ = train_data->metadata().pairwise_query_boundaries();
} else {
query_boundaries_ = train_data->metadata().query_boundaries();
}
objective_function_ = objective_function;
num_tree_per_iteration_ = num_tree_per_iteration;
num_threads_ = OMP_NUM_THREADS();
Expand Down Expand Up @@ -62,14 +66,14 @@ class BaggingSampleStrategy : public SampleStrategy {
sampled_query_boundaries_[0] = 0;
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (data_size_t i = 0; i < num_sampled_queries_; ++i) {
for (data_size_t i = 0; i < num_queries_; ++i) {
OMP_LOOP_EX_BEGIN();
sampled_query_boundaries_[i + 1] = query_boundaries_[bag_query_indices_[i] + 1] - query_boundaries_[bag_query_indices_[i]];
OMP_LOOP_EX_END();
}
OMP_THROW_EX();

const int num_blocks = Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
const int num_blocks = Threading::For<data_size_t>(0, num_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
for (data_size_t i = start_index + 1; i < end_index; ++i) {
sampled_query_boundaries_[i] += sampled_query_boundaries_[i - 1];
}
Expand All @@ -80,7 +84,7 @@ class BaggingSampleStrategy : public SampleStrategy {
sampled_query_boundaires_thread_buffer_[thread_index] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
}

Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
Threading::For<data_size_t>(0, num_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
if (thread_index > 0) {
for (data_size_t i = start_index; i < end_index; ++i) {
sampled_query_boundaries_[i] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
Expand All @@ -90,7 +94,7 @@ class BaggingSampleStrategy : public SampleStrategy {

bag_data_cnt_ = sampled_query_boundaries_[num_sampled_queries_];

Threading::For<data_size_t>(0, num_sampled_queries_, 1, [this](int /*thread_index*/, data_size_t start_index, data_size_t end_index) {
Threading::For<data_size_t>(0, num_queries_, 1, [this](int /*thread_index*/, data_size_t start_index, data_size_t end_index) {
for (data_size_t sampled_query_id = start_index; sampled_query_id < end_index; ++sampled_query_id) {
const data_size_t query_index = bag_query_indices_[sampled_query_id];
const data_size_t data_index_start = query_boundaries_[query_index];
Expand Down
15 changes: 8 additions & 7 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,14 +917,15 @@ void Dataset::CreatePairWiseRankingData(const Dataset* dataset, const bool is_va
feature2subfeature_.clear();
has_raw_ = dataset->has_raw();
numeric_feature_map_ = dataset->numeric_feature_map_;
for (const int feature_index : dataset->numeric_feature_map_) {
if (feature_index != -1) {
numeric_feature_map_.push_back(feature_index + dataset->num_features_);
num_numeric_features_ = dataset->num_numeric_features_;
for (const int nuermic_feature_index : dataset->numeric_feature_map_) {
if (nuermic_feature_index != -1) {
numeric_feature_map_.push_back(num_numeric_features_);
++num_numeric_features_;
} else {
numeric_feature_map_.push_back(-1);
}
}
num_numeric_features_ = dataset->num_numeric_features_ * 2;
// copy feature bin mapper data
feature_need_push_zeros_.clear();
group_bin_boundaries_.clear();
Expand Down Expand Up @@ -2102,9 +2103,9 @@ void Dataset::CreatePairwiseRankingDifferentialFeatures(
const int feature_index = diff_original_feature_index->at(i);
const data_size_t num_samples_for_feature = static_cast<data_size_t>(sample_values[feature_index].size());
if (config.zero_as_missing) {
int cur_query = 0;
for (int j = 0; j < num_samples_for_feature; ++j) {
const double value = sample_values[feature_index][j];
int cur_query = 0;
data_size_t cur_data_index = sample_indices[feature_index][j];
while (query_boundaries[cur_query + 1] <= cur_data_index) {
++cur_query;
Expand All @@ -2117,8 +2118,8 @@ void Dataset::CreatePairwiseRankingDifferentialFeatures(
} else {
CHECK_GT(sample_indices[feature_index].size(), 0);
int cur_pos_j = 0;
int cur_query = 0;
for (int j = 0; j < sample_indices[feature_index].back() + 1; ++j) {
int cur_query = 0;
while (query_boundaries[cur_query + 1] <= j) {
++cur_query;
}
Expand All @@ -2144,7 +2145,7 @@ void Dataset::CreatePairwiseRankingDifferentialFeatures(
differential_feature_bin_mappers->operator[](i)->FindBin(
sampled_differential_values[i].data(),
static_cast<int>(sampled_differential_values[i].size()),
static_cast<size_t>(num_total_sample_data * (num_total_sample_data) / 2),
static_cast<size_t>(num_total_sample_data * (num_total_sample_data + 1) / 2),
config.max_bin, config.min_data_in_bin, filter_cnt, config.feature_pre_filter,
BinType::NumericalBin, config.use_missing, config.zero_as_missing, forced_upper_bounds
);
Expand Down
29 changes: 16 additions & 13 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,17 @@ class PairwiseLambdarankNDCG: public LambdarankNDCG {
}
}

void GetGradients(const double* score_pairwise, score_t* gradients_pairwise,
score_t* hessians_pairwise) const override {
void GetGradients(const double* score_pairwise, const data_size_t num_sampled_queries, const data_size_t* sampled_query_indices,
score_t* gradients_pairwise, score_t* hessians_pairwise) const override {
const data_size_t num_queries = (sampled_query_indices == nullptr ? num_queries_ : num_sampled_queries);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
for (data_size_t i = 0; i < num_queries_; ++i) {
for (data_size_t i = 0; i < num_queries; ++i) {
global_timer.Start("pairwise_lambdarank::GetGradients part 0");
const data_size_t start_pointwise = query_boundaries_[i];
const data_size_t cnt_pointwise = query_boundaries_[i + 1] - query_boundaries_[i];
const data_size_t start_pairwise = query_boundaries_pairwise_[i];
const data_size_t cnt_pairwise = query_boundaries_pairwise_[i + 1] - query_boundaries_pairwise_[i];
const data_size_t query_index = (sampled_query_indices == nullptr ? i : sampled_query_indices[i]);
const data_size_t start_pointwise = query_boundaries_[query_index];
const data_size_t cnt_pointwise = query_boundaries_[query_index + 1] - query_boundaries_[query_index];
const data_size_t start_pairwise = query_boundaries_pairwise_[query_index];
const data_size_t cnt_pairwise = query_boundaries_pairwise_[query_index + 1] - query_boundaries_pairwise_[query_index];
std::vector<double> score_adjusted_pairwise;
if (num_position_ids_ > 0) {
for (data_size_t j = 0; j < cnt_pairwise; ++j) {
Expand All @@ -624,25 +626,26 @@ class PairwiseLambdarankNDCG: public LambdarankNDCG {
}
global_timer.Stop("pairwise_lambdarank::GetGradients part 0");
global_timer.Start("pairwise_lambdarank::GetGradients part 1");
GetGradientsForOneQuery(i, cnt_pointwise, cnt_pairwise, label_ + start_pointwise, scores_pointwise_.data() + start_pointwise, num_position_ids_ > 0 ? score_adjusted_pairwise.data() : score_pairwise + start_pairwise,
right2left_map_byquery_[i], left2right_map_byquery_[i], left_right2pair_map_byquery_[i],
GetGradientsForOneQuery(query_index, cnt_pointwise, cnt_pairwise, label_ + start_pointwise, scores_pointwise_.data() + start_pointwise, num_position_ids_ > 0 ? score_adjusted_pairwise.data() : score_pairwise + start_pairwise,
right2left_map_byquery_[query_index], left2right_map_byquery_[query_index], left_right2pair_map_byquery_[query_index],
gradients_pairwise + start_pairwise, hessians_pairwise + start_pairwise);
std::vector<data_size_t> all_pairs(cnt_pairwise);
std::iota(all_pairs.begin(), all_pairs.end(), 0);
global_timer.Stop("pairwise_lambdarank::GetGradients part 1");
global_timer.Start("pairwise_lambdarank::GetGradients part 2");
UpdatePointwiseScoresForOneQuery(i, scores_pointwise_.data() + start_pointwise, score_pairwise + start_pairwise, cnt_pointwise, cnt_pairwise, all_pairs.data(),
paired_index_map_ + start_pairwise, right2left_map_byquery_[i], left2right_map_byquery_[i], left_right2pair_map_byquery_[i], truncation_level_, sigmoid_, sigmoid_cache_);
paired_index_map_ + start_pairwise, right2left_map_byquery_[query_index], left2right_map_byquery_[query_index], left_right2pair_map_byquery_[query_index], truncation_level_, sigmoid_, sigmoid_cache_);
global_timer.Stop("pairwise_lambdarank::GetGradients part 2");
}
if (num_position_ids_ > 0) {
std::vector<score_t> gradients_pointwise(num_data_);
std::vector<score_t> hessians_pointwise(num_data_);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
for (data_size_t i = 0; i < num_queries_; ++i) {
const data_size_t cnt_pointwise = query_boundaries_[i + 1] - query_boundaries_[i];
const data_size_t cnt_pairwise = query_boundaries_pairwise_[i + 1] - query_boundaries_pairwise_[i];
TransformGradientsPairwiseIntoPointwiseForOneQuery(i, cnt_pointwise, cnt_pairwise, gradients_pairwise, hessians_pairwise, gradients_pointwise.data(), hessians_pointwise.data());
const data_size_t query_index = (sampled_query_indices == nullptr ? i : sampled_query_indices[i]);
const data_size_t cnt_pointwise = query_boundaries_[query_index + 1] - query_boundaries_[query_index];
const data_size_t cnt_pairwise = query_boundaries_pairwise_[query_index + 1] - query_boundaries_pairwise_[query_index];
TransformGradientsPairwiseIntoPointwiseForOneQuery(query_index, cnt_pointwise, cnt_pairwise, gradients_pairwise, hessians_pairwise, gradients_pointwise.data(), hessians_pointwise.data());
}
UpdatePositionBiasFactors(gradients_pointwise.data(), hessians_pointwise.data());
}
Expand Down

0 comments on commit 90a95fa

Please sign in to comment.