Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static calculation feature to matcalc/relaxation.py #29

Merged
merged 11 commits into from
Sep 6, 2024
Merged
66 changes: 41 additions & 25 deletions matcalc/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
traj_file: str | None = None,
interval: int = 1,
fmax: float = 0.1,
relax_atoms: bool = True,
relax_cell: bool = True,
cell_filter: Filter = FrechetCellFilter,
) -> None:
Expand All @@ -90,6 +91,7 @@ def __init__(
interval (int): The step interval for saving the trajectories. Defaults to 1.
fmax (float): Total force tolerance for relaxation convergence.
fmax is a sum of force and stress forces. Defaults to 0.1 (eV/A).
relax_atoms (bool): Whether to relax the atoms (or just static calculation).
relax_cell (bool): Whether to relax the cell (or just atoms).
cell_filter (Filter): The ASE Filter used to relax the cell. Default is FrechetCellFilter.

Expand All @@ -104,6 +106,7 @@ def __init__(
self.max_steps = max_steps
self.traj_file = traj_file
self.relax_cell = relax_cell
self.relax_atoms = relax_atoms
self.cell_filter = cell_filter

def calc(self, structure: Structure) -> dict:
Expand All @@ -114,7 +117,9 @@ def calc(self, structure: Structure) -> dict:

Returns: {
final_structure: final_structure,
energy: trajectory observer final energy in eV,
energy: static energy or trajectory observer final energy in eV,
forces: forces in eV/A,
stress: stress in eV/A^3,
volume: lattice.volume in A^3,
a: lattice.a in A,
b: lattice.b in A,
Expand All @@ -126,31 +131,42 @@ def calc(self, structure: Structure) -> dict:
"""
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = self.calculator
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
if self.relax_atoms:
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
if self.relax_cell:
atoms = self.cell_filter(atoms)
optimizer = self.optimizer(atoms)
optimizer.attach(obs, interval=self.interval)
optimizer.run(fmax=self.fmax, steps=self.max_steps)
if self.traj_file is not None:
obs()
obs.save(self.traj_file)
if self.relax_cell:
atoms = self.cell_filter(atoms)
optimizer = self.optimizer(atoms)
optimizer.attach(obs, interval=self.interval)
optimizer.run(fmax=self.fmax, steps=self.max_steps)
if self.traj_file is not None:
obs()
obs.save(self.traj_file)
if self.relax_cell:
atoms = atoms.atoms

final_structure = AseAtomsAdaptor.get_structure(atoms)
lattice = final_structure.lattice
atoms = atoms.atoms
energy = obs.energies[-1]
final_structure = AseAtomsAdaptor.get_structure(atoms)
lattice = final_structure.lattice

return {
"final_structure": final_structure,
"energy": energy,
"a": lattice.a,
"b": lattice.b,
"c": lattice.c,
"alpha": lattice.alpha,
"beta": lattice.beta,
"gamma": lattice.gamma,
"volume": lattice.volume,
}

energy = atoms.get_potential_energy()
forces = atoms.get_forces()
stresses = atoms.get_stress()

return {
"final_structure": final_structure,
"energy": obs.energies[-1],
"a": lattice.a,
"b": lattice.b,
"c": lattice.c,
"alpha": lattice.alpha,
"beta": lattice.beta,
"gamma": lattice.gamma,
"volume": lattice.volume,
"energy": energy,
"forces": forces,
"stress": stresses,
}
87 changes: 82 additions & 5 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

import numpy as np
import pytest
from ase.filters import ExpCellFilter, FrechetCellFilter

Expand All @@ -12,22 +13,38 @@

from ase.filters import Filter
from matgl.ext.ase import M3GNetCalculator
from nuampy.typing import ArrayLike
from pymatgen.core import Structure


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
def test_relax_calc_single(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, cell_filter: Filter, expected_a: float
@pytest.mark.parametrize(
("cell_filter", "expected_a", "expected_energy"),
[(ExpCellFilter, 3.288585, -14.176867), (FrechetCellFilter, 3.291072, -14.176713)],
)
def test_relax_calc_relax_cell(
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
tmp_path: Path,
cell_filter: Filter,
expected_a: float,
expected_energy: float,
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", cell_filter=cell_filter
M3GNetCalc,
traj_file=f"{tmp_path}/li2o_relax.txt",
optimizer="FIRE",
cell_filter=cell_filter,
relax_atoms=True,
relax_cell=True,
)
result = relax_calc.calc(Li2O)
final_struct: Structure = result["final_structure"]
energy: float = result["energy"]
missing_keys = {*final_struct.lattice.params_dict} - {*result}
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
Expand All @@ -37,7 +54,67 @@ def test_relax_calc_single(
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
@pytest.mark.parametrize(("expected_a", "expected_energy"), [(3.291072, -14.176713)])
def test_relax_calc_relax_atoms(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, expected_a: float, expected_energy: float
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", relax_atoms=True, relax_cell=False
)
result = relax_calc.calc(Li2O)
final_struct: Structure = result["final_structure"]
energy: float = result["energy"]
missing_keys = {*final_struct.lattice.params_dict} - {*result}
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
assert alpha == pytest.approx(60, abs=0.5)
assert beta == pytest.approx(60, abs=0.5)
assert gamma == pytest.approx(60, abs=0.5)
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)


@pytest.mark.parametrize(
("expected_energy", "expected_forces", "expected_stresses"),
[
(
-14.176713,
np.array(
[
[6.577218e-06, 1.851469e-06, -7.080846e-06],
[-4.507415e-03, -3.310852e-03, -7.090813e-03],
[4.500971e-03, 3.309000e-03, 7.097944e-03],
],
dtype=np.float32,
),
np.array([0.003883, 0.004126, 0.003089, -0.000617, -0.000839, -0.000391], dtype=np.float32),
),
],
)
def test_static_calc(
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
expected_energy: float,
expected_forces: ArrayLike,
expected_stresses: ArrayLike,
) -> None:
relax_calc = RelaxCalc(M3GNetCalc, relax_atoms=False, relax_cell=False)
result = relax_calc.calc(Li2O)

energy: float = result["energy"]
forces: ArrayLike = result["forces"]
stresses: ArrayLike = result["stress"]

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert np.allclose(forces, expected_forces, rtol=1e-3)
assert np.allclose(stresses, expected_stresses, rtol=1e-3)


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.288585), (FrechetCellFilter, 3.291072)])
def test_relax_calc_many(Li2O: Structure, M3GNetCalc: M3GNetCalculator, cell_filter: Filter, expected_a: float) -> None:
relax_calc = RelaxCalc(M3GNetCalc, optimizer="FIRE", cell_filter=cell_filter)
results = list(relax_calc.calc_many([Li2O] * 2))
Expand Down
Loading