diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index 5cfd3339f..f3233e9b0 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -24,6 +24,11 @@ ) +# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail +# Therefore we use spawn for HIP runtime until AMD fixes the issue +_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn" + + class MultiProcessContext: def __init__( self, @@ -126,7 +131,7 @@ def _run_multi_process_test( # pyre-ignore **kwargs, ) -> None: - ctx = multiprocessing.get_context("forkserver") + ctx = multiprocessing.get_context(_MP_INIT_MODE) processes = [] for rank in range(world_size): kwargs["rank"] = rank @@ -152,7 +157,7 @@ def _run_multi_process_test_per_rank( world_size: int, kwargs_per_rank: List[Dict[str, Any]], ) -> None: - ctx = multiprocessing.get_context("forkserver") + ctx = multiprocessing.get_context(_MP_INIT_MODE) processes = [] for rank in range(world_size): kwargs = {}