Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jan 4, 2024
1 parent ec7a8ad commit 5ded810
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_xla.distributed.xla_backend as xbn
from _pytest.fixtures import FixtureLookupError
from _pytest.outcomes import Skipped

from optimum.neuron.utils.cache_utils import get_num_neuron_cores
from optimum.neuron.utils.import_utils import is_neuronx_distributed_available
from optimum.neuron.utils.import_utils import is_neuronx_distributed_available, is_torch_xla_available


if is_torch_xla_available():
import torch_xla.distributed.xla_backend as xbn

if is_neuronx_distributed_available():
import neuronx_distributed
Expand Down Expand Up @@ -123,8 +126,10 @@ def _get_fixture_kwargs(self, request, func):
return fixture_kwargs

def _launch_procs(self, num_procs, tp_size, pp_size):
if not is_neuronx_distributed_available():
raise RuntimeError("The `neuronx_distributed` package is required to run a distributed test.")
if not is_torch_xla_available() or not is_neuronx_distributed_available():
raise RuntimeError(
"The `torch_xla` and `neuronx_distributed` packages are required to run a distributed test."
)

# Verify we have enough accelerator devices to run this test
num_cores = get_num_neuron_cores()
Expand Down

0 comments on commit 5ded810

Please sign in to comment.