diff --git a/matcalc/relaxation.py b/matcalc/relaxation.py index 34723e5..25429f1 100644 --- a/matcalc/relaxation.py +++ b/matcalc/relaxation.py @@ -79,6 +79,7 @@ def __init__( traj_file: str | None = None, interval: int = 1, fmax: float = 0.1, + relax_atoms: bool = True, relax_cell: bool = True, cell_filter: Filter = FrechetCellFilter, ) -> None: @@ -90,6 +91,7 @@ def __init__( interval (int): The step interval for saving the trajectories. Defaults to 1. fmax (float): Total force tolerance for relaxation convergence. fmax is a sum of force and stress forces. Defaults to 0.1 (eV/A). + relax_atoms (bool): Whether to relax the atoms (or just static calculation). relax_cell (bool): Whether to relax the cell (or just atoms). cell_filter (Filter): The ASE Filter used to relax the cell. Default is FrechetCellFilter. @@ -104,6 +106,7 @@ def __init__( self.max_steps = max_steps self.traj_file = traj_file self.relax_cell = relax_cell + self.relax_atoms = relax_atoms self.cell_filter = cell_filter def calc(self, structure: Structure) -> dict: @@ -114,7 +117,9 @@ def calc(self, structure: Structure) -> dict: Returns: { final_structure: final_structure, - energy: trajectory observer final energy in eV, + energy: static energy or trajectory observer final energy in eV, + forces: forces in eV/A, + stress: stress in eV/A^3, volume: lattice.volume in A^3, a: lattice.a in A, b: lattice.b in A, @@ -126,31 +131,42 @@ def calc(self, structure: Structure) -> dict: """ atoms = AseAtomsAdaptor.get_atoms(structure) atoms.calc = self.calculator - stream = io.StringIO() - with contextlib.redirect_stdout(stream): - obs = TrajectoryObserver(atoms) + if self.relax_atoms: + stream = io.StringIO() + with contextlib.redirect_stdout(stream): + obs = TrajectoryObserver(atoms) + if self.relax_cell: + atoms = self.cell_filter(atoms) + optimizer = self.optimizer(atoms) + optimizer.attach(obs, interval=self.interval) + optimizer.run(fmax=self.fmax, steps=self.max_steps) + if self.traj_file is not None: + obs() + obs.save(self.traj_file) if self.relax_cell: - atoms = self.cell_filter(atoms) - optimizer = self.optimizer(atoms) - optimizer.attach(obs, interval=self.interval) - optimizer.run(fmax=self.fmax, steps=self.max_steps) - if self.traj_file is not None: - obs() - obs.save(self.traj_file) - if self.relax_cell: - atoms = atoms.atoms - - final_structure = AseAtomsAdaptor.get_structure(atoms) - lattice = final_structure.lattice + atoms = atoms.atoms + energy = obs.energies[-1] + final_structure = AseAtomsAdaptor.get_structure(atoms) + lattice = final_structure.lattice + + return { + "final_structure": final_structure, + "energy": energy, + "a": lattice.a, + "b": lattice.b, + "c": lattice.c, + "alpha": lattice.alpha, + "beta": lattice.beta, + "gamma": lattice.gamma, + "volume": lattice.volume, + } + + energy = atoms.get_potential_energy() + forces = atoms.get_forces() + stresses = atoms.get_stress() return { - "final_structure": final_structure, - "energy": obs.energies[-1], - "a": lattice.a, - "b": lattice.b, - "c": lattice.c, - "alpha": lattice.alpha, - "beta": lattice.beta, - "gamma": lattice.gamma, - "volume": lattice.volume, + "energy": energy, + "forces": forces, + "stress": stresses, } diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index a4dcc98..7b2cbfc 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +import numpy as np import pytest from ase.filters import ExpCellFilter, FrechetCellFilter @@ -12,22 +13,38 @@ from ase.filters import Filter from matgl.ext.ase import M3GNetCalculator + from nuampy.typing import ArrayLike from pymatgen.core import Structure -@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)]) -def test_relax_calc_single( - Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, cell_filter: Filter, expected_a: float +@pytest.mark.parametrize( + ("cell_filter", "expected_a", "expected_energy"), + [(ExpCellFilter, 3.288585, -14.176867), (FrechetCellFilter, 3.291072, -14.176713)], +) +def test_relax_calc_relax_cell( + Li2O: Structure, + M3GNetCalc: M3GNetCalculator, + tmp_path: Path, + cell_filter: Filter, + expected_a: float, + expected_energy: float, ) -> None: relax_calc = RelaxCalc( - M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", cell_filter=cell_filter + M3GNetCalc, + traj_file=f"{tmp_path}/li2o_relax.txt", + optimizer="FIRE", + cell_filter=cell_filter, + relax_atoms=True, + relax_cell=True, ) result = relax_calc.calc(Li2O) final_struct: Structure = result["final_structure"] + energy: float = result["energy"] missing_keys = {*final_struct.lattice.params_dict} - {*result} assert len(missing_keys) == 0, f"{missing_keys=}" a, b, c, alpha, beta, gamma = final_struct.lattice.parameters + assert energy == pytest.approx(expected_energy, rel=1e-3) assert a == pytest.approx(expected_a, rel=1e-3) assert b == pytest.approx(expected_a, rel=1e-3) assert c == pytest.approx(expected_a, rel=1e-3) @@ -37,7 +54,67 @@ def test_relax_calc_single( assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1) -@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.291071), (FrechetCellFilter, 3.288585)]) +@pytest.mark.parametrize(("expected_a", "expected_energy"), [(3.291072, -14.176713)]) +def test_relax_calc_relax_atoms( + Li2O: Structure, M3GNetCalc: M3GNetCalculator, tmp_path: Path, expected_a: float, expected_energy: float +) -> None: + relax_calc = RelaxCalc( + M3GNetCalc, traj_file=f"{tmp_path}/li2o_relax.txt", optimizer="FIRE", relax_atoms=True, relax_cell=False + ) + result = relax_calc.calc(Li2O) + final_struct: Structure = result["final_structure"] + energy: float = result["energy"] + missing_keys = {*final_struct.lattice.params_dict} - {*result} + assert len(missing_keys) == 0, f"{missing_keys=}" + a, b, c, alpha, beta, gamma = final_struct.lattice.parameters + + assert energy == pytest.approx(expected_energy, rel=1e-3) + assert a == pytest.approx(expected_a, rel=1e-3) + assert b == pytest.approx(expected_a, rel=1e-3) + assert c == pytest.approx(expected_a, rel=1e-3) + assert alpha == pytest.approx(60, abs=0.5) + assert beta == pytest.approx(60, abs=0.5) + assert gamma == pytest.approx(60, abs=0.5) + assert final_struct.volume == pytest.approx(a * b * c / 2**0.5, abs=0.1) + + +@pytest.mark.parametrize( + ("expected_energy", "expected_forces", "expected_stresses"), + [ + ( + -14.176713, + np.array( + [ + [6.577218e-06, 1.851469e-06, -7.080846e-06], + [-4.507415e-03, -3.310852e-03, -7.090813e-03], + [4.500971e-03, 3.309000e-03, 7.097944e-03], + ], + dtype=np.float32, + ), + np.array([0.003883, 0.004126, 0.003089, -0.000617, -0.000839, -0.000391], dtype=np.float32), + ), + ], +) +def test_static_calc( + Li2O: Structure, + M3GNetCalc: M3GNetCalculator, + expected_energy: float, + expected_forces: ArrayLike, + expected_stresses: ArrayLike, +) -> None: + relax_calc = RelaxCalc(M3GNetCalc, relax_atoms=False, relax_cell=False) + result = relax_calc.calc(Li2O) + + energy: float = result["energy"] + forces: ArrayLike = result["forces"] + stresses: ArrayLike = result["stress"] + + assert energy == pytest.approx(expected_energy, rel=1e-3) + assert np.allclose(forces, expected_forces, rtol=1e-3) + assert np.allclose(stresses, expected_stresses, rtol=1e-3) + + +@pytest.mark.parametrize(("cell_filter", "expected_a"), [(ExpCellFilter, 3.288585), (FrechetCellFilter, 3.291072)]) def test_relax_calc_many(Li2O: Structure, M3GNetCalc: M3GNetCalculator, cell_filter: Filter, expected_a: float) -> None: relax_calc = RelaxCalc(M3GNetCalc, optimizer="FIRE", cell_filter=cell_filter) results = list(relax_calc.calc_many([Li2O] * 2))