diff --git a/demo/guide-python/distributed_extmem_basic.py b/demo/guide-python/distributed_extmem_basic.py index 3207572bc8e9..3d04d46ac8b7 100644 --- a/demo/guide-python/distributed_extmem_basic.py +++ b/demo/guide-python/distributed_extmem_basic.py @@ -117,11 +117,10 @@ def hist_train(worker_idx: int, tmpdir: str, device: str, rabit_args: dict) -> N """The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for faster initialization and lower memory usage. - .. versionadded:: 3.0.0 - """ - with coll.CommunicatorContext(**rabit_args): + # Make sure XGBoost is using RMM for all allocations. + with coll.CommunicatorContext(**rabit_args), xgboost.config_context(use_rmm=True): # Generate the data for demonstration. The sythetic data is sharded by workers. files = make_batches( n_samples_per_batch=4096, @@ -195,10 +194,8 @@ def initializer(device: str) -> None: # external memory to improve performance. If XGBoost is not built with RMM # support, a warning is raised when constructing the `DMatrix`. setup_rmm() - # Make sure XGBoost is using RMM for all allocations. - with xgboost.config_context(use_rmm=True): - with tempfile.TemporaryDirectory() as tmpdir: - main(tmpdir, args) + with tempfile.TemporaryDirectory() as tmpdir: + main(tmpdir, args) else: with tempfile.TemporaryDirectory() as tmpdir: main(tmpdir, args)