Skip to content

Commit

Permalink
[EM] Small fixes for the example (#10929)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 25, 2024
1 parent 18edf86 commit 6fef44d
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions demo/guide-python/distributed_extmem_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
- cupy
- rmm
- python-cuda
"""

Expand Down Expand Up @@ -100,7 +99,6 @@ def reset(self) -> None:
def setup_rmm() -> None:
"""Setup RMM for GPU-based external memory training."""
import rmm
from cuda import cudart
from rmm.allocators.cupy import rmm_cupy_allocator

if not xgboost.build_info()["USE_RMM"]:
Expand All @@ -119,11 +117,10 @@ def hist_train(worker_idx: int, tmpdir: str, device: str, rabit_args: dict) -> N
"""The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for
faster initialization and lower memory usage.
.. versionadded:: 3.0.0
"""

with coll.CommunicatorContext(**rabit_args):
# Make sure XGBoost is using RMM for all allocations.
with coll.CommunicatorContext(**rabit_args), xgboost.config_context(use_rmm=True):
# Generate the data for demonstration. The sythetic data is sharded by workers.
files = make_batches(
n_samples_per_batch=4096,
Expand Down Expand Up @@ -168,9 +165,11 @@ def main(tmpdir: str, args: argparse.Namespace) -> None:
def initializer(device: str) -> None:
# Set CUDA device before launching child processes.
if device == "cuda":
# name: LokyProcess-1
lop, sidx = mp.current_process().name.split("-")
idx = int(sidx) # 1-based indexing from loky
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(sidx) - 1)
os.environ["CUDA_VISIBLE_DEVICES"] = str(idx - 1)
setup_rmm()

with get_reusable_executor(
max_workers=n_workers, initargs=(args.device,), initializer=initializer
Expand All @@ -196,10 +195,8 @@ def initializer(device: str) -> None:
# external memory to improve performance. If XGBoost is not built with RMM
# support, a warning is raised when constructing the `DMatrix`.
setup_rmm()
# Make sure XGBoost is using RMM for all allocations.
with xgboost.config_context(use_rmm=True):
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args)
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args)
else:
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args)

0 comments on commit 6fef44d

Please sign in to comment.