Skip to content

Commit

Permalink
Fix all tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jul 25, 2023
1 parent d30ae1d commit 5678e8f
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__/
*.egg-info
.DS_Store
*.o
*.so
Expand Down
2 changes: 1 addition & 1 deletion matcalc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymatgen.core import Structure


class PropCalc(abc.ABCMeta):
class PropCalc(metaclass=abc.ABCMeta):
"""API for a property calculator."""

@abc.abstractmethod
Expand Down
164 changes: 164 additions & 0 deletions matcalc/relaxation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Phonon properties."""
from __future__ import annotations

import collections
import contextlib
import io
import pickle
from typing import TYPE_CHECKING

from ase.constraints import ExpCellFilter
from ase.optimize.bfgs import BFGS
from ase.optimize.bfgslinesearch import BFGSLineSearch
from ase.optimize.fire import FIRE
from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch
from ase.optimize.mdmin import MDMin
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
from pymatgen.io.ase import AseAtomsAdaptor

if TYPE_CHECKING:
import numpy as np
from ase.optimize.optimize import Optimizer

from .base import PropCalc

OPTIMIZERS = {
"FIRE": FIRE,
"BFGS": BFGS,
"LBFGS": LBFGS,
"LBFGSLineSearch": LBFGSLineSearch,
"MDMin": MDMin,
"SciPyFminCG": SciPyFminCG,
"SciPyFminBFGS": SciPyFminBFGS,
"BFGSLineSearch": BFGSLineSearch,
}
if TYPE_CHECKING:
from ase import Atoms
from ase.calculators.calculator import Calculator


class TrajectoryObserver(collections.abc.Sequence):
"""Trajectory observer is a hook in the relaxation process that saves the
intermediate structures.
"""

def __init__(self, atoms: Atoms) -> None:
"""
Init the Trajectory Observer from a Atoms.
Args:
atoms (Atoms): Structure to observe.
"""
self.atoms = atoms
self.energies: list[float] = []
self.forces: list[np.ndarray] = []
self.stresses: list[np.ndarray] = []
self.atom_positions: list[np.ndarray] = []
self.cells: list[np.ndarray] = []

def __call__(self) -> None:
"""The logic for saving the properties of an Atoms during the relaxation."""
self.energies.append(float(self.atoms.get_potential_energy()))
self.forces.append(self.atoms.get_forces())
self.stresses.append(self.atoms.get_stress())
self.atom_positions.append(self.atoms.get_positions())
self.cells.append(self.atoms.get_cell()[:])

def __getitem__(self, item):
return self.energies[item], self.forces[item], self.stresses[item], self.cells[item], self.atom_positions[item]

def __len__(self):
return len(self.energies)

def save(self, filename: str) -> None:
"""Save the trajectory to file.
Args:
filename (str): filename to save the trajectory.
"""
out = {
"energy": self.energies,
"forces": self.forces,
"stresses": self.stresses,
"atom_positions": self.atom_positions,
"cell": self.cells,
"atomic_number": self.atoms.get_atomic_numbers(),
}
with open(filename, "wb") as file:
pickle.dump(out, file)


class RelaxCalc(PropCalc):
"""Calculator for phonon properties."""

def __init__(
self,
calculator: Calculator,
optimizer: Optimizer | str = "FIRE",
fmax: float = 0.1,
steps: int = 500,
traj_file: str | None = None,
interval=1,
):
"""
Args:
calculator: ASE Calculator to use.
optimizer (str or ase Optimizer): the optimization algorithm.
Defaults to "FIRE"
fmax (float): total force tolerance for relaxation convergence. fmax is a sum of force and stress forces.
steps (int): max number of steps for relaxation.
traj_file (str): the trajectory file for saving
interval (int): the step interval for saving the trajectories.
"""
self.calculator = calculator
if isinstance(optimizer, str):
optimizer_obj = OPTIMIZERS.get(optimizer, None)
elif optimizer is None:
raise ValueError("Optimizer cannot be None")
else:
optimizer_obj = optimizer

self.opt_class: Optimizer = optimizer_obj
self.fmax = fmax
self.interval = interval
self.steps = steps
self.traj_file = traj_file

def calc(self, structure) -> dict:
"""
All PropCalc should implement a calc method that takes in a pymatgen structure and returns a dict. Note that
the method can return more than one property.
Args:
structure: Pymatgen structure.
Returns: {"prop name": value}
"""
ase_adaptor = AseAtomsAdaptor()
atoms = ase_adaptor.get_atoms(structure)
atoms.set_calculator(self.calculator)
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
atoms = ExpCellFilter(atoms)
optimizer = self.opt_class(atoms)
optimizer.attach(obs, interval=self.interval)
optimizer.run(fmax=self.fmax, steps=self.steps)
obs()
if self.traj_file is not None:
obs.save(self.traj_file)
atoms = atoms.atoms

final_structure = ase_adaptor.get_structure(atoms)
lattice = final_structure.lattice

return {
"final_structure": final_structure,
"a": lattice.a,
"b": lattice.b,
"c": lattice.c,
"alpha": lattice.alpha,
"beta": lattice.beta,
"gamma": lattice.gamma,
"volume": lattice.volume,
}
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ coveralls
mypy
ruff
black
matgl
111 changes: 111 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Define commonly used text fixtures. These are meant to be reused in unittests.
- Fixtures that are formulae (e.g., LiFePO4) returns the appropriate pymatgen Structure or Molecule based on the most
commonly known structure.
- Fixtures that are prefixed with `graph_` returns a (structure, graph, state) tuple.
Given that the fixtures are unlikely to be modified by the underlying code, the fixtures are set with a scope of
"session". In the event that future tests are written that modifies the fixtures, these can be set to the default scope
of "function".
"""
from __future__ import annotations

