Skip to content

Commit

Permalink
Merge pull request #29 from rul048/patch-7
Browse files Browse the repository at this point in the history
Add static calculation feature to matcalc/relaxation.py
  • Loading branch information
shyuep authored Sep 6, 2024
2 parents f70fbea + 2e3b2c3 commit 00be81a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 30 deletions.
66 changes: 41 additions & 25 deletions matcalc/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
traj_file: str | None = None,
interval: int = 1,
fmax: float = 0.1,
relax_atoms: bool = True,
relax_cell: bool = True,
cell_filter: Filter = FrechetCellFilter,
) -> None:
Expand All @@ -90,6 +91,7 @@ def __init__(
interval (int): The step interval for saving the trajectories. Defaults to 1.
fmax (float): Total force tolerance for relaxation convergence.
fmax is a sum of force and stress forces. Defaults to 0.1 (eV/A).
relax_atoms (bool): Whether to relax the atoms (or just static calculation).
relax_cell (bool): Whether to relax the cell (or just atoms).
cell_filter (Filter): The ASE Filter used to relax the cell. Default is FrechetCellFilter.
Expand All @@ -104,6 +106,7 @@ def __init__(
self.max_steps = max_steps
self.traj_file = traj_file
self.relax_cell = relax_cell
self.relax_atoms = relax_atoms
self.cell_filter = cell_filter

def calc(self, structure: Structure) -> dict:
Expand All @@ -114,7 +117,9 @@ def calc(self, structure: Structure) -> dict:
Returns: {
final_structure: final_structure,
energy: trajectory observer final energy in eV,
energy: static energy or trajectory observer final energy in eV,
forces: forces in eV/A,
stress: stress in eV/A^3,
volume: lattice.volume in A^3,
a: lattice.a in A,
b: lattice.b in A,
Expand All @@ -126,31 +131,42 @@ def calc(self, structure: Structure) -> dict:
"""
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = self.calculator
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
if self.relax_atoms:
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
if self.relax_cell:
atoms = self.cell_filter(atoms)
optimizer = self.optimizer(atoms)
optimizer.attach(obs, interval=self.interval)
optimizer.run(fmax=self.fmax, steps=self.max_steps)
if self.traj_file is not None:
obs()
obs.save(self.traj_file)
if self.relax_cell:
atoms = self.cell_filter(atoms)
optimizer = self.optimizer(atoms)
optimizer.attach(obs, interval=self.interval)
optimizer.run(fmax=self.fmax, steps=self.max_steps)
if self.traj_file is not None:
obs()
obs.save(self.traj_file)
if self.relax_cell:
atoms = atoms.atoms

final_structure = AseAtomsAdaptor.get_structure(atoms)
lattice = final_structure.lattice
atoms = atoms.atoms
energy = obs.energies[-1]
final_structure = AseAtomsAdaptor.get_structure(atoms)
lattice = final_structure.lattice

return {
"final_structure": final_structure,
"energy": energy,
"a": lattice.a,
"b": lattice.b,
"c": lattice.c,
"alpha": lattice.alpha,
"beta": lattice.beta,
"gamma": lattice.gamma,
"volume": lattice.volume,
}

energy = atoms.get_potential_energy()
forces = atoms.get_forces()
stresses = atoms.get_stress()

return {
"final_structure": final_structure,
"energy": obs.energies[-1],
"a": lattice.a,
"b": lattice.b,
"c": lattice.c,
"alpha": lattice.alpha,
"beta": lattice.beta,
"gamma": lattice.gamma,
"volume": lattice.volume,
"energy": energy,
"forces": forces,
"stress": stresses,
}
87 changes: 82 additions & 5 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

import numpy as np
import pytest
from ase.filters import ExpCellFilter, FrechetCellFilter

Expand All @@ -12,22 +13,38 @@

from ase.filters import Filter
from matgl.ext.ase import M3GNetCalculator
from nuampy.typing import ArrayLike
from pymatgen.core import Structure


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
def test_relax_calc_single(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, cell_filter: Filter, expected_a: float
@pytest.mark.parametrize(
("cell_filter", "expected_a", "expected_energy"),
[(ExpCellFilter, 3.288585, -14.176867), (FrechetCellFilter, 3.291072, -14.176713)],
)
def test_relax_calc_relax_cell(
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
tmp_path: Path,
cell_filter: Filter,
expected_a: float,
expected_energy: float,
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", cell_filter=cell_filter
M3GNetCalc,
traj_file=f"{tmp_path}/li2o_relax.txt",
optimizer="FIRE",
cell_filter=cell_filter,
relax_atoms=True,
relax_cell=True,
)
result = relax_calc.calc(Li2O)
final_struct: Structure = result["final_structure"]
energy: float = result["energy"]
missing_keys = {*final_struct.lattice.params_dict} - {*result}
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
Expand All @@ -37,7 +54,67 @@ def test_relax_calc_single(
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)])
@pytest.mark.parametrize(("expected_a", "expected_energy"), [(3.291072, -14.176713)])
def test_relax_calc_relax_atoms(
Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, expected_a: float, expected_energy: float
) -> None:
relax_calc = RelaxCalc(
M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", relax_atoms=True, relax_cell=False
)
result = relax_calc.calc(Li2O)
final_struct: Structure = result["final_structure"]
energy: float = result["energy"]
missing_keys = {*final_struct.lattice.params_dict} - {*result}
assert len(missing_keys) == 0, f"{missing_keys=}"
a, b, c, alpha, beta, gamma = final_struct.lattice.parameters

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert a == pytest.approx(expected_a, rel=1e-3)
assert b == pytest.approx(expected_a, rel=1e-3)
assert c == pytest.approx(expected_a, rel=1e-3)
assert alpha == pytest.approx(60, abs=0.5)
assert beta == pytest.approx(60, abs=0.5)
assert gamma == pytest.approx(60, abs=0.5)
assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1)


@pytest.mark.parametrize(
("expected_energy", "expected_forces", "expected_stresses"),
[
(
-14.176713,
np.array(
[
[6.577218e-06, 1.851469e-06, -7.080846e-06],
[-4.507415e-03, -3.310852e-03, -7.090813e-03],
[4.500971e-03, 3.309000e-03, 7.097944e-03],
],
dtype=np.float32,
),
np.array([0.003883, 0.004126, 0.003089, -0.000617, -0.000839, -0.000391], dtype=np.float32),
),
],
)
def test_static_calc(
Li2O: Structure,
M3GNetCalc: M3GNetCalculator,
expected_energy: float,
expected_forces: ArrayLike,
expected_stresses: ArrayLike,
) -> None:
relax_calc = RelaxCalc(M3GNetCalc, relax_atoms=False, relax_cell=False)
result = relax_calc.calc(Li2O)

energy: float = result["energy"]
forces: ArrayLike = result["forces"]
stresses: ArrayLike = result["stress"]

assert energy == pytest.approx(expected_energy, rel=1e-3)
assert np.allclose(forces, expected_forces, rtol=1e-3)
assert np.allclose(stresses, expected_stresses, rtol=1e-3)


@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.288585), (FrechetCellFilter, 3.291072)])
def test_relax_calc_many(Li2O: Structure, M3GNetCalc: M3GNetCalculator, cell_filter: Filter, expected_a: float) -> None:
relax_calc = RelaxCalc(M3GNetCalc, optimizer="FIRE", cell_filter=cell_filter)
results = list(relax_calc.calc_many([Li2O] * 2))
Expand Down

0 comments on commit 00be81a

Please sign in to comment.