Skip to content

Commit

Permalink
assert check_cuda_mem defaults to False in test_model_load_version_pa…
Browse files Browse the repository at this point in the history
…rams
  • Loading branch information
janosh committed Jun 11, 2024
1 parent dc2e425 commit 3a9bacd
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect

import numpy as np
import pytest
from pymatgen.core import Structure
Expand Down Expand Up @@ -249,10 +251,15 @@ def test_model_load_version_params(
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

# # set CHGNET_DEVICE to "cuda" and test
monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar")
with pytest.raises(
RuntimeError,
match=f"Expected one of cpu, .+type at start of device string: {env_device}",
bad_env_device = "foobar"
err_msg = f"Expected one of cpu, .+type at start of device string: {bad_env_device}"
with ( # noqa: PT012
monkeypatch.context() as ctx,
pytest.raises(RuntimeError, match=err_msg),
):
ctx.setenv("CHGNET_DEVICE", bad_env_device)
CHGNet.load()

# check check_cuda_mem defaults to False
inspect_signature = inspect.signature(CHGNet.load)
assert inspect_signature.parameters["check_cuda_mem"].default is False

0 comments on commit 3a9bacd

Please sign in to comment.