diff --git a/matcalc/phonon.py b/matcalc/phonon.py index 9c95db8..2672ffb 100644 --- a/matcalc/phonon.py +++ b/matcalc/phonon.py @@ -71,14 +71,9 @@ def calc(self, structure) -> dict: phonon = phonopy.Phonopy(cell, self.supercell_matrix) phonon.generate_displacements(distance=self.atom_disp) disp_supercells = phonon.supercells_with_displacements - forces = [ - _calc_forces(self.calculator, supercell) - for supercell in [phonon.supercell, *disp_supercells] - if supercell is not None + phonon.forces = [ + _calc_forces(self.calculator, supercell) for supercell in disp_supercells if supercell is not None ] - # parallel = Parallel(n_jobs=1) - # forces = parallel(delayed(_calc_forces)(self.calculator, s) for s in structure_list) - phonon.forces = forces[1:] phonon.produce_force_constants() phonon.run_mesh() phonon.run_thermal_properties(t_step=self.t_step, t_max=self.t_max, t_min=self.t_min) diff --git a/matcalc/util.py b/matcalc/util.py index f1023ae..5ab69e6 100644 --- a/matcalc/util.py +++ b/matcalc/util.py @@ -5,10 +5,12 @@ @functools.lru_cache -def get_calculator(name: str, **kwargs): +def get_universal_calculator(name: str, **kwargs): """ - Helper method to get some well-known calculators. Note that imports should be within the if statements to ensure - that all these are optional. + Helper method to get some well-known **universal** calculators. Note that imports should be within the if + statements to ensure that all these are optional. It should be stressed that this method is for universal + calculators encompassing a wide swath of the periodic table only. Though matcalc can be used with any MLIP, it is + not the intention for this method to provide a listing of all MLIPs. Args: name (str): Name of calculator. diff --git a/tests/test_util.py b/tests/test_util.py index 079cc9a..d292b7b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,12 +1,12 @@ import pytest -from matcalc.util import get_calculator +from matcalc.util import get_universal_calculator from ase.calculators.calculator import Calculator def test_get_calculator(): for name in ("M3GNet", "M3GNet-MP-2021.2.8-PES", "CHGNet"): - calc = get_calculator(name) + calc = get_universal_calculator(name) assert isinstance(calc, Calculator) with pytest.raises(ValueError, match="Unsupported model name"): - get_calculator("whatever") + get_universal_calculator("whatever")