From ce64ac99e4604fd7e6a0099b55d503748a719376 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 16 Oct 2024 15:31:16 +0800 Subject: [PATCH] Rename. --- doc/tutorials/external_memory.rst | 26 ++++++---- include/xgboost/c_api.h | 49 ++++++++++--------- python-package/xgboost/core.py | 13 ++--- src/common/error_msg.h | 2 +- src/tree/hist/param.h | 4 +- src/tree/updater_approx.cc | 2 +- src/tree/updater_gpu_hist.cu | 2 +- src/tree/updater_quantile_hist.cc | 2 +- .../gpu_hist/test_gradient_based_sampler.cu | 2 +- tests/cpp/tree/test_gpu_hist.cu | 18 +++---- tests/python-gpu/test_gpu_data_iterator.py | 2 +- 11 files changed, 68 insertions(+), 54 deletions(-) diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index 4445abc0791b..a77044455011 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -120,8 +120,11 @@ the ``hist`` tree method is employed. For a GPU device, the main memory is the d memory, whereas the external memory can be either a disk or the CPU memory. XGBoost stages the cache on CPU memory by default. Users can change the backing storage to disk by specifying the ``on_host`` parameter in the :py:class:`~xgboost.DataIter`. However, using -the disk is not recommended. It's likely to make the GPU slower than the CPU. The option is -here for experimental purposes only. +the disk is not recommended as it's likely to make the GPU slower than the CPU. The option +is here for experimental purposes only. In addition, +:py:class:`~xgboost.ExtMemQuantileDMatrix` parameters ``max_num_device_pages``, +``min_cache_page_bytes``, and ``max_quantile_batches`` can help control the data placement +and memory usage. Inputs to the :py:class:`~xgboost.ExtMemQuantileDMatrix` (through the iterator) must be on the GPU. This is a current limitation we aim to address in the future. @@ -157,12 +160,17 @@ the GPU. This is a current limitation we aim to address in the future. evals=[(Xy_train, "Train"), (Xy_valid, "Valid")] ) -It's crucial to use `RAPIDS Memory Manager (RMM) `__ for -all memory allocation when training with external memory. XGBoost relies on the memory -pool to reduce the overhead for data fetching. In addition, the open source `NVIDIA Linux -driver +It's crucial to use `RAPIDS Memory Manager (RMM) `__ with +an asynchronous memory resource for all memory allocation when training with external +memory. XGBoost relies on the asynchronous memory pool to reduce the overhead of data +fetching. In addition, the open source `NVIDIA Linux driver `__ -is required for ``Heterogeneous memory management (HMM)`` support. +is required for ``Heterogeneous memory management (HMM)`` support. Usually, users need not +to change :py:class:`~xgboost.ExtMemQuantileDMatrix` parameters ``max_num_device_pages`` +and ``min_cache_page_bytes``, they are automatically configured based on the device and +don't change model accuracy. However, the ``max_quantile_batches`` can be useful if +:py:class:`~xgboost.ExtMemQuantileDMatrix` is running out of device memory during +construction, see :py:class:`~xgboost.QuantileDMatrix` for more info. In addition to the batch-based data fetching, the GPU version supports concatenating batches into a single blob for the training data to improve performance. For GPUs @@ -181,7 +189,7 @@ concatenation can be enabled by: param = { "device": "cuda", - "extmem_concat_pages": true, + "extmem_single_page": true, 'subsample': 0.2, 'sampling_method': 'gradient_based', } @@ -200,7 +208,7 @@ interconnect between the CPU and the GPU. With the host memory serving as the da XGBoost can retrieve data with significantly lower overhead. When the input data is dense, there's minimal to no performance loss for training, except for the initial construction of the :py:class:`~xgboost.ExtMemQuantileDMatrix`. The initial construction iterates -through the input data twice, as a result, the most significantly overhead compared to +through the input data twice, as a result, the most significant overhead compared to in-core training is one additional data read when the data is dense. Please note that there are multiple variants of the platform and they come with different C2C bandwidths. During initial development of the feature, we used the LPDDR5 480G version, diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index b26e984fd0b8..89a8b12e928f 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -308,35 +308,40 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const * * used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated * input, and concatenate them into 1 final big CSR. The related functions are: * - * - \ref XGBCallbackSetData - * - \ref XGBCallbackDataIterNext - * - \ref XGDMatrixCreateFromDataIter + * - @ref XGBCallbackSetData + * - @ref XGBCallbackDataIterNext + * - @ref XGDMatrixCreateFromDataIter * - * Another set is used by external data iterator. It accept foreign data iterators as + * Another set is used by external data iterator. It accepts foreign data iterators as * callbacks. There are 2 different senarios where users might want to pass in callbacks - * instead of raw data. First it's the Quantile DMatrix used by hist and GPU Hist. For - * this case, the data is first compressed by quantile sketching then merged. This is - * particular useful for distributed setting as it eliminates 2 copies of data. 1 by a - * `concat` from external library to make the data into a blob for normal DMatrix - * initialization, another by the internal CSR copy of DMatrix. The second use case is - * external memory support where users can pass a custom data iterator into XGBoost for - * loading data in batches. There are short notes on each of the use cases in respected - * DMatrix factory function. + * instead of raw data. First it's the Quantile DMatrix used by the hist and GPU-based + * hist tree method. For this case, the data is first compressed by quantile sketching + * then merged. This is particular useful for distributed setting as it eliminates 2 + * copies of data. First one by a `concat` from external library to make the data into a + * blob for normal DMatrix initialization, another one by the internal CSR copy of + * DMatrix. + * + * The second use case is external memory support where users can pass a custom data + * iterator into XGBoost for loading data in batches. For both cases, the iterator is only + * used during the construction of the DMatrix and can be safely freed after construction + * finishes. There are short notes on each of the use cases in respected DMatrix factory + * function. * * Related functions are: * * # Factory functions - * - \ref XGDMatrixCreateFromCallback for external memory - * - \ref XGQuantileDMatrixCreateFromCallback for quantile DMatrix + * - @ref XGDMatrixCreateFromCallback for external memory + * - @ref XGQuantileDMatrixCreateFromCallback for quantile DMatrix + * - @ref XGExtMemQuantileDMatrixCreateFromCallback for External memory Quantile DMatrix * * # Proxy that callers can use to pass data to XGBoost - * - \ref XGProxyDMatrixCreate - * - \ref XGDMatrixCallbackNext - * - \ref DataIterResetCallback - * - \ref XGProxyDMatrixSetDataCudaArrayInterface - * - \ref XGProxyDMatrixSetDataCudaColumnar - * - \ref XGProxyDMatrixSetDataDense - * - \ref XGProxyDMatrixSetDataCSR + * - @ref XGProxyDMatrixCreate + * - @ref XGDMatrixCallbackNext + * - @ref DataIterResetCallback + * - @ref XGProxyDMatrixSetDataCudaArrayInterface + * - @ref XGProxyDMatrixSetDataCudaColumnar + * - @ref XGProxyDMatrixSetDataDense + * - @ref XGProxyDMatrixSetDataCSR * - ... (data setters) * * @{ @@ -515,7 +520,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand * * @since 3.0.0 * - * @note This is still under development, not ready for test yet. + * @note This is experimental and subject to change. * * @param iter A handle to external data iterator. * @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f2d90e01a5ee..f40aec6ddc36 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1574,12 +1574,13 @@ class QuantileDMatrix(DMatrix): applied to the validation/test data max_quantile_batches : - For GPU-based inputs, XGBoost handles incoming batches with multiple growing - substreams. This parameter sets the maximum number of batches before XGBoost can - cut the sub-stream and create a new one. This can help bound the memory - usage. By default, XGBoost grows new sub-streams exponentially until batches are - exhausted. Only used for the training dataset and the default is None - (unbounded). + For GPU-based inputs from an iterator, XGBoost handles incoming batches with + multiple growing substreams. This parameter sets the maximum number of batches + before XGBoost can cut the sub-stream and create a new one. This can help bound + the memory usage. By default, XGBoost grows new sub-streams exponentially until + batches are exhausted. Only used for the training dataset and the default is + None (unbounded). Lastly, if the `data` is a single batch instead of an + iterator, this parameter has no effect. .. versionadded:: 3.0.0 diff --git a/src/common/error_msg.h b/src/common/error_msg.h index c2ee4a0589b3..7cd16232a065 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -108,7 +108,7 @@ inline auto NoCategorical(std::string name) { inline void NoPageConcat(bool concat_pages) { if (concat_pages) { - LOG(FATAL) << "`extmem_concat_pages` must be false when there's no sampling or when it's " + LOG(FATAL) << "`extmem_single_page` must be false when there's no sampling or when it's " "running on the CPU."; } } diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h index e06eff027cd3..53e79f0da2f7 100644 --- a/src/tree/hist/param.h +++ b/src/tree/hist/param.h @@ -23,7 +23,7 @@ struct HistMakerTrainParam : public XGBoostParameter { constexpr static std::size_t CudaDefaultNodes() { return static_cast(1) << 12; } bool debug_synchronize{false}; - bool extmem_concat_pages{false}; + bool extmem_single_page{false}; void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const; @@ -43,7 +43,7 @@ struct HistMakerTrainParam : public XGBoostParameter { .set_default(NotSet()) .set_lower_bound(1) .describe("Maximum number of nodes in histogram cache."); - DMLC_DECLARE_FIELD(extmem_concat_pages).set_default(false); + DMLC_DECLARE_FIELD(extmem_single_page).set_default(false); } }; } // namespace xgboost::tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 51c8a5b21f65..fa34e9829c2f 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -278,7 +278,7 @@ class GlobalApproxUpdater : public TreeUpdater { *sampled = linalg::Empty(ctx_, gpair->Size(), 1); auto in = gpair->HostView().Values(); std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data()); - error::NoPageConcat(this->hist_param_.extmem_concat_pages); + error::NoPageConcat(this->hist_param_.extmem_single_page); SampleGradient(ctx_, param, sampled->HostView()); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index ab4ccd05a47c..9ed0bd409e30 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -162,7 +162,7 @@ struct GPUHistMakerDevice { interaction_constraints(param, static_cast(info.num_col_)), sampler{std::make_unique( ctx, info.num_row_, batch_param, param.subsample, param.sampling_method, - batch_ptr_.size() > 2 && this->hist_param_->extmem_concat_pages)} { + batch_ptr_.size() > 2 && this->hist_param_->extmem_single_page)} { if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index bafe525913be..277c844162dd 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -539,7 +539,7 @@ class QuantileHistMaker : public TreeUpdater { // Copy gradient into buffer for sampling. This converts C-order to F-order. std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out)); } - error::NoPageConcat(this->hist_param_.extmem_concat_pages); + error::NoPageConcat(this->hist_param_.extmem_single_page); SampleGradient(ctx_, *param, h_sample_out); auto *h_out_position = &out_position[tree_it - trees.begin()]; if ((*tree_it)->IsMultiTarget()) { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 45b3f7967e7a..fc51fa99ae17 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -85,7 +85,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { [&] { GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); }, - GMockThrow("extmem_concat_pages")); + GMockThrow("extmem_single_page")); } TEST(GradientBasedSampler, UniformSampling) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index f997573db81b..d0f546e6134e 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -39,7 +39,7 @@ void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix ObjInfo task{ObjInfo::kRegression}; std::unique_ptr hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)}; if (subsample < 1.0) { - hist_maker->Configure(Args{{"extmem_concat_pages", std::to_string(concat_pages)}}); + hist_maker->Configure(Args{{"extmem_single_page", std::to_string(concat_pages)}}); } else { hist_maker->Configure(Args{}); } @@ -240,31 +240,31 @@ TEST(GpuHist, PageConcatConfig) { auto learner = std::unique_ptr(Learner::Create({p_fmat})); learner->SetParam("device", ctx.DeviceName()); - learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("extmem_single_page", "true"); learner->SetParam("subsample", "0.8"); learner->Configure(); learner->UpdateOneIter(0, p_fmat); - learner->SetParam("extmem_concat_pages", "false"); + learner->SetParam("extmem_single_page", "false"); learner->Configure(); // GPU Hist rebuilds the updater after configuration. Training continues learner->UpdateOneIter(1, p_fmat); - learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("extmem_single_page", "true"); learner->SetParam("subsample", "1.0"); - ASSERT_THAT([&] { learner->UpdateOneIter(2, p_fmat); }, GMockThrow("extmem_concat_pages")); + ASSERT_THAT([&] { learner->UpdateOneIter(2, p_fmat); }, GMockThrow("extmem_single_page")); // Throws error on CPU. { auto learner = std::unique_ptr(Learner::Create({p_fmat})); - learner->SetParam("extmem_concat_pages", "true"); - ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages")); + learner->SetParam("extmem_single_page", "true"); + ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_single_page")); } { auto learner = std::unique_ptr(Learner::Create({p_fmat})); - learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("extmem_single_page", "true"); learner->SetParam("tree_method", "approx"); - ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages")); + ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_single_page")); } } diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 1d565152cd38..3cb6a1dd630f 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -115,7 +115,7 @@ def test_concat_pages_invalid() -> None: "device": "cuda", "subsample": 0.5, "sampling_method": "gradient_based", - "extmem_concat_pages": True, + "extmem_single_page": True, "objective": "reg:absoluteerror", }, Xy,