Skip to content

Commit

Permalink
add test type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Dec 11, 2023
1 parent e099a3b commit 03aaa07
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 43 deletions.
10 changes: 6 additions & 4 deletions matcalc/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
from pymatgen.analysis.elasticity import DeformedStructureSet, ElasticTensor, Strain
Expand All @@ -12,6 +12,8 @@
from .relaxation import RelaxCalc

if TYPE_CHECKING:
from collections.abc import Sequence

from ase.calculators.calculator import Calculator
from pymatgen.core import Structure

Expand All @@ -22,8 +24,8 @@ class ElasticityCalc(PropCalc):
def __init__(
self,
calculator: Calculator,
norm_strains: tuple[float, ...] | float = (-0.01, -0.005, 0.005, 0.01),
shear_strains: tuple[float, ...] | float = (-0.06, -0.03, 0.03, 0.06),
norm_strains: Sequence[float] | float = (-0.01, -0.005, 0.005, 0.01),
shear_strains: Sequence[float] | float = (-0.06, -0.03, 0.03, 0.06),
fmax: float = 0.1,
relax_structure: bool = True,
use_equilibrium: bool = True,
Expand Down Expand Up @@ -58,7 +60,7 @@ def __init__(
else:
self.use_equilibrium = True

def calc(self, structure: Structure) -> dict[str, float | ElasticTensor | Structure]:
def calc(self, structure: Structure) -> dict[str, Any]:
"""
Calculates elastic properties of Pymatgen structure with units determined by the calculator,
(often the stress_weight).
Expand Down
22 changes: 14 additions & 8 deletions tests/test_elasticity.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Tests for ElasticCalc class"""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pytest

from matcalc.elasticity import ElasticityCalc

if TYPE_CHECKING:
from matgl.ext.ase import M3GNetCalculator
from pymatgen.core import Structure


def test_elastic_calc(Li2O, M3GNetCalc):
def test_elastic_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator) -> None:
"""Tests for ElasticCalc class"""
e_calc = ElasticityCalc(
elast_calc = ElasticityCalc(
M3GNetCalc,
fmax=0.1,
norm_strains=list(np.linspace(-0.004, 0.004, num=4)),
Expand All @@ -18,7 +24,7 @@ def test_elastic_calc(Li2O, M3GNetCalc):
)

# Test Li2O with equilibrium structure
results = e_calc.calc(Li2O)
results = elast_calc.calc(Li2O)
assert results["elastic_tensor"].shape == (3, 3, 3, 3)
assert results["elastic_tensor"][0][1][1][0] == pytest.approx(0.5014895636122672, rel=1e-3)
assert results["bulk_modulus_vrh"] == pytest.approx(0.6737897607182401, rel=1e-3)
Expand All @@ -28,32 +34,32 @@ def test_elastic_calc(Li2O, M3GNetCalc):
assert results["structure"].lattice.a == pytest.approx(3.2885851104196875, rel=1e-4)

# Test Li2O without the equilibrium structure
e_calc = ElasticityCalc(
elast_calc = ElasticityCalc(
M3GNetCalc,
fmax=0.1,
norm_strains=list(np.linspace(-0.004, 0.004, num=4)),
shear_strains=list(np.linspace(-0.004, 0.004, num=4)),
use_equilibrium=False,
)

results = e_calc.calc(Li2O)
results = elast_calc.calc(Li2O)
assert results["residuals_sum"] == pytest.approx(2.9257237571340992e-08, rel=1e-2)

# Test Li2O with float
e_calc = ElasticityCalc(
elast_calc = ElasticityCalc(
M3GNetCalc,
fmax=0.1,
norm_strains=0.004,
shear_strains=0.004,
use_equilibrium=True,
)

results = e_calc.calc(Li2O)
results = elast_calc.calc(Li2O)
assert results["residuals_sum"] == 0.0
assert results["bulk_modulus_vrh"] == pytest.approx(0.6631894154825593, rel=1e-3)


def test_elastic_calc_invalid_states(Li2O, M3GNetCalc):
def test_elastic_calc_invalid_states(M3GNetCalc: M3GNetCalculator):
with pytest.raises(ValueError, match="shear_strains must be nonempty"):
ElasticityCalc(M3GNetCalc, shear_strains=[])
with pytest.raises(ValueError, match="norm_strains must be nonempty"):
Expand Down
30 changes: 20 additions & 10 deletions tests/test_eos.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
"""Tests for PhononCalc class"""
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from matcalc.eos import EOSCalc

if TYPE_CHECKING:
from matgl.ext.ase import M3GNetCalculator
from pymatgen.core import Structure


def test_eos_calc(Li2O, LiFePO4, M3GNetCalc):
def test_eos_calc(
Li2O: Structure,
LiFePO4: Structure,
M3GNetCalc: M3GNetCalculator,
) -> None:
"""Tests for EOSCalc class"""
# Note that the fmax is probably too high. This is for testing purposes only.
pcalc = EOSCalc(M3GNetCalc, fmax=0.1)
results = pcalc.calc(Li2O)
eos_calc = EOSCalc(M3GNetCalc, fmax=0.1)
result = eos_calc.calc(Li2O)

assert {*results} == {"eos", "r2_score_bm", "bulk_modulus_bm"}
assert results["bulk_modulus_bm"] == pytest.approx(65.57980045603279, rel=1e-2)
assert {*results["eos"]} == {"volumes", "energies"}
assert results["eos"]["volumes"] == pytest.approx(
assert {*result} == {"eos", "r2_score_bm", "bulk_modulus_bm"}
assert result["bulk_modulus_bm"] == pytest.approx(65.57980045603279, rel=1e-2)
assert {*result["eos"]} == {"volumes", "energies"}
assert result["eos"]["volumes"] == pytest.approx(
[18.38, 19.63, 20.94, 22.3, 23.73, 25.21, 26.75, 28.36, 30.02, 31.76, 33.55],
rel=1e-3,
)
assert results["eos"]["energies"] == pytest.approx(
assert result["eos"]["energies"] == pytest.approx(
[-13.52, -13.77, -13.94, -14.08, -14.15, -14.18, -14.16, -14.11, -14.03, -13.94, -13.83],
rel=1e-3,
)
pcalc = EOSCalc(M3GNetCalc, relax_structure=False)
results = list(pcalc.calc_many([Li2O, LiFePO4]))
eos_calc = EOSCalc(M3GNetCalc, relax_structure=False)
results = list(eos_calc.calc_many([Li2O, LiFePO4]))
assert len(results) == 2
assert results[1]["bulk_modulus_bm"] == pytest.approx(54.5953851822073, rel=1e-2)
10 changes: 9 additions & 1 deletion tests/test_neb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from matcalc.neb import NEBCalc

if TYPE_CHECKING:
from pathlib import Path

from matgl.ext.ase import M3GNetCalculator
from pymatgen.core import Structure


def test_neb_calc(LiFePO4, M3GNetCalc, tmp_path):
def test_neb_calc(LiFePO4: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path) -> None:
"""Tests for NEBCalc class"""
image_start = LiFePO4.copy()
image_start.remove_sites([2])
Expand Down
24 changes: 16 additions & 8 deletions tests/test_phonon.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
"""Tests for PhononCalc class"""
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from matcalc.phonon import PhononCalc

if TYPE_CHECKING:
from matgl.ext.ase import M3GNetCalculator
from pymatgen.core import Structure


def test_phonon_calc(Li2O, M3GNetCalc):
def test_phonon_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator) -> None:
"""Tests for PhononCalc class"""
# Note that the fmax is probably too high. This is for testing purposes only.
pcalc = PhononCalc(M3GNetCalc, supercell_matrix=((2, 0, 0), (0, 2, 0), (0, 0, 2)), fmax=0.1, t_step=50, t_max=1000)
results = pcalc.calc(Li2O)
phonon_calc = PhononCalc(
M3GNetCalc, supercell_matrix=((2, 0, 0), (0, 2, 0), (0, 0, 2)), fmax=0.1, t_step=50, t_max=1000
)
result = phonon_calc.calc(Li2O)

# Test values at 100 K
ind = results["thermal_properties"]["temperatures"].tolist().index(300)
assert results["thermal_properties"]["heat_capacity"][ind] == pytest.approx(58.42898370395005, rel=1e-2)
assert results["thermal_properties"]["entropy"][ind] == pytest.approx(49.3774618162247, rel=1e-2)
assert results["thermal_properties"]["free_energy"][ind] == pytest.approx(13.245478097108784, rel=1e-2)
ind = result["thermal_properties"]["temperatures"].tolist().index(300)
assert result["thermal_properties"]["heat_capacity"][ind] == pytest.approx(58.42898370395005, rel=1e-2)
assert result["thermal_properties"]["entropy"][ind] == pytest.approx(49.3774618162247, rel=1e-2)
assert result["thermal_properties"]["free_energy"][ind] == pytest.approx(13.245478097108784, rel=1e-2)

results = list(pcalc.calc_many([Li2O, Li2O]))
results = list(phonon_calc.calc_many([Li2O, Li2O]))
assert len(results) == 2
30 changes: 19 additions & 11 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from matcalc.relaxation import RelaxCalc

if TYPE_CHECKING:
from pathlib import Path

from matgl.ext.ase import M3GNetCalculator
from pymatgen.core import Structure


def test_relax_calc(Li2O, M3GNetCalc, tmp_path):
pcalc = RelaxCalc(M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE")
results = pcalc.calc(Li2O)
assert results["a"] == pytest.approx(3.291071792359756, rel=0.002)
assert results["b"] == pytest.approx(3.291071899625086, rel=0.002)
assert results["c"] == pytest.approx(3.291072056855788, rel=0.002)
assert results["alpha"] == pytest.approx(60, abs=1)
assert results["beta"] == pytest.approx(60, abs=1)
assert results["gamma"] == pytest.approx(60, abs=1)
assert results["volume"] == pytest.approx(results["a"] * results["b"] * results["c"] / 2**0.5, abs=0.1)
def test_relax_calc(Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path) -> None:
relax_calc = RelaxCalc(M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE")
result = relax_calc.calc(Li2O)
assert result["a"] == pytest.approx(3.291071792359756, rel=0.002)
assert result["b"] == pytest.approx(3.291071899625086, rel=0.002)
assert result["c"] == pytest.approx(3.291072056855788, rel=0.002)
assert result["alpha"] == pytest.approx(60, abs=1)
assert result["beta"] == pytest.approx(60, abs=1)
assert result["gamma"] == pytest.approx(60, abs=1)
assert result["volume"] == pytest.approx(result["a"] * result["b"] * result["c"] / 2**0.5, abs=0.1)

results = list(pcalc.calc_many([Li2O] * 2))
results = list(relax_calc.calc_many([Li2O] * 2))
assert len(results) == 2
assert results[-1]["a"] == pytest.approx(3.291071792359756, rel=0.002)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matcalc.util import UNIVERSAL_CALCULATORS, get_universal_calculator


def test_get_universal_calculator():
def test_get_universal_calculator() -> None:
for name in UNIVERSAL_CALCULATORS:
calc = get_universal_calculator(name)
assert isinstance(calc, Calculator)
Expand Down

0 comments on commit 03aaa07

Please sign in to comment.