Skip to content

Commit

Permalink
Merge branch 'main' into really-migrate-to-np2
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Sep 12, 2024
2 parents 676e788 + 9281cf4 commit ac942d5
Show file tree
Hide file tree
Showing 15 changed files with 73 additions and 72 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ jobs:
cache: pip
cache-dependency-path: pyproject.toml

- name: Install dependencies via uv
- name: Install chgnet through uv
run: |
pip install uv
uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
Expand Down
14 changes: 8 additions & 6 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gc
import sys
import warnings
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand All @@ -13,6 +13,8 @@
from chgnet.graph.graph import Graph, Node

if TYPE_CHECKING:
from typing import Literal

from pymatgen.core import Structure
from typing_extensions import Self

Expand All @@ -21,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

datatype = torch.float32
DATATYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -122,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=datatype, requires_grad=True
structure.frac_coords, dtype=DATATYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=datatype, requires_grad=True
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand All @@ -150,7 +152,7 @@ def forward(
# Report structures that failed creating bond graph
# This happen occasionally with pymatgen version issue
structure.to(filename="bond_graph_error.cif")
raise SystemExit(
raise RuntimeError(
f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif"
) from exc
bond_graph = torch.tensor(bond_graph, dtype=torch.int32)
Expand All @@ -175,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=datatype),
neighbor_image=torch.tensor(image, dtype=DATATYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
raise ValueError(
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
f"* number of undirected edges={len(self.directed_edges_list)}!"
f"* number of undirected edges={len(self.undirected_edges_list)}!"
f"This indicates directed edges are not complete"
)
line_graph = []
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def fit(
if isinstance(structure, Structure):
atomic_number = torch.tensor(
[site.specie.Z for site in structure],
dtype=int,
dtype=torch.int32,
requires_grad=False,
)
else:
Expand Down
9 changes: 5 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import io
import pickle
import sys
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen, NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md.verlet import VelocityVerlet
from ase.optimize.bfgs import BFGS
Expand Down Expand Up @@ -610,11 +610,12 @@ def __init__(
except Exception:
bulk_modulus_au = 2 / 160.2176
compressibility_au = 1 / bulk_modulus_au
print(
warnings.warn(
"Warning!!! Equation of State fitting failed, setting bulk "
"modulus to 2 GPa. NPT simulation can proceed with incorrect "
"pressure relaxation time."
"User input for bulk modulus is recommended."
"User input for bulk modulus is recommended.",
stacklevel=2,
)
self.bulk_modulus = bulk_modulus

Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def from_dict(cls, dct: dict, **kwargs) -> Self:
@classmethod
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
state = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
return cls.from_dict(state["model"], **kwargs)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import datetime
import inspect
import os
import random
import shutil
import time
from datetime import datetime
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
Expand Down Expand Up @@ -285,7 +285,7 @@ def train(
raise ValueError("Model needs to be initialized")
global best_checkpoint # noqa: PLW0603
if save_dir is None:
save_dir = f"{datetime.now():%m-%d-%Y}"
save_dir = f"{datetime.datetime.now(tz=datetime.timezone.utc):%m-%d-%Y}"

print(f"Begin Training: using {self.device} device")
print(f"training targets: {self.targets}")
Expand Down
2 changes: 1 addition & 1 deletion chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_vasp_dir(
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
energy_per_atom, force, magmom, stress.
"""
if os.path.isdir(base_dir) is False:
if not os.path.isdir(base_dir):
raise NotADirectoryError(f"{base_dir=} is not a directory")

oszicar_path = zpath(f"{base_dir}/OSZICAR")
Expand Down
28 changes: 13 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Chemistry",
"Topic :: Scientific/Engineering :: Physics",
]

[project.optional-dependencies]
test = ["pytest-cov>=4", "pytest>=8"]
test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
Expand All @@ -46,6 +46,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.setuptools.package-data]
"chgnet" = ["*.json", "py.typed"]
"chgnet.graph.fast_converter_libraries" = ["*"]
"chgnet.pretrained" = ["*", "**/*"]

[build-system]
Expand All @@ -59,35 +60,32 @@ target-version = "py39"
select = ["ALL"]
ignore = [
"ANN001", # TODO add missing type annotations
"ANN003",
"ANN101",
"ANN102",
"ANN003", # Missing type annotation for **{name}
"ANN101", # Missing type annotation for {name} in method
"ANN102", # Missing type annotation for {name} in classmethod
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
"BLE001",
"BLE001", # use of general except Exception
"C408", # unnecessary-collection-call
"C901", # function is too complex
"COM812", # trailing comma missing
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"E731", # do not assign a lambda expression, use a def
"EM",
"EM", # error message related
"ERA001", # found commented out code
"ISC001",
"NPY002", # TODO replace legacy np.random.seed
"PLR0912", # too many branches
"PLR0913", # too many args in function def
"PLR0915", # too many statements
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT006", # pytest-parametrize-names-wrong-type
"PTH", # prefer Path to os.path
"S108",
"S108", # Probable insecure usage of temporary file or directory
"S301", # pickle can be unsafe
"S310",
"S311",
"TRY003",
"TRY300",
"S310", # Audit URL open for permitted schemes
"S311", # pseudo-random generators not suitable for cryptographic purposes
"TRY003", # Avoid specifying long messages outside the exception class
"TRY300", # Consider moving this statement to an else block
]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _set_make_graph() -> Generator[None, None, None]:


@pytest.mark.parametrize(
"atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)]
("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (5, None), (4, 2)]
)
def test_crystal_graph_converter_cutoff(
atom_graph_cutoff: float | None, bond_graph_cutoff: float | None
Expand Down
61 changes: 31 additions & 30 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from time import perf_counter
from unittest.mock import patch

import numpy as np
from pymatgen.core import Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter

np.random.seed(0)

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3)
converter_legacy = CrystalGraphConverter(
Expand Down Expand Up @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast():


def test_crystal_graph_perturb_legacy():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_legacy(structure_perturbed)
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_perturb_fast():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_fast(structure_perturbed)
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_atom_embedding(atom_feature_dim: int, max_num_elements: int) -> None:
assert "index out of range" in str(exc_info.value)


@pytest.mark.parametrize("atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (6, 4)])
@pytest.mark.parametrize(("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (6, 4)])
def test_bond_encoder(atom_graph_cutoff: float, bond_graph_cutoff: float) -> None:
undirected2directed = torch.tensor([0, 1])
image = torch.zeros((2, 3))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.mark.parametrize(
"algorithm, ase_filter, assign_magmoms",
("algorithm", "ase_filter", "assign_magmoms"),
[("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)],
)
def test_relaxation(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,19 @@ def test_wandb_init(mock_wandb):
)


def test_wandb_log_frequency(mock_wandb):
def test_wandb_log_frequency(tmp_path, mock_wandb):
trainer = Trainer(model=chgnet, wandb_path="test-project/test-run", epochs=1)

# Test epoch logging
trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir="")
trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir=tmp_path)
assert (
mock_wandb.log.call_count == 2 * trainer.epochs
), "Expected one train and one val log per epoch"

mock_wandb.log.reset_mock()

# Test batch logging
trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir="")
trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir=tmp_path)
expected_batch_calls = trainer.epochs * len(train_loader)
assert (
mock_wandb.log.call_count > expected_batch_calls
Expand All @@ -183,5 +183,5 @@ def test_wandb_log_frequency(mock_wandb):

# Test no logging when wandb_path is not provided
trainer_no_wandb = Trainer(model=chgnet, epochs=1)
trainer_no_wandb.train(train_loader, val_loader)
trainer_no_wandb.train(train_loader, val_loader, save_dir=tmp_path)
mock_wandb.log.assert_not_called()
Loading

0 comments on commit ac942d5

Please sign in to comment.