diff --git a/matcalc/neb.py b/matcalc/neb.py index 42389af..3681658 100644 --- a/matcalc/neb.py +++ b/matcalc/neb.py @@ -2,20 +2,18 @@ from __future__ import annotations import os -from inspect import isclass from typing import TYPE_CHECKING -from ase import optimize from ase.io import Trajectory from ase.neb import NEB, NEBTools -from ase.optimize.optimize import Optimizer from pymatgen.core import Structure from matcalc.base import PropCalc -from matcalc.utils import get_universal_calculator +from matcalc.utils import get_ase_optimizer, get_universal_calculator if TYPE_CHECKING: from ase.calculators.calculator import Calculator + from ase.optimize.optimize import Optimizer class NEBCalc(PropCalc): @@ -34,9 +32,9 @@ def __init__( """ Args: images(list): A list of pymatgen structures as NEB image structures. - calculator(str|Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES. - optimizer(str|Optimizer): The optimization algorithm. Defaults to "BEGS". - traj_folder(str|None): The folder address to store NEB trajectories. Defaults to None. + calculator(str | Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES. + optimizer(str | Optimizer): The optimization algorithm. Defaults to "BEGS". + traj_folder(str | None): The folder address to store NEB trajectories. Defaults to None. interval(int): The step interval for saving the trajectories. Defaults to 1. climb(bool): Whether to enable climb image NEB. Defaults to True. kwargs: Other arguments passed to ASE NEB object. @@ -44,15 +42,7 @@ def __init__( self.images = images self.calculator = get_universal_calculator(calculator) - # check str is valid optimizer key - def is_ase_optimizer(key): - return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer) - - valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)] - if isinstance(optimizer, str) and optimizer not in valid_keys: - raise ValueError(f"Unknown {optimizer=}, must be one of {valid_keys}") - - self.optimizer: Optimizer = getattr(optimize, optimizer) if isinstance(optimizer, str) else optimizer + self.optimizer = get_ase_optimizer(optimizer) self.traj_folder = traj_folder self.interval = interval self.climb = climb @@ -83,7 +73,7 @@ def from_end_images( Args: start_struct(Structure): The starting image as a pymatgen Structure. end_struct(Structure): The ending image as a pymatgen Structure. - calculator(str|Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES. + calculator(str | Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES. n_images(int): The number of intermediate image structures to create. interpolate_lattices(bool): Whether to interpolate the lattices when creating NEB path with Structure.interpolate() in pymatgen. Defaults to False. diff --git a/matcalc/relaxation.py b/matcalc/relaxation.py index 1fce42b..4cac971 100644 --- a/matcalc/relaxation.py +++ b/matcalc/relaxation.py @@ -4,20 +4,20 @@ import contextlib import io import pickle -from inspect import isclass from typing import TYPE_CHECKING -from ase import optimize from ase.constraints import ExpCellFilter -from ase.optimize.optimize import Optimizer from pymatgen.io.ase import AseAtomsAdaptor +from matcalc.utils import get_ase_optimizer + from .base import PropCalc if TYPE_CHECKING: import numpy as np from ase import Atoms from ase.calculators.calculator import Calculator + from ase.optimize.optimize import Optimizer from pymatgen.core import Structure @@ -95,15 +95,7 @@ def __init__( """ self.calculator = calculator - # check str is valid optimizer key - def is_ase_optimizer(key): - return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer) - - valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)] - if isinstance(optimizer, str) and optimizer not in valid_keys: - raise ValueError(f"Unknown {optimizer=}, must be one of {valid_keys}") - - self.optimizer: Optimizer = getattr(optimize, optimizer) if isinstance(optimizer, str) else optimizer + self.optimizer = get_ase_optimizer(optimizer) self.fmax = fmax self.interval = interval self.max_steps = max_steps diff --git a/matcalc/utils.py b/matcalc/utils.py index 318fb20..726de2d 100644 --- a/matcalc/utils.py +++ b/matcalc/utils.py @@ -3,8 +3,12 @@ from __future__ import annotations import functools +from inspect import isclass from typing import TYPE_CHECKING +import ase.optimize +from ase.optimize.optimize import Optimizer + if TYPE_CHECKING: from ase.calculators.calculator import Calculator @@ -61,3 +65,32 @@ def get_universal_calculator(name: str | Calculator, **kwargs) -> Calculator: return mace_mp(**kwargs) raise ValueError(f"Unrecognized {name=}, must be one of {UNIVERSAL_CALCULATORS}") + + +def is_ase_optimizer(key: str) -> bool: + """Check if key is the name of an ASE optimizer class.""" + return isclass(obj := getattr(ase.optimize, key)) and issubclass(obj, Optimizer) + + +VALID_OPTIMIZERS = [key for key in dir(ase.optimize) if is_ase_optimizer(key)] + + +def get_ase_optimizer(optimizer: str | Optimizer) -> Optimizer: + """Validate optimizer is a valid ASE Optimizer. + + Args: + optimizer (str | Optimizer): The optimization algorithm. + + Raises: + ValueError: on unrecognized optimizer name. + + Returns: + Optimizer: ASE Optimizer class. + """ + if isclass(optimizer) and issubclass(optimizer, Optimizer): + return optimizer + + if optimizer not in VALID_OPTIMIZERS: + raise ValueError(f"Unknown {optimizer=}, must be one of {VALID_OPTIMIZERS}") + + return getattr(ase.optimize, optimizer) if isinstance(optimizer, str) else optimizer diff --git a/tests/test_utils.py b/tests/test_utils.py index f0d67c4..4a9997c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,17 @@ from __future__ import annotations +import ase.optimize import pytest from ase.calculators.calculator import Calculator +from ase.optimize.optimize import Optimizer -from matcalc.utils import UNIVERSAL_CALCULATORS, get_universal_calculator +from matcalc.utils import ( + UNIVERSAL_CALCULATORS, + VALID_OPTIMIZERS, + get_ase_optimizer, + get_universal_calculator, + is_ase_optimizer, +) def test_get_universal_calculator() -> None: @@ -21,3 +29,17 @@ def test_get_universal_calculator() -> None: # cover edge case like https://github.com/materialsvirtuallab/matcalc/issues/14 # where non-str and non-ASE Calculator instances are passed in assert get_universal_calculator(42) == 42 # test non-str input is returned as-is + + +def test_get_ase_optimizer() -> None: + for name in dir(ase.optimize): + if is_ase_optimizer(name): + optimizer = get_ase_optimizer(name) + assert issubclass(optimizer, Optimizer) + same_optimizer = get_ase_optimizer(optimizer) # test ASE Optimizer classes are returned as-is + assert optimizer is same_optimizer + + for optimizer in ("whatever", 42): + with pytest.raises(ValueError, match=f"Unknown {optimizer=}") as exc: + get_ase_optimizer(optimizer) + assert str(exc.value) == f"Unknown {optimizer=}, must be one of {VALID_OPTIMIZERS}"