From 034791fa722d52f8f2ff10077e11c2c5b0ac2ff8 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 12 Dec 2023 09:50:42 -0800 Subject: [PATCH] fix get_universal_calculator AttributeError: 'SumCalculator' object has no attribute 'lower' from running get_universal_calculator("mace", model="large", default_dtype="float64", dispersion=True) --- matcalc/util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/matcalc/util.py b/matcalc/util.py index dd1c51b..318fb20 100644 --- a/matcalc/util.py +++ b/matcalc/util.py @@ -1,9 +1,12 @@ """Some utility methods, e.g., for getting calculators from well-known sources.""" + from __future__ import annotations import functools +from typing import TYPE_CHECKING -from ase.calculators.calculator import Calculator +if TYPE_CHECKING: + from ase.calculators.calculator import Calculator # Listing of supported universal calculators. UNIVERSAL_CALCULATORS = ( @@ -34,7 +37,7 @@ def get_universal_calculator(name: str | Calculator, **kwargs) -> Calculator: Returns: Calculator """ - if isinstance(name, Calculator): + if not isinstance(name, str): # e.g. already an ase Calculator instance return name if name.lower().startswith("m3gnet"): @@ -44,7 +47,7 @@ def get_universal_calculator(name: str | Calculator, **kwargs) -> Calculator: # M3GNet is shorthand for latest M3GNet based on DIRECT sampling. name = {"m3gnet": "M3GNet-MP-2021.2.8-DIRECT-PES"}.get(name.lower(), name) model = matgl.load_model(name) - kwargs.setdefault("stress_weight", 1.0 / 160.21766208) + kwargs.setdefault("stress_weight", 1 / 160.21766208) return M3GNetCalculator(potential=model, **kwargs) if name.lower() == "chgnet":