-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Shyue Ping Ong
committed
Jul 25, 2023
1 parent
d30ae1d
commit 5678e8f
Showing
6 changed files
with
295 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__pycache__/ | ||
*.egg-info | ||
.DS_Store | ||
*.o | ||
*.so | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
"""Phonon properties.""" | ||
from __future__ import annotations | ||
|
||
import collections | ||
import contextlib | ||
import io | ||
import pickle | ||
from typing import TYPE_CHECKING | ||
|
||
from ase.constraints import ExpCellFilter | ||
from ase.optimize.bfgs import BFGS | ||
from ase.optimize.bfgslinesearch import BFGSLineSearch | ||
from ase.optimize.fire import FIRE | ||
from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch | ||
from ase.optimize.mdmin import MDMin | ||
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG | ||
from pymatgen.io.ase import AseAtomsAdaptor | ||
|
||
if TYPE_CHECKING: | ||
import numpy as np | ||
from ase.optimize.optimize import Optimizer | ||
|
||
from .base import PropCalc | ||
|
||
OPTIMIZERS = { | ||
"FIRE": FIRE, | ||
"BFGS": BFGS, | ||
"LBFGS": LBFGS, | ||
"LBFGSLineSearch": LBFGSLineSearch, | ||
"MDMin": MDMin, | ||
"SciPyFminCG": SciPyFminCG, | ||
"SciPyFminBFGS": SciPyFminBFGS, | ||
"BFGSLineSearch": BFGSLineSearch, | ||
} | ||
if TYPE_CHECKING: | ||
from ase import Atoms | ||
from ase.calculators.calculator import Calculator | ||
|
||
|
||
class TrajectoryObserver(collections.abc.Sequence): | ||
"""Trajectory observer is a hook in the relaxation process that saves the | ||
intermediate structures. | ||
""" | ||
|
||
def __init__(self, atoms: Atoms) -> None: | ||
""" | ||
Init the Trajectory Observer from a Atoms. | ||
Args: | ||
atoms (Atoms): Structure to observe. | ||
""" | ||
self.atoms = atoms | ||
self.energies: list[float] = [] | ||
self.forces: list[np.ndarray] = [] | ||
self.stresses: list[np.ndarray] = [] | ||
self.atom_positions: list[np.ndarray] = [] | ||
self.cells: list[np.ndarray] = [] | ||
|
||
def __call__(self) -> None: | ||
"""The logic for saving the properties of an Atoms during the relaxation.""" | ||
self.energies.append(float(self.atoms.get_potential_energy())) | ||
self.forces.append(self.atoms.get_forces()) | ||
self.stresses.append(self.atoms.get_stress()) | ||
self.atom_positions.append(self.atoms.get_positions()) | ||
self.cells.append(self.atoms.get_cell()[:]) | ||
|
||
def __getitem__(self, item): | ||
return self.energies[item], self.forces[item], self.stresses[item], self.cells[item], self.atom_positions[item] | ||
|
||
def __len__(self): | ||
return len(self.energies) | ||
|
||
def save(self, filename: str) -> None: | ||
"""Save the trajectory to file. | ||
Args: | ||
filename (str): filename to save the trajectory. | ||
""" | ||
out = { | ||
"energy": self.energies, | ||
"forces": self.forces, | ||
"stresses": self.stresses, | ||
"atom_positions": self.atom_positions, | ||
"cell": self.cells, | ||
"atomic_number": self.atoms.get_atomic_numbers(), | ||
} | ||
with open(filename, "wb") as file: | ||
pickle.dump(out, file) | ||
|
||
|
||
class RelaxCalc(PropCalc): | ||
"""Calculator for phonon properties.""" | ||
|
||
def __init__( | ||
self, | ||
calculator: Calculator, | ||
optimizer: Optimizer | str = "FIRE", | ||
fmax: float = 0.1, | ||
steps: int = 500, | ||
traj_file: str | None = None, | ||
interval=1, | ||
): | ||
""" | ||
Args: | ||
calculator: ASE Calculator to use. | ||
optimizer (str or ase Optimizer): the optimization algorithm. | ||
Defaults to "FIRE" | ||
fmax (float): total force tolerance for relaxation convergence. fmax is a sum of force and stress forces. | ||
steps (int): max number of steps for relaxation. | ||
traj_file (str): the trajectory file for saving | ||
interval (int): the step interval for saving the trajectories. | ||
""" | ||
self.calculator = calculator | ||
if isinstance(optimizer, str): | ||
optimizer_obj = OPTIMIZERS.get(optimizer, None) | ||
elif optimizer is None: | ||
raise ValueError("Optimizer cannot be None") | ||
else: | ||
optimizer_obj = optimizer | ||
|
||
self.opt_class: Optimizer = optimizer_obj | ||
self.fmax = fmax | ||
self.interval = interval | ||
self.steps = steps | ||
self.traj_file = traj_file | ||
|
||
def calc(self, structure) -> dict: | ||
""" | ||
All PropCalc should implement a calc method that takes in a pymatgen structure and returns a dict. Note that | ||
the method can return more than one property. | ||
Args: | ||
structure: Pymatgen structure. | ||
Returns: {"prop name": value} | ||
""" | ||
ase_adaptor = AseAtomsAdaptor() | ||
atoms = ase_adaptor.get_atoms(structure) | ||
atoms.set_calculator(self.calculator) | ||
stream = io.StringIO() | ||
with contextlib.redirect_stdout(stream): | ||
obs = TrajectoryObserver(atoms) | ||
atoms = ExpCellFilter(atoms) | ||
optimizer = self.opt_class(atoms) | ||
optimizer.attach(obs, interval=self.interval) | ||
optimizer.run(fmax=self.fmax, steps=self.steps) | ||
obs() | ||
if self.traj_file is not None: | ||
obs.save(self.traj_file) | ||
atoms = atoms.atoms | ||
|
||
final_structure = ase_adaptor.get_structure(atoms) | ||
lattice = final_structure.lattice | ||
|
||
return { | ||
"final_structure": final_structure, | ||
"a": lattice.a, | ||
"b": lattice.b, | ||
"c": lattice.c, | ||
"alpha": lattice.alpha, | ||
"beta": lattice.beta, | ||
"gamma": lattice.gamma, | ||
"volume": lattice.volume, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ coveralls | |
mypy | ||
ruff | ||
black | ||
matgl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Define commonly used text fixtures. These are meant to be reused in unittests. | ||
- Fixtures that are formulae (e.g., LiFePO4) returns the appropriate pymatgen Structure or Molecule based on the most | ||
commonly known structure. | ||
- Fixtures that are prefixed with `graph_` returns a (structure, graph, state) tuple. | ||
Given that the fixtures are unlikely to be modified by the underlying code, the fixtures are set with a scope of | ||
"session". In the event that future tests are written that modifies the fixtures, these can be set to the default scope | ||
of "function". | ||
""" | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from pymatgen.core import Lattice, Molecule, Structure | ||
from pymatgen.util.testing import PymatgenTest | ||
|
||
from matgl.ext.pymatgen import Molecule2Graph, Structure2Graph, get_element_list | ||
from matgl.graph.compute import ( | ||
compute_pair_vector_and_distance, | ||
) | ||
|
||
|
||
def get_graph(structure, cutoff): | ||
""" | ||
Helper class to generate DGL graph from an input Structure or Molecule. | ||
Returns: | ||
Structure/Molecule, Graph, State | ||
""" | ||
element_types = get_element_list([structure]) | ||
if isinstance(structure, Structure): | ||
converter = Structure2Graph(element_types=element_types, cutoff=cutoff) # type: ignore | ||
else: | ||
converter = Molecule2Graph(element_types=element_types, cutoff=cutoff) # type: ignore | ||
graph, state = converter.get_graph(structure) | ||
bond_vec, bond_dist = compute_pair_vector_and_distance(graph) | ||
graph.edata["bond_dist"] = bond_dist | ||
graph.edata["bond_vec"] = bond_vec | ||
return structure, graph, state | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def LiFePO4(): | ||
return PymatgenTest.get_structure("LiFePO4") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def CH4(): | ||
coords = [ | ||
[0.000000, 0.000000, 0.000000], | ||
[0.000000, 0.000000, 1.089000], | ||
[1.026719, 0.000000, -0.363000], | ||
[-0.513360, -0.889165, -0.363000], | ||
[-0.513360, 0.889165, -0.363000], | ||
] | ||
return Molecule(["C", "H", "H", "H", "H"], coords) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def CO(): | ||
return Molecule(["C", "O"], [[0, 0, 0], [1.1, 0, 0]]) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def BaNiO3(): | ||
return PymatgenTest.get_structure("BaNiO3") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def MoS(): | ||
return Structure(Lattice.cubic(4.0), ["Mo", "S"], [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_Mo(): | ||
s = Structure(Lattice.cubic(3.17), ["Mo", "Mo"], [[0.01, 0, 0], [0.5, 0.5, 0.5]]) | ||
return get_graph(s, 5.0) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_CH4(CH4): | ||
""" | ||
Returns: | ||
Molecule, Graph, State | ||
""" | ||
return get_graph(CH4, 2.0) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_LiFePO4(LiFePO4): | ||
""" | ||
Returns: | ||
Molecule, Graph, State | ||
""" | ||
return get_graph(LiFePO4, 4.0) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_MoS(MoS): | ||
return get_graph(MoS, 5.0) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_CO(CO): | ||
return get_graph(CO, 5.0) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def graph_MoSH(): | ||
s = Structure(Lattice.cubic(3.17), ["Mo", "S", "H"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.75, 0.75, 0.75]]) | ||
return get_graph(s, 4.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pytest | ||
import matgl | ||
|
||
from matgl.ext.ase import M3GNetCalculator | ||
|
||
from matcalc.relaxation import RelaxCalc | ||
|
||
|
||
def test_RelaxCalc(LiFePO4): | ||
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES") | ||
calculator = M3GNetCalculator(potential=potential, stress_weight=0.01) | ||
|
||
calc = RelaxCalc(calculator) | ||
results = calc.calc(LiFePO4) | ||
assert results["a"] == pytest.approx(4.755711375217371) | ||
assert results["b"] == pytest.approx(6.131614236614623) | ||
assert results["c"] == pytest.approx(10.43859339794175) |