diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 863b250..95b4a9a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,11 +29,10 @@ jobs: cache: pip cache-dependency-path: pyproject.toml - - name: Install dependencies via uv + - name: Install chgnet through uv run: | pip install uv - - uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }} + uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system - name: Run Tests run: pytest --capture=no --cov --cov-report=xml diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 477654c..4b26db5 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -3,7 +3,7 @@ import gc import sys import warnings -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import numpy as np import torch @@ -13,6 +13,8 @@ from chgnet.graph.graph import Graph, Node if TYPE_CHECKING: + from typing import Literal + from pymatgen.core import Structure from typing_extensions import Self @@ -21,7 +23,7 @@ except (ImportError, AttributeError): make_graph = None -datatype = torch.float32 +DATATYPE = torch.float32 class CrystalGraphConverter(nn.Module): @@ -122,10 +124,10 @@ def forward( requires_grad=False, ) atom_frac_coord = torch.tensor( - structure.frac_coords, dtype=datatype, requires_grad=True + structure.frac_coords, dtype=DATATYPE, requires_grad=True ) lattice = torch.tensor( - structure.lattice.matrix, dtype=datatype, requires_grad=True + structure.lattice.matrix, dtype=DATATYPE, requires_grad=True ) center_index, neighbor_index, image, distance = structure.get_neighbor_list( r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8 @@ -150,7 +152,7 @@ def forward( # Report structures that failed creating bond graph # This happen occasionally with pymatgen version issue structure.to(filename="bond_graph_error.cif") - raise SystemExit( + raise RuntimeError( f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif" ) from exc bond_graph = torch.tensor(bond_graph, dtype=torch.int32) @@ -175,7 +177,7 @@ def forward( atomic_number=atomic_number, atom_frac_coord=atom_frac_coord, atom_graph=atom_graph, - neighbor_image=torch.tensor(image, dtype=datatype), + neighbor_image=torch.tensor(image, dtype=DATATYPE), directed2undirected=directed2undirected, undirected2directed=undirected2directed, bond_graph=bond_graph, diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index ba619a7..a32e993 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -271,7 +271,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]] if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list): raise ValueError( f"Error: number of directed edges={len(self.directed_edges_list)} != 2 " - f"* number of undirected edges={len(self.directed_edges_list)}!" + f"* number of undirected edges={len(self.undirected_edges_list)}!" f"This indicates directed edges are not complete" ) line_graph = [] diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 4a7e0e9..71f0276 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -146,7 +146,7 @@ def fit( if isinstance(structure, Structure): atomic_number = torch.tensor( [site.specie.Z for site in structure], - dtype=int, + dtype=torch.int32, requires_grad=False, ) else: diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 37c2431..c1a8601 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -5,14 +5,14 @@ import io import pickle import sys +import warnings from typing import TYPE_CHECKING, Literal import numpy as np from ase import Atoms, units from ase.calculators.calculator import Calculator, all_changes, all_properties from ase.md.npt import NPT -from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen -from ase.md.nvtberendsen import NVTBerendsen +from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen, NVTBerendsen from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary from ase.md.verlet import VelocityVerlet from ase.optimize.bfgs import BFGS @@ -610,11 +610,12 @@ def __init__( except Exception: bulk_modulus_au = 2 / 160.2176 compressibility_au = 1 / bulk_modulus_au - print( + warnings.warn( "Warning!!! Equation of State fitting failed, setting bulk " "modulus to 2 GPa. NPT simulation can proceed with incorrect " "pressure relaxation time." - "User input for bulk modulus is recommended." + "User input for bulk modulus is recommended.", + stacklevel=2, ) self.bulk_modulus = bulk_modulus diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d2edb92..d203033 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -672,7 +672,7 @@ def from_dict(cls, dct: dict, **kwargs) -> Self: @classmethod def from_file(cls, path: str, **kwargs) -> Self: """Build a CHGNet from a saved file.""" - state = torch.load(path, map_location=torch.device("cpu")) + state = torch.load(path, map_location=torch.device("cpu"), weights_only=False) return cls.from_dict(state["model"], **kwargs) @classmethod diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 3d3cdf4..60543ab 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -1,11 +1,11 @@ from __future__ import annotations +import datetime import inspect import os import random import shutil import time -from datetime import datetime from typing import TYPE_CHECKING, Literal, get_args import numpy as np @@ -285,7 +285,7 @@ def train( raise ValueError("Model needs to be initialized") global best_checkpoint # noqa: PLW0603 if save_dir is None: - save_dir = f"{datetime.now():%m-%d-%Y}" + save_dir = f"{datetime.datetime.now(tz=datetime.timezone.utc):%m-%d-%Y}" print(f"Begin Training: using {self.device} device") print(f"training targets: {self.targets}") diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index bfb54a0..5b71226 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -38,7 +38,7 @@ def parse_vasp_dir( dict: a dictionary of lists with keys for structure, uncorrected_total_energy, energy_per_atom, force, magmom, stress. """ - if os.path.isdir(base_dir) is False: + if not os.path.isdir(base_dir): raise NotADirectoryError(f"{base_dir=} is not a directory") oszicar_path = zpath(f"{base_dir}/OSZICAR") diff --git a/pyproject.toml b/pyproject.toml index 08b1412..5d199ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,17 +20,17 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Chemistry", "Topic :: Scientific/Engineering :: Physics", ] [project.optional-dependencies] -test = ["pytest-cov>=4", "pytest>=8"] +test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"] # needed to run interactive example notebooks examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"] docs = ["lazydocs>=0.4"] @@ -46,6 +46,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } [tool.setuptools.package-data] "chgnet" = ["*.json", "py.typed"] +"chgnet.graph.fast_converter_libraries" = ["*"] "chgnet.pretrained" = ["*", "**/*"] [build-system] @@ -59,35 +60,32 @@ target-version = "py39" select = ["ALL"] ignore = [ "ANN001", # TODO add missing type annotations - "ANN003", - "ANN101", - "ANN102", + "ANN003", # Missing type annotation for **{name} + "ANN101", # Missing type annotation for {name} in method + "ANN102", # Missing type annotation for {name} in classmethod "B019", # Use of functools.lru_cache on methods can lead to memory leaks - "BLE001", + "BLE001", # use of general except Exception "C408", # unnecessary-collection-call "C901", # function is too complex "COM812", # trailing comma missing "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D205", # 1 blank line required between summary line and description - "DTZ005", # use of datetime.now() without timezone "E731", # do not assign a lambda expression, use a def - "EM", + "EM", # error message related "ERA001", # found commented out code - "ISC001", "NPY002", # TODO replace legacy np.random.seed "PLR0912", # too many branches "PLR0913", # too many args in function def "PLR0915", # too many statements "PLW2901", # Outer for loop variable overwritten by inner assignment target - "PT006", # pytest-parametrize-names-wrong-type "PTH", # prefer Path to os.path - "S108", + "S108", # Probable insecure usage of temporary file or directory "S301", # pickle can be unsafe - "S310", - "S311", - "TRY003", - "TRY300", + "S310", # Audit URL open for permitted schemes + "S311", # pseudo-random generators not suitable for cryptographic purposes + "TRY003", # Avoid specifying long messages outside the exception class + "TRY300", # Consider moving this statement to an else block ] pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] diff --git a/tests/test_converter.py b/tests/test_converter.py index 46600ac..647f5c4 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -29,7 +29,7 @@ def _set_make_graph() -> Generator[None, None, None]: @pytest.mark.parametrize( - "atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)] + ("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (5, None), (4, 2)] ) def test_crystal_graph_converter_cutoff( atom_graph_cutoff: float | None, bond_graph_cutoff: float | None diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index 8834d23..4022e70 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations from time import perf_counter +from unittest.mock import patch import numpy as np from pymatgen.core import Structure @@ -8,8 +9,6 @@ from chgnet import ROOT from chgnet.graph import CrystalGraphConverter -np.random.seed(0) - structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif") converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) converter_legacy = CrystalGraphConverter( @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast(): def test_crystal_graph_perturb_legacy(): - np.random.seed(0) structure_perturbed = structure.copy() - structure_perturbed.perturb(distance=0.1) + fixed_rng = np.random.default_rng(0) + with patch("numpy.random.default_rng", return_value=fixed_rng): + structure_perturbed.perturb(distance=0.1) start = perf_counter() graph = converter_legacy(structure_perturbed) print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [410, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 - - assert list(graph.bond_graph.shape) == [688, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 + assert list(graph.atom_graph.shape) == [420, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 + + assert list(graph.bond_graph.shape) == [850, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [205] - assert list(graph.directed2undirected.shape) == [410] + assert list(graph.undirected2directed.shape) == [210] + assert list(graph.directed2undirected.shape) == [420] def test_crystal_graph_perturb_fast(): - np.random.seed(0) structure_perturbed = structure.copy() - structure_perturbed.perturb(distance=0.1) + fixed_rng = np.random.default_rng(0) + with patch("numpy.random.default_rng", return_value=fixed_rng): + structure_perturbed.perturb(distance=0.1) start = perf_counter() graph = converter_fast(structure_perturbed) print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [410, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 - - assert list(graph.bond_graph.shape) == [688, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 + assert list(graph.atom_graph.shape) == [420, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 + + assert list(graph.bond_graph.shape) == [850, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [205] - assert list(graph.directed2undirected.shape) == [410] + assert list(graph.undirected2directed.shape) == [210] + assert list(graph.directed2undirected.shape) == [420] def test_crystal_graph_isotropic_strained_legacy(): diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 6beda80..27d5b61 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -28,7 +28,7 @@ def test_atom_embedding(atom_feature_dim: int, max_num_elements: int) -> None: assert "index out of range" in str(exc_info.value) -@pytest.mark.parametrize("atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (6, 4)]) +@pytest.mark.parametrize(("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (6, 4)]) def test_bond_encoder(atom_graph_cutoff: float, bond_graph_cutoff: float) -> None: undirected2directed = torch.tensor([0, 1]) image = torch.zeros((2, 3)) diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index 87241f5..c23b675 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize( - "algorithm, ase_filter, assign_magmoms", + ("algorithm", "ase_filter", "assign_magmoms"), [("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)], ) def test_relaxation( diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 71fe1f8..89bf16b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -153,11 +153,11 @@ def test_wandb_init(mock_wandb): ) -def test_wandb_log_frequency(mock_wandb): +def test_wandb_log_frequency(tmp_path, mock_wandb): trainer = Trainer(model=chgnet, wandb_path="test-project/test-run", epochs=1) # Test epoch logging - trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir="") + trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir=tmp_path) assert ( mock_wandb.log.call_count == 2 * trainer.epochs ), "Expected one train and one val log per epoch" @@ -165,7 +165,7 @@ def test_wandb_log_frequency(mock_wandb): mock_wandb.log.reset_mock() # Test batch logging - trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir="") + trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir=tmp_path) expected_batch_calls = trainer.epochs * len(train_loader) assert ( mock_wandb.log.call_count > expected_batch_calls @@ -183,5 +183,5 @@ def test_wandb_log_frequency(mock_wandb): # Test no logging when wandb_path is not provided trainer_no_wandb = Trainer(model=chgnet, epochs=1) - trainer_no_wandb.train(train_loader, val_loader) + trainer_no_wandb.train(train_loader, val_loader, save_dir=tmp_path) mock_wandb.log.assert_not_called() diff --git a/tests/test_vasp_utils.py b/tests/test_vasp_utils.py index 8def656..b16c2ad 100644 --- a/tests/test_vasp_utils.py +++ b/tests/test_vasp_utils.py @@ -62,7 +62,7 @@ def test_parse_vasp_dir_without_magmoms(tmp_path: Path): def test_parse_vasp_dir_no_data(): # test non-existing directory - with pytest.raises(FileNotFoundError, match="is not a directory"): + with pytest.raises(NotADirectoryError, match="is not a directory"): parse_vasp_dir(f"{ROOT}/tests/files/non-existent") # test existing directory without VASP files