Skip to content

Commit

Permalink
[EM] Concatenate ellpack pages for ExtMemQdm. (#10887)
Browse files Browse the repository at this point in the history
- Optional page concat for the host cache.
- New parameter to control the cache.
  • Loading branch information
trivialfis authored Oct 14, 2024
1 parent 78b82e4 commit 3f9bfaf
Show file tree
Hide file tree
Showing 38 changed files with 746 additions and 310 deletions.
32 changes: 23 additions & 9 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,35 @@ def main(tmpdir: str, args: argparse.Namespace) -> None:
approx_train(it)


def setup_rmm() -> None:
"""Setup RMM for GPU-based external memory training."""
import rmm
from cuda import cudart
from rmm.allocators.cupy import rmm_cupy_allocator

if not xgboost.build_info()["USE_RMM"]:
return

# The combination of pool and async is by design. As XGBoost needs to allocate large
# pages repeatly, it's not easy to handle fragmentation. We can use more experiments
# here.
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
args = parser.parse_args()
if args.device == "cuda":
import cupy as cp
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator

# It's important to use RMM for GPU-based external memory to improve performance.
# If XGBoost is not built with RMM support, a warning will be raised.
mr = rmm.mr.CudaAsyncMemoryResource()
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)

# It's important to use RMM with `CudaAsyncMemoryResource`. for GPU-based
# external memory to improve performance. If XGBoost is not built with RMM
# support, a warning is raised when constructing the `DMatrix`.
setup_rmm()
# Make sure XGBoost is using RMM for all allocations.
with xgboost.config_context(use_rmm=True):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
14 changes: 8 additions & 6 deletions doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ the GPU. This is a current limitation we aim to address in the future.
# It's important to use RMM for GPU-based external memory to improve performance.
# If XGBoost is not built with RMM support, a warning will be raised.
mr = rmm.mr.CudaAsyncMemoryResource()
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)
Expand All @@ -159,9 +159,8 @@ the GPU. This is a current limitation we aim to address in the future.
It's crucial to use `RAPIDS Memory Manager (RMM) <https://github.com/rapidsai/rmm>`__ for
all memory allocation when training with external memory. XGBoost relies on the memory
pool to reduce the overhead for data fetching. The size of each batch should be slightly
smaller than a quarter of the available GPU memory. In addition, the open source `NVIDIA
Linux driver
pool to reduce the overhead for data fetching. In addition, the open source `NVIDIA Linux
driver
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
is required for ``Heterogeneous memory management (HMM)`` support.

