Skip to content

Commit

Permalink
Add get_calculator method.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Aug 7, 2023
1 parent ca2b1a2 commit 003e97c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
33 changes: 33 additions & 0 deletions matcalc/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Some utility methods, e.g., for getting calculators from well-known sources."""
from __future__ import annotations

import functools


@functools.lru_cache
def get_calculator(name: str, **kwargs):
"""
Helper method to get some well-known calculators. Note that imports should be within the if statements to ensure
that all these are optional.
Args:
name (str): Name of calculator.
**kwargs: Passthrough to calculator init.
Returns:
Calculator
"""
if name in ("M3GNet-MP-2021.2.8-PES", "M3GNet-MP-2021.2.8-DIRECT-PES"):
import matgl
from matgl.ext.ase import M3GNetCalculator

potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
return M3GNetCalculator(potential=potential, stress_weight=0.01, **kwargs)

if name == "CHGNet":
from chgnet.model.dynamics import CHGNetCalculator
from chgnet.model.model import CHGNet

return CHGNetCalculator(CHGNet.load(), stress_weight=0.01, **kwargs)

raise ValueError(f"Unsupported model name: {name}")
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mypy
ruff
black
matgl
chgnet
12 changes: 12 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from matcalc.util import get_calculator
from ase.calculators.calculator import Calculator


def test_get_calculator():
for name in ("M3GNet-MP-2021.2.8-DIRECT-PES", "CHGNet"):
calc = get_calculator(name)
assert isinstance(calc, Calculator)
with pytest.raises(ValueError, match="Unsupported model name"):
get_calculator("whatever")

0 comments on commit 003e97c

Please sign in to comment.