import pytest
from pymatgen.core import Lattice, Molecule, Structure
from pymatgen.util.testing import PymatgenTest

from matgl.ext.pymatgen import Molecule2Graph, Structure2Graph, get_element_list
from matgl.graph.compute import (
compute_pair_vector_and_distance,
)


def get_graph(structure, cutoff):
"""
Helper class to generate DGL graph from an input Structure or Molecule.
Returns:
Structure/Molecule, Graph, State
"""
element_types = get_element_list([structure])
if isinstance(structure, Structure):
converter = Structure2Graph(element_types=element_types, cutoff=cutoff) # type: ignore
else:
converter = Molecule2Graph(element_types=element_types, cutoff=cutoff) # type: ignore
graph, state = converter.get_graph(structure)
bond_vec, bond_dist = compute_pair_vector_and_distance(graph)
graph.edata["bond_dist"] = bond_dist
graph.edata["bond_vec"] = bond_vec
return structure, graph, state


@pytest.fixture(scope="session")
def LiFePO4():
return PymatgenTest.get_structure("LiFePO4")


@pytest.fixture(scope="session")
def CH4():
coords = [
[0.000000, 0.000000, 0.000000],
[0.000000, 0.000000, 1.089000],
[1.026719, 0.000000, -0.363000],
[-0.513360, -0.889165, -0.363000],
[-0.513360, 0.889165, -0.363000],
]
return Molecule(["C", "H", "H", "H", "H"], coords)


@pytest.fixture(scope="session")
def CO():
return Molecule(["C", "O"], [[0, 0, 0], [1.1, 0, 0]])


@pytest.fixture(scope="session")
def BaNiO3():
return PymatgenTest.get_structure("BaNiO3")


@pytest.fixture(scope="session")
def MoS():
return Structure(Lattice.cubic(4.0), ["Mo", "S"], [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])


@pytest.fixture(scope="session")
def graph_Mo():
s = Structure(Lattice.cubic(3.17), ["Mo", "Mo"], [[0.01, 0, 0], [0.5, 0.5, 0.5]])
return get_graph(s, 5.0)


@pytest.fixture(scope="session")
def graph_CH4(CH4):
"""
Returns:
Molecule, Graph, State
"""
return get_graph(CH4, 2.0)


@pytest.fixture(scope="session")
def graph_LiFePO4(LiFePO4):
"""
Returns:
Molecule, Graph, State
"""
return get_graph(LiFePO4, 4.0)


@pytest.fixture(scope="session")
def graph_MoS(MoS):
return get_graph(MoS, 5.0)


@pytest.fixture(scope="session")
def graph_CO(CO):
return get_graph(CO, 5.0)


@pytest.fixture(scope="session")
def graph_MoSH():
s = Structure(Lattice.cubic(3.17), ["Mo", "S", "H"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.75, 0.75, 0.75]])
return get_graph(s, 4.0)
17 changes: 17 additions & 0 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import matgl

from matgl.ext.ase import M3GNetCalculator

from matcalc.relaxation import RelaxCalc


def test_RelaxCalc(LiFePO4):
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
calculator = M3GNetCalculator(potential=potential, stress_weight=0.01)

calc = RelaxCalc(calculator)
results = calc.calc(LiFePO4)
assert results["a"] == pytest.approx(4.755711375217371)
assert results["b"] == pytest.approx(6.131614236614623)
assert results["c"] == pytest.approx(10.43859339794175)

0 comments on commit 5678e8f

Please sign in to comment.