diff --git a/pymatgen/analysis/structure_prediction/substitutor.py b/pymatgen/analysis/structure_prediction/substitutor.py index 3d0bde07031..e4ac2565262 100644 --- a/pymatgen/analysis/structure_prediction/substitutor.py +++ b/pymatgen/analysis/structure_prediction/substitutor.py @@ -38,6 +38,8 @@ class Substitutor(MSONable): Inorganic Chemistry, 50(2), 656-663. doi:10.1021/ic102031h. """ + charge_balanced_tol: float = 1e-9 + def __init__(self, threshold=1e-3, symprec: float = 0.1, **kwargs): """ This substitutor uses the substitution probability class to @@ -158,9 +160,9 @@ def pred_from_structures( return transmuter.transformed_structures @staticmethod - def _is_charge_balanced(struct): + def _is_charge_balanced(struct) -> bool: """Checks if the structure object is charge balanced.""" - return sum(site.specie.oxi_state for site in struct) == 0.0 + return abs(sum(site.specie.oxi_state for site in struct)) < Substitutor.charge_balanced_tol @staticmethod def _is_from_chemical_system(chemical_system, struct): @@ -175,11 +177,9 @@ def pred_from_list(self, species_list): through these possibilities. The brute force method would be:: output = [] - for p in itertools.product(self._sp.species_list - , repeat = len(species_list)): - if self._sp.conditional_probability_list(p, species_list) - > self._threshold: - output.append(dict(zip(species_list,p))) + for p in itertools.product(self._sp.species_list, repeat=len(species_list)): + if self._sp.conditional_probability_list(p, species_list) > self._threshold: + output.append(dict(zip(species_list, p))) return output Instead of that we do a branch and bound. diff --git a/pymatgen/core/composition.py b/pymatgen/core/composition.py index 2ef6b9be88d..f2c70048f31 100644 --- a/pymatgen/core/composition.py +++ b/pymatgen/core/composition.py @@ -70,6 +70,7 @@ class Composition(collections.abc.Hashable, collections.abc.Mapping, MSONable, S # 1e-8 is fairly tight, but should cut out most floating point arithmetic # errors. amount_tolerance = 1e-8 + charge_balanced_tolerance = 1e-8 # Special formula handling for peroxides and certain elements. This is so # that formula output does not write LiO instead of Li2O2 for example. @@ -689,6 +690,38 @@ def to_data_dict(self) -> dict: "nelements": len(self), } + @property + def charge(self) -> float | None: + """Total charge based on oxidation states. If any oxidation states + are None or they're all 0, returns None. Use add_charges_from_oxi_state_guesses to + assign oxidation states to elements based on charge balancing. + """ + warnings.warn( + "Composition.charge is experimental and may produce incorrect results. Use with " + "caution and open a GitHub issue pinging @janosh to report bad behavior." + ) + oxi_states = [getattr(specie, "oxi_state", None) for specie in self] + if {*oxi_states} <= {0, None}: + # all oxidation states are None or 0 + return None + return sum(oxi * amt for oxi, amt in zip(oxi_states, self.values())) + + @property + def charge_balanced(self) -> bool | None: + """True if composition is charge balanced, False otherwise. If any oxidation states + are None, returns None. Use add_charges_from_oxi_state_guesses to assign oxidation + states to elements. + """ + warnings.warn( + "Composition.charge_balanced is experimental and may produce incorrect results. " + "Use with caution and open a GitHub issue pinging @janosh to report bad behavior." + ) + if self.charge is None: + if {getattr(el, "oxi_state", None) for el in self} == {0}: + return False + return None + return abs(self.charge) < Composition.charge_balanced_tolerance + def oxi_state_guesses( self, oxi_states_override: dict | None = None, @@ -797,16 +830,16 @@ def add_charges_from_oxi_state_guesses( routine. Args: - oxi_states_override (dict): dict of str->list to override an - element's common oxidation states, e.g. {"V": [2,3,4,5]} - target_charge (int): the desired total charge on the structure. + oxi_states_override (dict[str, list[float]]): Override an + element's common oxidation states, e.g. {"V": [2, 3, 4, 5]} + target_charge (float): the desired total charge on the structure. Default is 0 signifying charge balance. - all_oxi_states (bool): if True, an element defaults to + all_oxi_states (bool): If True, an element defaults to all oxidation states in pymatgen Element.icsd_oxidation_states. Otherwise, default is Element.common_oxidation_states. Note that the full oxidation state list is *very* inclusive and can produce nonsensical results. - max_sites (int): if possible, will reduce Compositions to at most + max_sites (int): If possible, will reduce Compositions to at most this many sites to speed up oxidation state guesses. If the composition cannot be reduced to this many sites a ValueError will be raised. Set to -1 to just reduce fully. If set to a diff --git a/tests/analysis/structure_prediction/test_substitution_probability.py b/tests/analysis/structure_prediction/test_substitution_probability.py index 8d12adc141a..f936d2198c5 100644 --- a/tests/analysis/structure_prediction/test_substitution_probability.py +++ b/tests/analysis/structure_prediction/test_substitution_probability.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os import unittest from pytest import approx @@ -20,14 +19,9 @@ def get_table(): initialization time, and make unit tests insensitive to changes in the default lambda table. """ - data_dir = os.path.join( - TEST_FILES_DIR, - "struct_predictor", - ) - - json_file = f"{data_dir}/test_lambda.json" - with open(json_file) as f: - return json.load(f) + json_path = f"{TEST_FILES_DIR}/struct_predictor/test_lambda.json" + with open(json_path) as file: + return json.load(file) class TestSubstitutionProbability(unittest.TestCase): diff --git a/tests/analysis/structure_prediction/test_substitutor.py b/tests/analysis/structure_prediction/test_substitutor.py index 512ba835528..fbc59fa8a95 100644 --- a/tests/analysis/structure_prediction/test_substitutor.py +++ b/tests/analysis/structure_prediction/test_substitutor.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os from pymatgen.analysis.structure_prediction.substitutor import Substitutor from pymatgen.core import Composition, Species @@ -14,13 +13,9 @@ def get_table(): initialization time, and make unit tests insensitive to changes in the default lambda table. """ - data_dir = os.path.join( - TEST_FILES_DIR, - "struct_predictor", - ) - json_file = f"{data_dir}/test_lambda.json" - with open(json_file) as f: - return json.load(f) + json_path = f"{TEST_FILES_DIR}/struct_predictor/test_lambda.json" + with open(json_path) as file: + return json.load(file) class TestSubstitutor(PymatgenTest): diff --git a/tests/analysis/test_local_env.py b/tests/analysis/test_local_env.py index ce58988f8f1..529652626ff 100644 --- a/tests/analysis/test_local_env.py +++ b/tests/analysis/test_local_env.py @@ -59,15 +59,15 @@ def setUp(self): [0.5, 0.5, 0.5], ] self._mgo_uc = Structure(mgo_latt, mgo_specie, mgo_frac_cord, validate_proximity=True, to_unit_cell=True) - self._mgo_valrad_evaluator = ValenceIonicRadiusEvaluator(self._mgo_uc) + self._mgo_val_rad_evaluator = ValenceIonicRadiusEvaluator(self._mgo_uc) def test_valences_ionic_structure(self): - valence_dict = self._mgo_valrad_evaluator.valences + valence_dict = self._mgo_val_rad_evaluator.valences for val in list(valence_dict.values()): assert val in {2, -2} def test_radii_ionic_structure(self): - radii_dict = self._mgo_valrad_evaluator.radii + radii_dict = self._mgo_val_rad_evaluator.radii for rad in list(radii_dict.values()): assert rad in {0.86, 1.26} diff --git a/tests/core/test_composition.py b/tests/core/test_composition.py index d11411d2231..01dc53c778f 100644 --- a/tests/core/test_composition.py +++ b/tests/core/test_composition.py @@ -712,6 +712,24 @@ def test_replace(self): c_new_4 = Ca2NF_oxi.replace(example_sub_4) assert c_new_4 == Composition("Mg2O2").add_charges_from_oxi_state_guesses() + def test_is_charge_balanced(self): + false_dct = dict.fromkeys("FeO FeO2 MgO Mg2O3 Mg2O4".split(), False) + true_dct = dict.fromkeys("Fe2O3 FeO CaTiO3 SrTiO3 MgO Mg2O2".split(), True) + + for formula, expected in (false_dct | true_dct).items(): + comp = Composition(formula) + # by default, compositions contain elements, not species and hence have no oxidation states + assert comp.charge is None + + # convert elements to species with oxidation states + oxi_comp = comp.add_charges_from_oxi_state_guesses() + assert oxi_comp.charge_balanced is expected, f"Failed for {formula=}" + + if expected is True: + assert abs(oxi_comp.charge) < Composition.charge_balanced_tolerance + else: + assert oxi_comp.charge is None + class TestChemicalPotential(unittest.TestCase): def test_init(self): diff --git a/tests/transformations/test_advanced_transformations.py b/tests/transformations/test_advanced_transformations.py index 6d5ba14566c..426d428ddab 100644 --- a/tests/transformations/test_advanced_transformations.py +++ b/tests/transformations/test_advanced_transformations.py @@ -59,10 +59,9 @@ def get_table(): initialization time, and make unit tests insensitive to changes in the default lambda table. """ - data_dir = f"{TEST_FILES_DIR}/struct_predictor" - json_file = f"{data_dir}/test_lambda.json" - with open(json_file) as f: - return json.load(f) + json_path = f"{TEST_FILES_DIR}/struct_predictor/test_lambda.json" + with open(json_path) as file: + return json.load(file) enum_cmd = which("enum.x") or which("multienum.x")