Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Optimal intercept initialization for simple objectives #10298

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions R-package/tests/testthat/test_poisson_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,41 @@ test_that("Poisson regression works", {
expect_equal(length(pred), 32)
expect_lt(sqrt(mean((pred - mtcars[, 11])^2)), 1.2)
})

test_that("Poisson regression is centered around mean", {
m <- 50L
n <- 10L
y <- rpois(m, n)
x <- matrix(rnorm(m * n), nrow = m)
model <- xgb.train(
data = xgb.DMatrix(x, label = y),
params = list(objective = "count:poisson", gamma = 1e4),
nrounds = 1
)
model_json <- xgb.save.raw(model, "json") |> rawToChar() |> jsonlite::fromJSON()
expect_equal(
model_json$learner$learner_model_param$base_score |> as.numeric(),
mean(y),
tolerance = 1e-4
)

pred <- predict(model, x)
expect_equal(
pred,
rep(mean(y), m),
tolerance = 1e-4
)

w <- y + 1
model_weighted <- xgb.train(
data = xgb.DMatrix(x, label = y, weight = w),
params = list(objective = "count:poisson", gamma = 1e4),
nrounds = 1
)
model_json <- xgb.save.raw(model_weighted, "json") |> rawToChar() |> jsonlite::fromJSON()
expect_equal(
model_json$learner$learner_model_param$base_score |> as.numeric(),
weighted.mean(y, w),
tolerance = 1e-4
)
})
8 changes: 5 additions & 3 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ class ObjFunction : public Configurable {
*/
[[nodiscard]] virtual bst_float ProbToMargin(bst_float base_score) const { return base_score; }
/**
* \brief Make initialize estimation of prediction.
* @brief Obtain the initial estimation of prediction.
*
* \param info MetaInfo that contains label.
* \param base_score Output estimation.
* The output in `base_score` represents prediction after apply the inverse link function.
*
* @param info MetaInfo that contains label.
* @param base_score Output estimation.
*/
virtual void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const;
/*!
Expand Down
14 changes: 10 additions & 4 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
return value;
}

template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const* ctx, bool is_column_split,
linalg::TensorView<T, kDim> values) {
if (!is_column_split) {
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}

/**
* @brief Find the global sum of the given values across all workers.
*
Expand All @@ -148,10 +157,7 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
return GlobalSum(ctx, info.IsColumnSplit(), values);
}

/**
Expand Down
67 changes: 61 additions & 6 deletions src/common/stats.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include "stats.h"

#include <cstddef> // std::size_t
#include <numeric> // std::accumulate
#include <cstddef> // std::size_t
#include <numeric> // std::accumulate

#include "common.h" // OptionalWeights
#include "linalg_op.h"
#include "../collective/aggregator.h" // for GlobalSum
#include "linalg_op.h" // for Matrix
#include "optional_weight.h" // OptionalWeights
#include "threading_utils.h" // ParallelFor, MemStackAllocator
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/context.h" // Context
Expand All @@ -16,7 +17,7 @@
#include "xgboost/logging.h" // CHECK_EQ

namespace xgboost::common {
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
void Median(Context const* ctx, linalg::Matrix<float> const& t,
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out) {
if (ctx->IsCUDA()) {
weights.SetDevice(ctx->Device());
Expand Down Expand Up @@ -61,4 +62,58 @@ void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<flo
out->HostView()(0) = ret;
}
}

void SampleMean(Context const* ctx, bool is_column_split, linalg::Matrix<float> const& v,
linalg::Vector<float>* out) {
*out = linalg::Zeros<float>(ctx, std::max(v.Shape(1), decltype(v.Shape(1)){1}));
if (ctx->IsCPU()) {
auto h_v = v.HostView();
CHECK(h_v.CContiguous());
std::int64_t n_samples = v.Shape(0);
SafeColl(collective::GlobalSum(ctx, is_column_split, linalg::MakeVec(&n_samples, 1)));
auto n_columns = v.Shape(1);
auto h_out = out->HostView();

auto n_rows_f64 = static_cast<double>(n_samples);
for (std::size_t j = 0; j < n_columns; ++j) {
MemStackAllocator<double, DefaultMaxThreads()> mean_tloc(ctx->Threads(), 0.0f);
ParallelFor(v.Shape(0), ctx->Threads(),
[&](auto i) { mean_tloc[omp_get_thread_num()] += (h_v(i, j) / n_rows_f64); });
auto mean = std::accumulate(mean_tloc.cbegin(), mean_tloc.cend(), 0.0);
h_out(j) = mean;
}
SafeColl(collective::GlobalSum(ctx, is_column_split, h_out));
} else {
auto d_v = v.View(ctx->Device());
auto d_out = out->View(ctx->Device());
cuda_impl::SampleMean(ctx, is_column_split, d_v, d_out);
}
}

void WeightedSampleMean(Context const* ctx, bool is_column_split, linalg::Matrix<float> const& v,
HostDeviceVector<float> const& w, linalg::Vector<float>* out) {
*out = linalg::Zeros<float>(ctx, std::max(v.Shape(1), decltype(v.Shape(1)){1}));
CHECK_EQ(v.Shape(0), w.Size());
if (ctx->IsCPU()) {
auto h_v = v.HostView();
auto h_w = w.ConstHostSpan();
auto sum_w = std::accumulate(h_w.data(), h_w.data() + h_w.size(), 0.0);
SafeColl(collective::GlobalSum(ctx, is_column_split, linalg::MakeVec(&sum_w, 1)));
auto h_out = out->HostView();
for (std::size_t j = 0; j < v.Shape(1); ++j) {
MemStackAllocator<double, DefaultMaxThreads()> mean_tloc(ctx->Threads(), 0.0f);
ParallelFor(v.Shape(0), ctx->Threads(),
[&](auto i) { mean_tloc[omp_get_thread_num()] += (h_v(i, j) * h_w(i) / sum_w); });
auto mean = std::accumulate(mean_tloc.cbegin(), mean_tloc.cend(), 0.0);
h_out(j) = mean;
}
SafeColl(collective::GlobalSum(ctx, is_column_split, h_out));
} else {
auto d_v = v.View(ctx->Device());
w.SetDevice(ctx->Device());
auto d_w = w.ConstDeviceSpan();
auto d_out = out->View(ctx->Device());
cuda_impl::WeightedSampleMean(ctx, is_column_split, d_v, d_w, d_out);
}
}
} // namespace xgboost::common
76 changes: 67 additions & 9 deletions src/common/stats.cu
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/

#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator

#include <cstddef> // size_t
#include <cstddef> // size_t

#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#include "optional_weight.h" // common::OptionalWeights
#include "../collective/aggregator.h" // for GlobalSum
#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#include "optional_weight.h" // common::OptionalWeights
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply
#include "xgboost/base.h" // for XGBOOST_DEVICE
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for TensorView, UnravelIndex, Apply

namespace xgboost::common::cuda_impl {
void Median(Context const* ctx, linalg::TensorView<float const, 2> t,
Expand Down Expand Up @@ -58,4 +59,61 @@ void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorV
dh::TemporaryArray<char> temp{bytes};
cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size(), s);
}

void SampleMean(Context const* ctx, bool is_column_split, linalg::MatrixView<float const> d_v,
linalg::VectorView<float> d_out) {
auto n_samples = d_v.Shape(0);
auto n_total_samples = n_samples;
auto cpu = ctx->MakeCPU();
SafeColl(collective::GlobalSum(&cpu, is_column_split, linalg::MakeVec(&n_total_samples, 1)));
auto column_it = dh::MakeTransformIterator<std::size_t>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto cidx = i / n_samples;
return cidx;
});
auto n_rows_f64 = static_cast<double>(n_total_samples);
auto val_it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> double {
auto cidx = i / n_samples;
auto ridx = i % n_samples;
return d_v(ridx, cidx) / n_rows_f64;
});
auto cuctx = ctx->CUDACtx();
thrust::reduce_by_key(cuctx->CTP(), column_it, column_it + d_v.Size(), val_it,
thrust::make_discard_iterator(), d_out.Values().data(),
thrust::equal_to<>{}, thrust::plus<double>{});
SafeColl(collective::GlobalSum(ctx, is_column_split, d_out));
}

void WeightedSampleMean(Context const* ctx, bool is_column_split,
linalg::MatrixView<float const> d_v, common::Span<float const> d_w,
linalg::VectorView<float> d_out) {
CHECK(d_v.CContiguous());
auto n_rows = d_v.Shape(0);
// The use of `cidx = i / n_rows` does not imply the input is column-major, it simply
// states the order of the reduction operator, and we want to reduce over the first
// dimension (rows). `thrust::reduce_by_key` requires all keys within the same reduction
// segment to be next to each other. `array(ridx, cidx)` can be used with any memory
// layout.
auto column_it = dh::MakeTransformIterator<std::size_t>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto cidx = i / n_rows;
return cidx;
});
auto cuctx = ctx->CUDACtx();
auto sum_w =
dh::Reduce(cuctx->CTP(), d_w.data(), d_w.data() + d_w.size(), 0.0, thrust::plus<double>{});
auto cpu = ctx->MakeCPU();
SafeColl(collective::GlobalSum(&cpu, is_column_split, linalg::MakeVec(&sum_w, 1)));
auto val_it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> double {
auto cidx = i / n_rows;
auto ridx = i % n_rows;
return d_v(ridx, cidx) * d_w(ridx) / sum_w;
});
thrust::reduce_by_key(cuctx->CTP(), column_it, column_it + d_v.Size(), val_it,
thrust::make_discard_iterator(), d_out.Values().data(),
thrust::equal_to<>{}, thrust::plus<double>{});
SafeColl(collective::GlobalSum(ctx, is_column_split, d_out));
}
} // namespace xgboost::common::cuda_impl
56 changes: 46 additions & 10 deletions src/common/stats.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_
Expand All @@ -8,13 +8,15 @@
#include <limits>
#include <vector>

#include "algorithm.h" // for StableSort
#include "common.h" // AssertGPUSupport, OptionalWeights
#include "optional_weight.h" // OptionalWeights
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h" // TensorView,VectorView
#include "xgboost/logging.h" // CHECK_GE
#include "algorithm.h" // for StableSort
#include "optional_weight.h" // OptionalWeights
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h" // TensorView,VectorView
#include "xgboost/logging.h" // CHECK_GE

#if !defined(XGBOOST_USE_CUDA)
#include "common.h" // AssertGPUSupport
#endif

namespace xgboost {
namespace common {
Expand Down Expand Up @@ -112,6 +114,13 @@ void Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWe

void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);

void SampleMean(Context const* ctx, bool is_column_split, linalg::MatrixView<float const> d_v,
linalg::VectorView<float> d_out);

void WeightedSampleMean(Context const* ctx, bool is_column_split,
linalg::MatrixView<float const> d_v, common::Span<float const> d_w,
linalg::VectorView<float> d_out);

#if !defined(XGBOOST_USE_CUDA)
inline void Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights,
linalg::Tensor<float, 1>*) {
Expand All @@ -120,16 +129,43 @@ inline void Median(Context const*, linalg::TensorView<float const, 2>, OptionalW
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
common::AssertGPUSupport();
}

inline void SampleMean(Context const*, bool, linalg::MatrixView<float const>,
linalg::VectorView<float>) {
common::AssertGPUSupport();
}

inline void WeightedSampleMean(Context const*, bool, linalg::MatrixView<float const>,
common::Span<float const>, linalg::VectorView<float>) {
common::AssertGPUSupport();
}

#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl

/**
* \brief Calculate medians for each column of the input matrix.
* @brief Calculate medians for each column of the input matrix.
*/
void Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
void Median(Context const* ctx, linalg::Matrix<float> const& t,
HostDeviceVector<float> const& weights, linalg::Tensor<float, 1>* out);

/**
* @brief Calculate the mean value of a vector.
*/
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);

