Skip to content

Commit

Permalink
Merge pull request #176 from PainterQubits/#169-add-pydantic
Browse files Browse the repository at this point in the history
#169 Add Pydantic
  • Loading branch information
alexhad6 authored Apr 25, 2024
2 parents afaa97a + 2018782 commit 8807f0a
Show file tree
Hide file tree
Showing 11 changed files with 612 additions and 180 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build:
- pip install poetry==1.8.2
post_install:
# See https://docs.readthedocs.io/en/stable/build-customization.html#install-dependencies-with-poetry
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --without dev
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install -E pydantic --without dev
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry run python -m ipykernel install --user

sphinx:
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- If Pydantic is installed, parameter data classes automatically have Pydantic type
validation enabled.

### Changed

- All `ParamData` objects now internally track the latest time that they or any of their
Expand Down
34 changes: 32 additions & 2 deletions docs/parameter-data.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ when building up dataclasses through inheritance.
from dataclasses import field
class KeywordOnlyParam(ParamDataclass, kw_only=True):
values: list[int] = field(default_factory=list)
count: int
values: list[int] = field(default_factory=list)
keyword_only_param = KeywordOnlyParam(count=123)
keyword_only_param
Expand Down Expand Up @@ -162,6 +162,34 @@ nested_param.child_param.root is nested_param
See [Type Mixins](#type-mixins) for information on how to get the parent and root
properties to work better with static type checkers.

### Type Validation

If [Pydantic] is installed, parameter data classes will automatically be converted to
[Pydantic data classes], enabling runtime type validation. Some [Pydantic configuration]
have modified defaults; see {py:class}`ParamDataclass` for more information.

Pydantic type validation will enforce type hints at runtime by raising an exception. For
example:

```{jupyter-execute}
import pydantic
try:
CustomParam(value="123")
except pydantic.ValidationError as exception:
print(exception)
```

Type validation can be disabled for a particular parameter data class (and its subclasses)
using the class keyword argument `type_validation`:

```{jupyter-execute}
class NoTypeValidationParam(CustomParam, type_validation=False):
pass
NoTypeValidationParam(value="123")
```

## Collections

Ordinary lists and dictionaries can be used within parameter data; however, any
Expand Down Expand Up @@ -265,5 +293,7 @@ This does nothing to the functionality, but static type checkers will now know t
[mutable default values]: https://docs.python.org/3/library/dataclasses.html#mutable-default-values
[`@property`]: https://docs.python.org/3/library/functions.html#property
[`__post_init__`]: https://docs.python.org/3/library/dataclasses.html#post-init-processing
[`abc.abc`]: https://docs.python.org/3/library/abc.html#abc.ABC
[Pydantic]: https://docs.pydantic.dev/latest/
[Pydantic data classes]: https://docs.pydantic.dev/latest/concepts/dataclasses/
[Pydantic configuration]: https://docs.pydantic.dev/latest/api/config/
[`collections.abc`]: https://docs.python.org/3/library/collections.abc.html
8 changes: 4 additions & 4 deletions paramdb/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
try:
from astropy.units import Quantity # type: ignore

ASTROPY_INSTALLED = True
_ASTROPY_INSTALLED = True
except ImportError:
ASTROPY_INSTALLED = False
_ASTROPY_INSTALLED = False

T = TypeVar("T")
SelectT = TypeVar("SelectT", bound=Select[Any])
Expand Down Expand Up @@ -62,7 +62,7 @@ def _to_dict(obj: Any) -> Any:
class_full_name_dict = {CLASS_NAME_KEY: class_full_name}
if isinstance(obj, datetime):
return class_full_name_dict | {"isoformat": obj.isoformat()}
if ASTROPY_INSTALLED and isinstance(obj, Quantity):
if _ASTROPY_INSTALLED and isinstance(obj, Quantity):
return class_full_name_dict | {"value": obj.value, "unit": str(obj.unit)}
if isinstance(obj, ParamData):
return {CLASS_NAME_KEY: type(obj).__name__} | obj.to_dict()
Expand All @@ -86,7 +86,7 @@ def _from_dict(json_dict: dict[str, Any]) -> Any:
class_name = json_dict.pop(CLASS_NAME_KEY)
if class_name == _full_class_name(datetime):
return datetime.fromisoformat(json_dict["isoformat"]).astimezone()
if ASTROPY_INSTALLED and class_name == _full_class_name(Quantity):
if _ASTROPY_INSTALLED and class_name == _full_class_name(Quantity):
return Quantity(**json_dict)
param_class = get_param_class(class_name)
if param_class is not None:
Expand Down
102 changes: 85 additions & 17 deletions paramdb/_param_data/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from typing_extensions import Self, dataclass_transform
from paramdb._param_data._param_data import ParamData

try:
import pydantic
import pydantic.dataclasses

_PYDANTIC_INSTALLED = True
except ImportError:
_PYDANTIC_INSTALLED = False


@dataclass_transform()
class ParamDataclass(ParamData):
Expand All @@ -20,32 +28,92 @@ class CustomParam(ParamDataclass):
value2: int
Any keyword arguments given when creating a subclass are passed internally to the
standard ``@dataclass()`` decorator.
Any class keyword arguments (other than those described below) given when creating a
subclass are passed internally to the ``@dataclass()`` decorator.
If Pydantic is installed, then subclasses will have Pydantic runtime validation
enabled by default. This can be disabled using the class keyword argument
``type_validation``. The following Pydantic configuration values are set by default:
- extra: ``'forbid'`` (forbid extra attributes)
- validate_assignment: ``True`` (validate on assignment as well as initialization)
- arbitrary_types_allowed: ``True`` (allow arbitrary type hints)
- strict: ``True`` (disable value coercion, e.g. '2' -> 2)
- validate_default: ``True`` (validate default values)
Pydantic configuration options can be updated using the class keyword argument
``pydantic_config``, which will merge new options with the existing configuration.
See https://docs.pydantic.dev/latest/api/config for full configuration options.
"""

__field_names: set[str]
__field_names: set[str] # Data class field names
__type_validation: bool = True # Whether to use Pydantic
__pydantic_config: pydantic.ConfigDict = {
"extra": "forbid",
"validate_assignment": True,
"arbitrary_types_allowed": True,
"strict": True,
"validate_default": True,
}

# Set in __init_subclass__() and used to set attributes within __setattr__()
# pylint: disable-next=unused-argument
def __base_setattr(self: Any, name: str, value: Any) -> None: ...

def __init_subclass__(
cls,
/,
type_validation: bool | None = None,
pydantic_config: pydantic.ConfigDict | None = None,
**kwargs: Any,
) -> None:
super().__init_subclass__() # kwargs are passed to dataclass constructor
if type_validation is not None:
cls.__type_validation = type_validation
if pydantic_config is not None:
# Merge new Pydantic config with the old one
cls.__pydantic_config = cls.__pydantic_config | pydantic_config
cls.__base_setattr = object.__setattr__ # type: ignore
if _PYDANTIC_INSTALLED and cls.__type_validation:
# Transform the class into a Pydantic data class, with custom handling for
# validate_assignment
pydantic.dataclasses.dataclass(
config=cls.__pydantic_config | {"validate_assignment": False}, **kwargs
)(cls)
if (
"validate_assignment" in cls.__pydantic_config
and cls.__pydantic_config["validate_assignment"]
):
pydantic_validator = (
pydantic.dataclasses.is_pydantic_dataclass(cls)
and cls.__pydantic_validator__ # pylint: disable=no-member
)
if pydantic_validator:

def __base_setattr(self: Any, name: str, value: Any) -> None:
pydantic_validator.validate_assignment(self, name, value)

cls.__base_setattr = __base_setattr # type: ignore
else:
# Transform the class into a data class
dataclass(**kwargs)(cls)
cls.__field_names = (
{f.name for f in fields(cls)} if is_dataclass(cls) else set()
)

# pylint: disable-next=unused-argument
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# Prevent instantiating ParamDataclass and call the superclass __init__() here
# since __init__() will be overwritten by dataclass()
if cls is ParamDataclass:
raise TypeError("only subclasses of ParamDataclass can be instantiated")
self = super(ParamDataclass, cls).__new__(cls)
super(ParamDataclass, self).__init__() # pylint: disable=super-with-arguments
raise TypeError(
f"only subclasses of {ParamDataclass.__name__} can be instantiated"
)
self = super().__new__(cls)
super().__init__(self)
return self

def __init_subclass__(cls, /, **kwargs: Any) -> None:
# Convert subclasses into dataclasses
super().__init_subclass__() # kwargs are passed to dataclass constructor
dataclass(**kwargs)(cls)
cls.__field_names = (
{f.name for f in fields(cls)} if is_dataclass(cls) else set()
)

def __post_init__(self) -> None:
# Called by the self.__init__() generated by dataclass()
for field_name in self.__field_names:
self._add_child(getattr(self, field_name))

Expand All @@ -61,12 +129,12 @@ def __setattr__(self, name: str, value: Any) -> None:
# If this attribute is a Data Class field, update last updated and children
if name in self.__field_names:
old_value = getattr(self, name) if hasattr(self, name) else None
super().__setattr__(name, value)
self.__base_setattr(name, value)
self._update_last_updated()
self._remove_child(old_value)
self._add_child(value)
return
super().__setattr__(name, value)
self.__base_setattr(name, value)

def _to_json(self) -> dict[str, Any]:
if is_dataclass(self):
Expand Down
33 changes: 15 additions & 18 deletions paramdb/_param_data/_param_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def get_param_class(class_name: str) -> type[ParamData] | None:
class ParamData(ABC):
"""Abstract base class for all parameter data."""

__parent: ParamData | None = None
__last_updated: datetime
_parent: ParamData | None = None
_last_updated: datetime

def __init_subclass__(cls, /, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
Expand All @@ -34,19 +34,17 @@ def __init_subclass__(cls, /, **kwargs: Any) -> None:
_param_classes[cls.__name__] = cls

def __init__(self) -> None:
self.__last_updated = datetime.now(timezone.utc).astimezone()
super().__setattr__("_last_updated", datetime.now(timezone.utc).astimezone())

def _add_child(self, child: Any) -> None:
"""Add the given object as a child, if it is ``ParamData``."""
if isinstance(child, ParamData):
# pylint: disable-next=protected-access,unused-private-member
child.__parent = self
super(ParamData, child).__setattr__("_parent", self)

def _remove_child(self, child: Any) -> None:
"""Remove the given object as a child, if it is ``ParamData``."""
if isinstance(child, ParamData):
# pylint: disable-next=protected-access,unused-private-member
child.__parent = None
super(ParamData, child).__setattr__("_parent", None)

def _update_last_updated(self) -> None:
"""Update last updated for this object and its chain of parents."""
Expand All @@ -57,10 +55,10 @@ def _update_last_updated(self) -> None:
# Continue up the chain of parents, stopping if we reach a last updated
# timestamp that is more recent than the new one
while current and not (
current.__last_updated and current.__last_updated >= new_last_updated
current._last_updated and current._last_updated >= new_last_updated
):
current.__last_updated = new_last_updated
current = current.__parent
super(ParamData, current).__setattr__("_last_updated", new_last_updated)
current = current._parent

@abstractmethod
def _to_json(self) -> Any:
Expand Down Expand Up @@ -91,7 +89,7 @@ def to_dict(self) -> dict[str, Any]:
Return a dictionary representation of this parameter data object, which can be
used to reconstruct the object by passing it to :py:meth:`from_dict`.
"""
return {_LAST_UPDATED_KEY: self.last_updated, _DATA_KEY: self._to_json()}
return {_LAST_UPDATED_KEY: self._last_updated, _DATA_KEY: self._to_json()}

@classmethod
def from_dict(cls, data_dict: dict[str, Any]) -> Self:
Expand All @@ -100,14 +98,13 @@ def from_dict(cls, data_dict: dict[str, Any]) -> Self:
``json.loads()`` and originally constructed by :py:meth:`from_dict`.
"""
param_data = cls._from_json(data_dict[_DATA_KEY])
# pylint: disable-next=protected-access,unused-private-member
param_data.__last_updated = data_dict[_LAST_UPDATED_KEY]
super().__setattr__(param_data, "_last_updated", data_dict[_LAST_UPDATED_KEY])
return param_data

@property
def last_updated(self) -> datetime:
"""When any parameter within this parameter data was last updated."""
return self.__last_updated
return self._last_updated

@property
def parent(self) -> ParamData:
Expand All @@ -119,12 +116,12 @@ def parent(self) -> ParamData:
Raises a ``ValueError`` if there is currently no parent, which can occur if the
parent is still being initialized.
"""
if self.__parent is None:
if self._parent is None:
raise ValueError(
f"'{type(self).__name__}' object has no parent, or its parent has not"
" been initialized yet"
)
return self.__parent
return self._parent

@property
def root(self) -> ParamData:
Expand All @@ -134,6 +131,6 @@ def root(self) -> ParamData:
"""
# pylint: disable=protected-access
root = self
while root.__parent is not None:
root = root.__parent
while root._parent is not None:
root = root._parent
return root
Loading

0 comments on commit 8807f0a

Please sign in to comment.