Skip to content

Commit

Permalink
Define util functions is_ase_optimizer and get_ase_optimizer (#16)
Browse files Browse the repository at this point in the history
* define util functions is_ase_optimizer and get_ase_optimizer

* refactor NEBCalc and RelaxCalc to use get_ase_optimizer

* add test_get_ase_optimizer
  • Loading branch information
janosh committed Dec 12, 2023
1 parent 330fecc commit 70ad3b3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 30 deletions.
24 changes: 7 additions & 17 deletions matcalc/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
from __future__ import annotations

import os
from inspect import isclass
from typing import TYPE_CHECKING

from ase import optimize
from ase.io import Trajectory
from ase.neb import NEB, NEBTools
from ase.optimize.optimize import Optimizer
from pymatgen.core import Structure

from matcalc.base import PropCalc
from matcalc.utils import get_universal_calculator
from matcalc.utils import get_ase_optimizer, get_universal_calculator

if TYPE_CHECKING:
from ase.calculators.calculator import Calculator
from ase.optimize.optimize import Optimizer


class NEBCalc(PropCalc):
Expand All @@ -34,25 +32,17 @@ def __init__(
"""
Args:
images(list): A list of pymatgen structures as NEB image structures.
calculator(str|Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES.
optimizer(str|Optimizer): The optimization algorithm. Defaults to "BEGS".
traj_folder(str|None): The folder address to store NEB trajectories. Defaults to None.
calculator(str | Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES.
optimizer(str | Optimizer): The optimization algorithm. Defaults to "BEGS".
traj_folder(str | None): The folder address to store NEB trajectories. Defaults to None.
interval(int): The step interval for saving the trajectories. Defaults to 1.
climb(bool): Whether to enable climb image NEB. Defaults to True.
kwargs: Other arguments passed to ASE NEB object.
"""
self.images = images
self.calculator = get_universal_calculator(calculator)

# check str is valid optimizer key
def is_ase_optimizer(key):
return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer)

valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)]
if isinstance(optimizer, str) and optimizer not in valid_keys:
raise ValueError(f"Unknown {optimizer=}, must be one of {valid_keys}")

self.optimizer: Optimizer = getattr(optimize, optimizer) if isinstance(optimizer, str) else optimizer
self.optimizer = get_ase_optimizer(optimizer)
self.traj_folder = traj_folder
self.interval = interval
self.climb = climb
Expand Down Expand Up @@ -83,7 +73,7 @@ def from_end_images(
Args:
start_struct(Structure): The starting image as a pymatgen Structure.
end_struct(Structure): The ending image as a pymatgen Structure.
calculator(str|Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES.
calculator(str | Calculator): ASE Calculator to use. Defaults to M3GNet-MP-2021.2.8-DIRECT-PES.
n_images(int): The number of intermediate image structures to create.
interpolate_lattices(bool): Whether to interpolate the lattices when creating NEB
path with Structure.interpolate() in pymatgen. Defaults to False.
Expand Down
16 changes: 4 additions & 12 deletions matcalc/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
import contextlib
import io
import pickle
from inspect import isclass
from typing import TYPE_CHECKING

from ase import optimize
from ase.constraints import ExpCellFilter
from ase.optimize.optimize import Optimizer
from pymatgen.io.ase import AseAtomsAdaptor

from matcalc.utils import get_ase_optimizer

from .base import PropCalc

if TYPE_CHECKING:
import numpy as np
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.optimize.optimize import Optimizer
from pymatgen.core import Structure


Expand Down Expand Up @@ -95,15 +95,7 @@ def __init__(
"""
self.calculator = calculator

# check str is valid optimizer key
def is_ase_optimizer(key):
return isclass(obj := getattr(optimize, key)) and issubclass(obj, Optimizer)

valid_keys = [key for key in dir(optimize) if is_ase_optimizer(key)]
if isinstance(optimizer, str) and optimizer not in valid_keys:
raise ValueError(f"Unknown {optimizer=}, must be one of {valid_keys}")

self.optimizer: Optimizer = getattr(optimize, optimizer) if isinstance(optimizer, str) else optimizer
self.optimizer = get_ase_optimizer(optimizer)
self.fmax = fmax
self.interval = interval
self.max_steps = max_steps
Expand Down
33 changes: 33 additions & 0 deletions matcalc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from __future__ import annotations

import functools
from inspect import isclass
from typing import TYPE_CHECKING

import ase.optimize
from ase.optimize.optimize import Optimizer

if TYPE_CHECKING:
from ase.calculators.calculator import Calculator

Expand Down Expand Up @@ -61,3 +65,32 @@ def get_universal_calculator(name: str | Calculator, **kwargs) -> Calculator:
return mace_mp(**kwargs)

raise ValueError(f"Unrecognized {name=}, must be one of {UNIVERSAL_CALCULATORS}")


def is_ase_optimizer(key: str) -> bool:
"""Check if key is the name of an ASE optimizer class."""
return isclass(obj := getattr(ase.optimize, key)) and issubclass(obj, Optimizer)


VALID_OPTIMIZERS = [key for key in dir(ase.optimize) if is_ase_optimizer(key)]


def get_ase_optimizer(optimizer: str | Optimizer) -> Optimizer:
"""Validate optimizer is a valid ASE Optimizer.
Args:
optimizer (str | Optimizer): The optimization algorithm.
Raises:
ValueError: on unrecognized optimizer name.
Returns:
Optimizer: ASE Optimizer class.
"""
if isclass(optimizer) and issubclass(optimizer, Optimizer):
return optimizer

if optimizer not in VALID_OPTIMIZERS:
raise ValueError(f"Unknown {optimizer=}, must be one of {VALID_OPTIMIZERS}")

return getattr(ase.optimize, optimizer) if isinstance(optimizer, str) else optimizer
24 changes: 23 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

import ase.optimize
import pytest
from ase.calculators.calculator import Calculator
from ase.optimize.optimize import Optimizer

from matcalc.utils import UNIVERSAL_CALCULATORS, get_universal_calculator
from matcalc.utils import (
UNIVERSAL_CALCULATORS,
VALID_OPTIMIZERS,
get_ase_optimizer,
get_universal_calculator,
is_ase_optimizer,
)


def test_get_universal_calculator() -> None:
Expand All @@ -21,3 +29,17 @@ def test_get_universal_calculator() -> None:
# cover edge case like https://github.com/materialsvirtuallab/matcalc/issues/14
# where non-str and non-ASE Calculator instances are passed in
assert get_universal_calculator(42) == 42 # test non-str input is returned as-is


def test_get_ase_optimizer() -> None:
for name in dir(ase.optimize):
if is_ase_optimizer(name):
optimizer = get_ase_optimizer(name)
assert issubclass(optimizer, Optimizer)
same_optimizer = get_ase_optimizer(optimizer) # test ASE Optimizer classes are returned as-is
assert optimizer is same_optimizer

for optimizer in ("whatever", 42):
with pytest.raises(ValueError, match=f"Unknown {optimizer=}") as exc:
get_ase_optimizer(optimizer)
assert str(exc.value) == f"Unknown {optimizer=}, must be one of {VALID_OPTIMIZERS}"

0 comments on commit 70ad3b3

Please sign in to comment.