Skip to content

Commit

Permalink
Use MPS backend if available and use_device=None, add `CHGNET_DEVIC…
Browse files Browse the repository at this point in the history
…E` env var (#131)

* use MPS backend if available and use_device is None

(prev would default to CPU in that case)
also fix type errors

* revert torch.det for volume to torch.dot and torch.cross (which have MPS support)

* try PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 in CI to avoid OOM error

E   RuntimeError: MPS backend out of memory (MPS allocated: 0 bytes, other allocations: 0 bytes, max allowed: 7.93 GB). Tried to allocate 512 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

* add support for CHGNET_DEVICE environment variable

* test CHGNET_DEVICE in test_model_load_version_params()

update deprecated ruff lint config

* set CHGNET_DEVICE=cpu in test.yml since no MPS hardware available even on macos-14 runners

pytorch/pytorch#111449 (comment)

* fix setting CHGNET_DEVICE env var on windows
  • Loading branch information
janosh authored Feb 28, 2024
1 parent 5a25b2c commit 3bc34b5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ jobs:
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml .
env:
CHGNET_DEVICE: cpu

- name: Codacy coverage reporter
if: ${{ matrix.os == 'ubuntu-latest' && github.event_name == 'push' }}
Expand Down
8 changes: 5 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import inspect
import io
import os
import pickle
import sys
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -51,7 +52,7 @@
class CHGNetCalculator(Calculator):
"""CHGNet Calculator for ASE applications."""

implemented_properties = ("energy", "forces", "stress", "magmoms")
implemented_properties = ("energy", "forces", "stress", "magmoms") # type: ignore

def __init__(
self,
Expand Down Expand Up @@ -81,7 +82,8 @@ def __init__(
super().__init__(**kwargs)

# Determine the device to use
if use_device == "mps" and torch.backends.mps.is_available():
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -95,7 +97,7 @@ def __init__(
print(f"CHGNet will run on {self.device}")

@property
def version(self) -> str:
def version(self) -> str | None:
"""The version of CHGNet."""
return self.model.version

Expand Down
10 changes: 5 additions & 5 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,8 @@ def load(
)

# Determine the device to use
if use_device == "mps" and torch.backends.mps.is_available():
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -763,7 +764,7 @@ class BatchedGraph:
directed2undirected: Tensor
atom_positions: Sequence[Tensor]
strains: Sequence[Tensor]
volumes: Sequence[Tensor]
volumes: Sequence[Tensor] | Tensor

@classmethod
def from_graphs(
Expand All @@ -790,8 +791,7 @@ def from_graphs(
batched_atom_graph, batched_bond_graph = [], []
directed2undirected = []
atom_owners = []
atom_offset_idx = 0
n_undirected = 0
atom_offset_idx = n_undirected = 0

for graph_idx, graph in enumerate(graphs):
# Atoms
Expand All @@ -807,7 +807,7 @@ def from_graphs(
else:
strain = None
lattice = graph.lattice
volumes.append(torch.det(lattice))
volumes.append(torch.dot(lattice[0], torch.cross(lattice[1], lattice[2])))
strains.append(strain)

# Bonds
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
[tool.ruff]
target-version = "py39"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
Expand Down Expand Up @@ -94,6 +95,7 @@ ignore = [
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"E731", # do not assign a lambda expression, use a def
"EM",
"ERA001", # found commented out code
"FBT001", # Boolean positional argument in function
Expand All @@ -114,7 +116,7 @@ pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"site/*" = ["INP001", "S602"]
"tests/*" = ["ANN201", "D103", "INP001", "S101"]
# E402 Module level import not at top of file
Expand Down
38 changes: 23 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,29 +220,37 @@ def test_as_to_from_dict() -> None:
assert model_3.todict() == to_dict


def test_model_load_version_params(capsys: pytest.CaptureFixture) -> None:
def test_model_load_version_params(
capsys: pytest.CaptureFixture, monkeypatch: pytest.MonkeyPatch
) -> None:
model = CHGNet.load(use_device="cpu")
assert model.version == "0.3.0"
assert model.n_params == 412_525
v030_key, v030_params = "0.3.0", 412_525
assert model.version == v030_key
assert model.n_params == v030_params
stdout, stderr = capsys.readouterr()
assert (
stdout
== f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters
CHGNet will run on cpu\n"""
expected_stdout = lambda version, params: (
f"CHGNet v{version} initialized with {params:,} parameters\n"
"CHGNet will run on cpu\n"
)
assert stdout == expected_stdout(v030_key, v030_params)
assert stderr == ""

model = CHGNet.load(model_name="0.2.0", use_device="cpu")
assert model.version == "0.2.0"
assert model.n_params == 400_438
v020_key, v020_params = "0.2.0", 400_438
model = CHGNet.load(model_name=v020_key, use_device="cpu")
assert model.version == v020_key
assert model.n_params == v020_params
stdout, stderr = capsys.readouterr()
assert (
stdout
== f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters
CHGNet will run on cpu\n"""
)
assert stdout == expected_stdout(v020_key, v020_params)
assert stderr == ""

model_name = "0.1.0" # invalid
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

# # set CHGNET_DEVICE to "cuda" and test
monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar")
with pytest.raises(
RuntimeError,
match=f"Expected one of cpu, .+type at start of device string: {env_device}",
):
CHGNet.load()

0 comments on commit 3bc34b5

Please sign in to comment.