Skip to content

Commit

Permalink
RelaxCalc allow selecting any ASE Optimizer by name and raise better …
Browse files Browse the repository at this point in the history
…error msg on invalid key
  • Loading branch information
janosh committed Aug 13, 2023
1 parent 1538245 commit ee5848b
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions matcalc/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,21 @@
import contextlib
import io
import pickle
from inspect import isclass
from typing import TYPE_CHECKING

from ase import optimize
from ase.constraints import ExpCellFilter
from ase.optimize.bfgs import BFGS
from ase.optimize.bfgslinesearch import BFGSLineSearch
from ase.optimize.fire import FIRE
from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch
from ase.optimize.mdmin import MDMin
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
from ase.optimize.optimize import Optimizer
from pymatgen.io.ase import AseAtomsAdaptor

if TYPE_CHECKING:
import numpy as np
from ase.optimize.optimize import Optimizer
from pymatgen.core import Structure

from .base import PropCalc

OPTIMIZERS = {
"FIRE": FIRE,
"BFGS": BFGS,
"LBFGS": LBFGS,
"LBFGSLineSearch": LBFGSLineSearch,
"MDMin": MDMin,
"SciPyFminCG": SciPyFminCG,
"SciPyFminBFGS": SciPyFminBFGS,
"BFGSLineSearch": BFGSLineSearch,
}
if TYPE_CHECKING:
import numpy as np
from ase import Atoms
from ase.calculators.calculator import Calculator
from pymatgen.core import Structure


class TrajectoryObserver:
Expand Down Expand Up @@ -104,9 +88,21 @@ def __init__(
interval (int): The step interval for saving the trajectories.
fmax (float): Total force tolerance for relaxation convergence. fmax is a sum of force and stress forces.
relax_cell (bool): Whether to relax the cell (or just atoms).
Raises:
ValueError: If the optimizer is not a valid ASE optimizer.
"""
self.calculator = calculator
self.optimizer: Optimizer = OPTIMIZERS[optimizer] if isinstance(optimizer, str) else optimizer

# check str is valid optimizer key
def is_ase_optimizer(key):
return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer)

valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)]
if isinstance(optimizer, str) and optimizer not in valid_keys:
raise ValueError(f"Unknown {optimizer=}, must be one of {valid_keys}")

Check warning on line 103 in matcalc/relaxation.py

View check run for this annotation

Codecov / codecov/patch

matcalc/relaxation.py#L103

Added line #L103 was not covered by tests

self.optimizer: Optimizer = getattr(optimize, optimizer) if isinstance(optimizer, str) else optimizer
self.fmax = fmax
self.interval = interval
self.steps = steps
Expand Down

0 comments on commit ee5848b

Please sign in to comment.