From d8fdb09c8e806a033d7f6a2e60451821dcf63941 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 7 Mar 2024 09:50:01 +0100 Subject: [PATCH] skip failing test_callable_sort_criteria test with TODO and link to matgl issue --- .../advanced_transformations.py | 20 +++++++++---------- .../test_advanced_transformations.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pymatgen/transformations/advanced_transformations.py b/pymatgen/transformations/advanced_transformations.py index 610c4098bad..751c9d6026e 100644 --- a/pymatgen/transformations/advanced_transformations.py +++ b/pymatgen/transformations/advanced_transformations.py @@ -338,7 +338,9 @@ def __init__( if max_cell_size and max_disordered_sites: raise ValueError("Cannot set both max_cell_size and max_disordered_sites!") - def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False): + def apply_transformation( + self, structure: Structure, return_ranked_list: bool | int = False + ) -> Structure | list[dict]: """Returns either a single ordered structure or a sequence of all ordered structures. @@ -879,7 +881,7 @@ def apply_transformation( # remove dummy species and replace Spin.up or Spin.down # with spin magnitudes given in mag_species_spin arg alls = self._remove_dummy_species(alls) - alls = self._add_spin_magnitudes(alls) + alls = self._add_spin_magnitudes(alls) # type: ignore[arg-type] else: for idx in range(len(alls)): alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"]) @@ -891,7 +893,7 @@ def apply_transformation( num_to_return = 1 if num_to_return == 1 or not return_ranked_list: - return alls[0]["structure"] if num_to_return else alls + return alls[0]["structure"] if num_to_return else alls # type: ignore[return-value] # remove duplicate structures and group according to energy model matcher = StructureMatcher(comparator=SpinComparator()) @@ -1010,11 +1012,10 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | Args: structure (Structure): Input structure to dope return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures. - is returned. If False, only the single lowest energy structure is returned. Defaults to False. Returns: - [{"structure": Structure, "energy": float}] + list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}. """ comp = structure.composition logger.info(f"Composition: {comp}") @@ -1059,7 +1060,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | logger.info(f"{lengths=}") logger.info(f"{scaling=}") - all_structures = [] + all_structures: list[dict] = [] trafo = EnumerateStructureTransformation(**self.kwargs) for sp in compatible_species: @@ -1131,10 +1132,9 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | } ) - ss = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum) - logger.info(f"{len(ss)} distinct structures") - all_structures.extend(ss) - + structs = trafo.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum) + logger.info(f"{len(structs)} distinct structures") + all_structures.extend(structs) logger.info(f"Total {len(all_structures)} doped structures") if return_ranked_list: return all_structures[:return_ranked_list] diff --git a/tests/transformations/test_advanced_transformations.py b/tests/transformations/test_advanced_transformations.py index b20cf223a71..4aa982e55e0 100644 --- a/tests/transformations/test_advanced_transformations.py +++ b/tests/transformations/test_advanced_transformations.py @@ -204,6 +204,7 @@ def test_m3gnet(self): # Check ordering of energy/atom assert alls[0]["energy"] / alls[0]["num_sites"] <= alls[-1]["energy"] / alls[-1]["num_sites"] + @pytest.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved") def test_callable_sort_criteria(self): matgl = pytest.importorskip("matgl") from matgl.ext.ase import Relaxer @@ -212,8 +213,8 @@ def test_callable_sort_criteria(self): m3gnet_model = Relaxer(potential=pot) - def sort_criteria(s): - relax_results = m3gnet_model.relax(s) + def sort_criteria(struct: Structure) -> tuple[Structure, float]: + relax_results = m3gnet_model.relax(struct) energy = float(relax_results["trajectory"].energies[-1]) return relax_results["final_structure"], energy