Skip to content

Commit

Permalink
[MNT] handle mps backend for lower versions of pytorch and fix `mps…
Browse files Browse the repository at this point in the history
…` failure on `macOS-latest` runner (#1648)

### Description

This PR handles the issue that may result when setting the device to `mps` if the torch version doesn't support `mps backend` 

Depends on #1633

I used `pytest.MonkeyPatch()` to disable the discovery of the `mps` accelerator. The tests run on CPU for `macOS-latest`

fixes #1596
  • Loading branch information
fnhirwa authored Sep 13, 2024
1 parent b497a6b commit f233d92
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-13, windows-latest]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
Expand Down
9 changes: 8 additions & 1 deletion pytorch_forecasting/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,14 @@ def move_to_device(
x on targeted device
"""
if isinstance(device, str):
device = torch.device(device)
if device == "mps":
if hasattr(torch.backends, device):
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(device)
if isinstance(x, dict):
for name in x.keys():
x[name] = move_to_device(x[name], device=device)
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,9 @@ def test_dataset(test_data):
randomize_length=None,
)
return training


@pytest.fixture(autouse=True)
def disable_mps(monkeypatch):
"""Disable MPS for all tests"""
monkeypatch.setattr("torch._C._mps_is_available", lambda: False)

0 comments on commit f233d92

Please sign in to comment.