From 6fef44dbf9404663cccf00a0389ddab7d2d4d0c7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 26 Oct 2024 00:25:25 +0800 Subject: [PATCH] [EM] Small fixes for the example (#10929) --- demo/guide-python/distributed_extmem_basic.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/demo/guide-python/distributed_extmem_basic.py b/demo/guide-python/distributed_extmem_basic.py index 007bc733c965..2ee9b33f6684 100644 --- a/demo/guide-python/distributed_extmem_basic.py +++ b/demo/guide-python/distributed_extmem_basic.py @@ -14,7 +14,6 @@ - cupy - rmm -- python-cuda """ @@ -100,7 +99,6 @@ def reset(self) -> None: def setup_rmm() -> None: """Setup RMM for GPU-based external memory training.""" import rmm - from cuda import cudart from rmm.allocators.cupy import rmm_cupy_allocator if not xgboost.build_info()["USE_RMM"]: @@ -119,11 +117,10 @@ def hist_train(worker_idx: int, tmpdir: str, device: str, rabit_args: dict) -> N """The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for faster initialization and lower memory usage. - .. versionadded:: 3.0.0 - """ - with coll.CommunicatorContext(**rabit_args): + # Make sure XGBoost is using RMM for all allocations. + with coll.CommunicatorContext(**rabit_args), xgboost.config_context(use_rmm=True): # Generate the data for demonstration. The sythetic data is sharded by workers. files = make_batches( n_samples_per_batch=4096, @@ -168,9 +165,11 @@ def main(tmpdir: str, args: argparse.Namespace) -> None: def initializer(device: str) -> None: # Set CUDA device before launching child processes. if device == "cuda": + # name: LokyProcess-1 lop, sidx = mp.current_process().name.split("-") idx = int(sidx) # 1-based indexing from loky - os.environ["CUDA_VISIBLE_DEVICES"] = str(int(sidx) - 1) + os.environ["CUDA_VISIBLE_DEVICES"] = str(idx - 1) + setup_rmm() with get_reusable_executor( max_workers=n_workers, initargs=(args.device,), initializer=initializer @@ -196,10 +195,8 @@ def initializer(device: str) -> None: # external memory to improve performance. If XGBoost is not built with RMM # support, a warning is raised when constructing the `DMatrix`. setup_rmm() - # Make sure XGBoost is using RMM for all allocations. - with xgboost.config_context(use_rmm=True): - with tempfile.TemporaryDirectory() as tmpdir: - main(tmpdir, args) + with tempfile.TemporaryDirectory() as tmpdir: + main(tmpdir, args) else: with tempfile.TemporaryDirectory() as tmpdir: main(tmpdir, args)