diff --git a/docs/releasehistory.md b/docs/releasehistory.md index 3a73dfcc..e41babb1 100644 --- a/docs/releasehistory.md +++ b/docs/releasehistory.md @@ -22,6 +22,7 @@ Please note that all releases prior to a version 1.0.0 are considered pre-releas * Several classes and methods which were deprecated in the 0.3 line of releases are now removed. * Previously-deprecated examples are removed. * `ProperTorsionKey` no longer accepts an empty tuple as atom indices. +* Fixes a regression in which some `ElectrostaticsCollection.charges` properties did not return cached values. ## 0.3.30 - 2024-08 diff --git a/openff/interchange/_annotations.py b/openff/interchange/_annotations.py index 7cfd18eb..2387d080 100644 --- a/openff/interchange/_annotations.py +++ b/openff/interchange/_annotations.py @@ -57,6 +57,7 @@ def _unit_validator_factory(unit: str) -> Callable: _is_kj_mol, _is_nanometer, _is_degree, + _is_elementary_charge, ) = ( _unit_validator_factory(unit=_unit) for _unit in [ @@ -64,6 +65,7 @@ def _unit_validator_factory(unit: str) -> Callable: "kilojoule / mole", "nanometer", "degree", + "elementary_charge", ] ) @@ -153,6 +155,13 @@ def quantity_json_serializer( WrapSerializer(quantity_json_serializer), ] +_ElementaryChargeQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_elementary_charge), + WrapSerializer(quantity_json_serializer), +] + _kJMolQuantity = Annotated[ Quantity, WrapValidator(quantity_validator), diff --git a/openff/interchange/_tests/test_issues.py b/openff/interchange/_tests/test_issues.py index 7b381ec2..6a3919ad 100644 --- a/openff/interchange/_tests/test_issues.py +++ b/openff/interchange/_tests/test_issues.py @@ -112,3 +112,10 @@ def test_issue_1031(monkeypatch): # check a few atom names to ensure these didn't end up being empty sets for atom_name in ("NE2", "H3", "HA", "CH3", "CA", "CB", "CE1"): assert atom_name in openff_atom_names + + +def test_issue_1052(sage, ethanol): + """Test that _SMIRNOFFElectrostaticsCollection.charges is populated.""" + out = sage.create_interchange(ethanol.to_topology()) + + assert len(out["Electrostatics"].charges) > 0 diff --git a/openff/interchange/common/_nonbonded.py b/openff/interchange/common/_nonbonded.py index 7d5cb9da..ecb47d60 100644 --- a/openff/interchange/common/_nonbonded.py +++ b/openff/interchange/common/_nonbonded.py @@ -3,9 +3,9 @@ from typing import Literal from openff.toolkit import Quantity, unit -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, computed_field -from openff.interchange._annotations import _DistanceQuantity +from openff.interchange._annotations import _DistanceQuantity, _ElementaryChargeQuantity from openff.interchange.components.potentials import Collection from openff.interchange.constants import _PME from openff.interchange.models import ( @@ -101,20 +101,17 @@ class ElectrostaticsCollection(_NonbondedCollection): nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb") exception_potential: Literal["Coulomb"] = Field("Coulomb") - _charges: dict[ - TopologyKey | LibraryChargeTopologyKey, - Quantity, - ] = PrivateAttr( - default_factory=dict, - ) + # TODO: Charge caching doesn't work when this is defined in the model + # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) _charges_cached: bool = PrivateAttr(default=False) + @computed_field # type: ignore[misc] @property def charges( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" - if len(self._charges) == 0 or self._charges_cached is False: + if len(self._charges) == 0 or self._charges_cached is False: # type: ignore[has-type] self._charges = self._get_charges(include_virtual_sites=False) self._charges_cached = True @@ -123,7 +120,7 @@ def charges( def _get_charges( self, include_virtual_sites: bool = False, - ) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, Quantity]: + ) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]: if include_virtual_sites: raise NotImplementedError() diff --git a/openff/interchange/components/potentials.py b/openff/interchange/components/potentials.py index 9d6b334e..5161f2cb 100644 --- a/openff/interchange/components/potentials.py +++ b/openff/interchange/components/potentials.py @@ -33,7 +33,7 @@ if has_package("jax"): from jax import Array else: - Array = Any + Array = Any # type: ignore class Potential(_BaseModel): diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index ee751f53..74f360f2 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -13,9 +13,10 @@ ToolkitAM1BCCHandler, vdWHandler, ) -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, computed_field from typing_extensions import Self +from openff.interchange._annotations import _ElementaryChargeQuantity from openff.interchange.common._nonbonded import ( ElectrostaticsCollection, _NonbondedCollection, @@ -272,8 +273,9 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect ) # type: ignore[assignment] exception_potential: Literal["Coulomb"] = Field("Coulomb") - _charges = PrivateAttr(default_factory=dict) - _charges_cached: bool + # TODO: Charge caching doesn't work when this is defined in the model + # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) + _charges_cached: bool = PrivateAttr(default=False) @classmethod def allowed_parameter_handlers(cls): @@ -292,14 +294,15 @@ def supported_parameters(cls): @property def _charges_without_virtual_sites( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, excluding virtual sites.""" return self._get_charges(include_virtual_sites=False) + @computed_field # type: ignore[misc] @property def charges( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges(include_virtual_sites=True) @@ -310,10 +313,10 @@ def charges( def _get_charges( self, include_virtual_sites=True, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom or particle.""" # Keyed by index for atoms and by VirtualSiteKey for virtual sites. - charges: dict[VirtualSiteKey | int, Quantity] = dict() + charges: dict[VirtualSiteKey | int, _ElementaryChargeQuantity] = dict() for topology_key, potential_key in self.key_map.items(): potential = self.potentials[potential_key]