Skip to content

Commit

Permalink
assert efs based on value
Browse files Browse the repository at this point in the history
Signed-off-by: Runze Liu <[email protected]>
  • Loading branch information
rul048 committed Sep 6, 2024
1 parent 266da49 commit a2a0376
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

import numpy as np
import pytest
from ase.filters import ExpCellFilter, FrechetCellFilter

Expand All @@ -12,16 +13,29 @@

from ase.filters import Filter
from matgl.ext.ase import M3GNetCalculator
from numpy.typing import ArrayLike
from nuampy.typing import ArrayLike
from pymatgen.core import Structure


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
@pytest.mark.parametrize(
("cell_filter", "expected_a", "expected_energy"),
[(ExpCellFilter, 3.288585, -14.176882), (FrechetCellFilter, 3.291071, -14.176743)],
)
def test_relax_calc_relax_cell(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, cell_filter: Filter, expected_a: float
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
tmp_path: Path,
cell_filter: Filter,
expected_a: float,
expected_energy: float,
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", cell_filter=cell_filter, relax_atoms=True, relax_cell=True
M3GNetCalc,
traj_file=f"{tmp_path}/li2o_relax.txt",
optimizer="FIRE",
cell_filter=cell_filter,
relax_atoms=True,
relax_cell=True,
)
result = relax_calc.calc(Li2O)
final_struct: Structure = result["final_structure"]
Expand All @@ -30,19 +44,19 @@ def test_relax_calc_relax_cell(
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
assert alpha == pytest.approx(60, abs=0.5)
assert beta == pytest.approx(60, abs=0.5)
assert gamma == pytest.approx(60, abs=0.5)
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)
assert isinstance(energy, float)


@pytest.mark.parametrize("expected_a", [3.288585])
@pytest.mark.parametrize(("expected_a", "expected_energy"), [(3.291071, -14.1767423)])
def test_relax_calc_relax_atoms(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, expected_a: float
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, expected_a: float, expected_energy: float
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", relax_atoms=True, relax_cell=False
Expand All @@ -54,32 +68,54 @@ def test_relax_calc_relax_atoms(
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
assert alpha == pytest.approx(60, abs=0.5)
assert beta == pytest.approx(60, abs=0.5)
assert gamma == pytest.approx(60, abs=0.5)
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)
assert isinstance(energy, float)


@pytest.mark.parametrize(
("expected_energy", "expected_forces", "expected_stress"),
[
(
-14.176743,
np.array(
[
[-4.252263e-03, -3.029412e-03, -7.360560e-03],
[4.272716e-03, 3.017864e-03, 7.360807e-03],
[-2.035008e-05, 1.152878e-05, -2.714805e-07],
],
dtype=np.float32,
),
np.array([0.003945, 0.004185, 0.003000, -0.000584, -0.000826, -0.000337], dtype=np.float32),
),
],
)
def test_static_calc(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
tmp_path: Path,
expected_energy: float,
expected_forces: ArrayLike,
expected_stresses: ArrayLike,
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", relax_atoms=False, relax_cell=False
)
relax_calc = RelaxCalc(M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", relax_atoms=False, relax_cell=False)
result = relax_calc.calc(Li2O)

energy: float = result["energy"]
forces: ArrayLike = result["forces"]
stresses: ArrayLike = result["stress"]

assert isinstance(energy, float)
assert list(forces.shape) == [3, 3]
assert list(stresses.shape) == [6]
assert energy == pytest.approx(expected_energy, rel=1e-3)
assert np.allclose(forces, expected_forces, rtol=1e-3)
assert np.allclose(stresses, expected_stresses, rtol=1e-3)


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.288585), (FrechetCellFilter, 3.291071)])
def test_relax_calc_many(Li2O: Structure, M3GNetCalc: M3GNetCalculator, cell_filter: Filter, expected_a: float) -> None:
relax_calc = RelaxCalc(M3GNetCalc, optimizer="FIRE", cell_filter=cell_filter)
results = list(relax_calc.calc_many([Li2O] * 2))
Expand Down

0 comments on commit a2a0376

Please sign in to comment.