diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index af4e50e2..9e176e87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,6 +44,8 @@ jobs: - name: Run Tests run: pytest --capture=no --cov --cov-report=xml . + env: + CHGNET_DEVICE: cpu - name: Codacy coverage reporter if: ${{ matrix.os == 'ubuntu-latest' && github.event_name == 'push' }} diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 35a8819b..23127c9d 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -3,6 +3,7 @@ import contextlib import inspect import io +import os import pickle import sys from typing import TYPE_CHECKING, Literal @@ -51,7 +52,7 @@ class CHGNetCalculator(Calculator): """CHGNet Calculator for ASE applications.""" - implemented_properties = ("energy", "forces", "stress", "magmoms") + implemented_properties = ("energy", "forces", "stress", "magmoms") # type: ignore def __init__( self, @@ -81,7 +82,8 @@ def __init__( super().__init__(**kwargs) # Determine the device to use - if use_device == "mps" and torch.backends.mps.is_available(): + use_device = use_device or os.getenv("CHGNET_DEVICE") + if use_device in ("mps", None) and torch.backends.mps.is_available(): self.device = "mps" else: self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu") @@ -95,7 +97,7 @@ def __init__( print(f"CHGNet will run on {self.device}") @property - def version(self) -> str: + def version(self) -> str | None: """The version of CHGNet.""" return self.model.version diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 8ee1f481..f1efad7f 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -707,7 +707,8 @@ def load( ) # Determine the device to use - if use_device == "mps" and torch.backends.mps.is_available(): + use_device = use_device or os.getenv("CHGNET_DEVICE") + if use_device in ("mps", None) and torch.backends.mps.is_available(): device = "mps" else: device = use_device or ("cuda" if torch.cuda.is_available() else "cpu") @@ -763,7 +764,7 @@ class BatchedGraph: directed2undirected: Tensor atom_positions: Sequence[Tensor] strains: Sequence[Tensor] - volumes: Sequence[Tensor] + volumes: Sequence[Tensor] | Tensor @classmethod def from_graphs( @@ -790,8 +791,7 @@ def from_graphs( batched_atom_graph, batched_bond_graph = [], [] directed2undirected = [] atom_owners = [] - atom_offset_idx = 0 - n_undirected = 0 + atom_offset_idx = n_undirected = 0 for graph_idx, graph in enumerate(graphs): # Atoms @@ -807,7 +807,7 @@ def from_graphs( else: strain = None lattice = graph.lattice - volumes.append(torch.det(lattice)) + volumes.append(torch.dot(lattice[0], torch.cross(lattice[1], lattice[2]))) strains.append(strain) # Bonds diff --git a/pyproject.toml b/pyproject.toml index f53c5992..e0d99c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } [tool.ruff] target-version = "py39" include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] +[tool.ruff.lint] select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions @@ -94,6 +95,7 @@ ignore = [ "D104", # Missing docstring in public package "D205", # 1 blank line required between summary line and description "DTZ005", # use of datetime.now() without timezone + "E731", # do not assign a lambda expression, use a def "EM", "ERA001", # found commented out code "FBT001", # Boolean positional argument in function @@ -114,7 +116,7 @@ pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] isort.split-on-trailing-comma = false -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "site/*" = ["INP001", "S602"] "tests/*" = ["ANN201", "D103", "INP001", "S101"] # E402 Module level import not at top of file diff --git a/tests/test_model.py b/tests/test_model.py index 4a67f55c..380da0fd 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -220,29 +220,37 @@ def test_as_to_from_dict() -> None: assert model_3.todict() == to_dict -def test_model_load_version_params(capsys: pytest.CaptureFixture) -> None: +def test_model_load_version_params( + capsys: pytest.CaptureFixture, monkeypatch: pytest.MonkeyPatch +) -> None: model = CHGNet.load(use_device="cpu") - assert model.version == "0.3.0" - assert model.n_params == 412_525 + v030_key, v030_params = "0.3.0", 412_525 + assert model.version == v030_key + assert model.n_params == v030_params stdout, stderr = capsys.readouterr() - assert ( - stdout - == f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters -CHGNet will run on cpu\n""" + expected_stdout = lambda version, params: ( + f"CHGNet v{version} initialized with {params:,} parameters\n" + "CHGNet will run on cpu\n" ) + assert stdout == expected_stdout(v030_key, v030_params) assert stderr == "" - model = CHGNet.load(model_name="0.2.0", use_device="cpu") - assert model.version == "0.2.0" - assert model.n_params == 400_438 + v020_key, v020_params = "0.2.0", 400_438 + model = CHGNet.load(model_name=v020_key, use_device="cpu") + assert model.version == v020_key + assert model.n_params == v020_params stdout, stderr = capsys.readouterr() - assert ( - stdout - == f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters -CHGNet will run on cpu\n""" - ) + assert stdout == expected_stdout(v020_key, v020_params) assert stderr == "" model_name = "0.1.0" # invalid with pytest.raises(ValueError, match=f"Unknown {model_name=}"): CHGNet.load(model_name=model_name) + + # # set CHGNET_DEVICE to "cuda" and test + monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar") + with pytest.raises( + RuntimeError, + match=f"Expected one of cpu, .+type at start of device string: {env_device}", + ): + CHGNet.load()