Expand Down Expand Up @@ -200,9 +199,12 @@ The newer NVIDIA platforms like `Grace-Hopper
interconnect between the CPU and the GPU. With the host memory serving as the data cache,
XGBoost can retrieve data with significantly lower overhead. When the input data is dense,
there's minimal to no performance loss for training, except for the initial construction
of the :py:class:`~xgboost.ExtMemQuantileDMatrix`. The initial construction iterates
of the :py:class:`~xgboost.ExtMemQuantileDMatrix`. The initial construction iterates
through the input data twice, as a result, the most significantly overhead compared to
in-core training is one additional data read when the data is dense.
in-core training is one additional data read when the data is dense. Please note that
there are multiple variants of the platform and they come with different C2C
bandwidths. During initial development of the feature, we used the LPDDR5 480G version,
which has about 350GB/s bandwidth for host to device transfer.

To run experiments on these platforms, the open source `NVIDIA Linux driver
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
Expand Down
3 changes: 3 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
the corresponding booster training parameter.
* - on_host (optional): Whether the data should be placed on host memory. Used by GPU inputs.
* - min_cache_page_bytes (optional): The minimum number of bytes for each internal GPU
* page. Set to 0 to disable page concatenation. Automatic configuration if the
* parameter is not provided or set to None.
* @param out The created Quantile DMatrix.
*
* @return 0 when success, -1 when failure happens
Expand Down
26 changes: 18 additions & 8 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,20 @@ class BatchSet {

struct XGBAPIThreadLocalEntry;

struct ExtMemConfig {
// Cache prefix, not used if the cache is in the host memory. (on_host is true)
std::string cache;
// Whether the ellpack page is stored in the host memory.
bool on_host{true};
// Minimum number of of bytes for each ellpack page in cache. Only used for in-host
// ExtMemQdm.
std::int64_t min_cache_page_bytes{0};
// Missing value.
float missing{std::numeric_limits<float>::quiet_NaN()};
// The number of CPU threads.
std::int32_t n_threads{0};
};

/**
* @brief Internal data structured used by XGBoost to hold all external data.
*
Expand Down Expand Up @@ -637,18 +651,14 @@ class DMatrix {
* @param proxy A hanlde to ProxyDMatrix
* @param reset Callback for reset
* @param next Callback for next
* @param missing Value that should be treated as missing.
* @param nthread number of threads used for initialization.
* @param cache Prefix of cache file path.
* @param on_host Used for GPU, whether the data should be cached on host memory.
* @param config Configuration for the cache.
*
* @return A created external memory DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, std::int32_t nthread,
std::string cache, bool on_host);
XGDMatrixCallbackNext* next, ExtMemConfig const& config);

/**
* @brief Create an external memory quantile DMatrix with callbacks.
Expand All @@ -660,8 +670,8 @@ class DMatrix {
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host);
DataIterResetCallback* reset, XGDMatrixCallbackNext* next,
bst_bin_t max_bin, ExtMemConfig const& config);

virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;

Expand Down
20 changes: 20 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,16 +536,34 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
This is an experimental parameter.
min_cache_page_bytes :
The minimum number of bytes of each cached pages. Only used for on-host cache
with GPU-based :py:class:`ExtMemQuantileDMatrix`. When using GPU-based external
memory with the data cached in the host memory, XGBoost can concatenate the
pages internally to increase the batch size for the GPU. The default page size
is about 1/8 of the total device memory. Users can manually set the value based
on the actual hardware and datasets. Set this to 0 to disable page
concatenation.
.. versionadded:: 3.0.0
.. warning::
This is an experimental parameter.
"""

def __init__(
self,
cache_prefix: Optional[str] = None,
release_data: bool = True,
*,
on_host: bool = True,
min_cache_page_bytes: Optional[int] = None,
) -> None:
self.cache_prefix = cache_prefix
self.on_host = on_host
self.min_cache_page_bytes = min_cache_page_bytes

self._handle = _ProxyDMatrix()
self._exception: Optional[Exception] = None
Expand Down Expand Up @@ -940,6 +958,7 @@ def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
nthread=self.nthread,
cache_prefix=it.cache_prefix if it.cache_prefix else "",
on_host=it.on_host,
min_cache_page_bytes=it.min_cache_page_bytes,
)
handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(enable_categorical)
Expand Down Expand Up @@ -1727,6 +1746,7 @@ def _init(
cache_prefix=it.cache_prefix if it.cache_prefix else "",
on_host=it.on_host,
max_bin=self.max_bin,
min_cache_page_bytes=it.min_cache_page_bytes,
)
handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(enable_categorical)
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,18 @@ def __init__( # pylint: disable=too-many-arguments
*,
cache: Optional[str],
on_host: bool = False,
min_cache_page_bytes: Optional[int] = None,
) -> None:
assert len(X) == len(y)
self.X = X
self.y = y
self.w = w
self.it = 0
super().__init__(cache_prefix=cache, on_host=on_host)
super().__init__(
cache_prefix=cache,
on_host=on_host,
min_cache_page_bytes=min_cache_page_bytes,
)

def next(self, input_data: Callable) -> bool:
if self.it == len(self.X):
Expand Down
101 changes: 55 additions & 46 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,48 @@
*/
#include "xgboost/c_api.h"

#include <algorithm> // for copy, transform
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
#include <string> // for char_traits, basic_string, operator==, string
#include <system_error> // for errc
#include <utility> // for pair
#include <vector> // for vector

#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/error_msg.h" // for NoFederated
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/ellpack_page.h" // for EllpackPage
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
#include "xgboost/feature_map.h" // for FeatureMap
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
#include "xgboost/learner.h" // for Learner, PredictionType
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
#include "xgboost/predictor.h" // for PredictionCacheEntry
#include "xgboost/span.h" // for Span
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
#include <algorithm> // for copy, transform
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
#include <limits> // for numeric_limits
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
#include <string> // for char_traits, basic_string, operator==, string
#include <system_error> // for errc
#include <utility> // for pair
#include <vector> // for vector

#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/error_msg.h" // for NoFederated
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/batch_utils.h" // for MatchingPageBytes, CachePageRatio
#include "../data/ellpack_page.h" // for EllpackPage
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
#include "xgboost/feature_map.h" // for FeatureMap
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
#include "xgboost/learner.h" // for Learner, PredictionType
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
#include "xgboost/predictor.h" // for PredictionCacheEntry
#include "xgboost/span.h" // for Span
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...

using namespace xgboost; // NOLINT(*);

Expand Down Expand Up @@ -296,15 +297,20 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
std::int32_t n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
auto min_cache_page_bytes = OptionalArg<Integer, std::int64_t>(jconfig, "min_cache_page_bytes",
cuda_impl::MatchingPageBytes());
CHECK_EQ(min_cache_page_bytes, cuda_impl::MatchingPageBytes())
<< "Page concatenation is not supported by the DMatrix yet.";

xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);

auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads};
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)};
xgboost::DMatrix::Create(iter, proxy, reset, next, config)};
API_END();
}

Expand Down Expand Up @@ -368,17 +374,20 @@ XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto missing = GetMissing(jconfig);
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
std::int32_t n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
auto min_cache_page_bytes = OptionalArg<Integer, std::int64_t>(jconfig, "min_cache_page_bytes",
cuda_impl::AutoCachePageBytes());

xgboost_CHECK_C_ARG_PTR(next);
xgboost_CHECK_C_ARG_PTR(reset);
xgboost_CHECK_C_ARG_PTR(out);

*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads};
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, max_bin, config)};
API_END();
}

Expand Down
10 changes: 10 additions & 0 deletions src/common/cuda_rt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cuda_runtime_api.h>
#endif // defined(XGBOOST_USE_CUDA)

#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <mutex> // for once_flag, call_once

Expand Down Expand Up @@ -65,6 +66,13 @@ void SetDevice(std::int32_t device) {
}
}

[[nodiscard]] std::size_t TotalMemory() {
std::size_t device_free = 0;
std::size_t device_total = 0;
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
return device_total;
}

namespace {
template <typename Fn>
void GetVersionImpl(Fn&& fn, std::int32_t* major, std::int32_t* minor) {
Expand Down Expand Up @@ -101,6 +109,8 @@ bool SupportsPageableMem() { return false; }

bool SupportsAts() { return false; }

[[nodiscard]] std::size_t TotalMemory() { return 0; }

void CheckComputeCapability() {}

void SetDevice(std::int32_t device) {
Expand Down
Loading

0 comments on commit 3f9bfaf

Please sign in to comment.