/**
* @brief Calculate the mean value for the first axis.
*/
void SampleMean(Context const* ctx, bool is_column_split, linalg::Matrix<float> const& v,
linalg::Vector<float>* out);

/**
* @brief Calculate the weighted mean value for the first axis, weights are assumed to be
* equal to or greater than zero.
*/
void WeightedSampleMean(Context const* ctx, bool is_column_split, linalg::Matrix<float> const& v,
HostDeviceVector<float> const& w, linalg::Vector<float>* out);
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_STATS_H_
2 changes: 1 addition & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ class LearnerConfiguration : public Learner {
info.Validate(Ctx()->Device());
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
InitEstimation(info, &base_score);
this->InitEstimation(info, &base_score);
CHECK_EQ(base_score.Size(), 1);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
Expand Down
17 changes: 16 additions & 1 deletion src/objective/init_estimation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,24 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for
// Workaround, we don't support multi-target due to binary model serialization for
// base margin.
common::Mean(this->ctx_, leaf_weight, base_score);
this->PredTransform(base_score->Data());
}

void FitInterceptGlmLike::InitEstimation(MetaInfo const& info,
linalg::Vector<float>* base_score) const {
if (this->Task().task == ObjInfo::kRegression) {
CheckInitInputs(info);
}
linalg::Vector<float> out;
if (info.weights_.Empty()) {
common::SampleMean(this->ctx_, info.IsColumnSplit(), info.labels, &out);
} else {
common::WeightedSampleMean(this->ctx_, info.IsColumnSplit(), info.labels, info.weights_, &out);
}
common::Mean(this->ctx_, out, base_score);
CHECK_EQ(base_score->Size(), 1);
}
} // namespace xgboost::obj
Loading
Loading