Skip to content

Commit

Permalink
Merge pull request #866 from lsst/tickets/DM-40002
Browse files Browse the repository at this point in the history
DM-40002: Try to support pydantic v1 and v2
  • Loading branch information
timj authored Jul 19, 2023
2 parents f8470df + 1ce5be8 commit 1fe3838
Show file tree
Hide file tree
Showing 28 changed files with 504 additions and 310 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@ jobs:
channels: conda-forge,defaults
channel-priority: strict
show-channel-urls: true
miniforge-variant: Mambaforge
use-mamba: true

- name: Update pip/wheel infrastructure
shell: bash -l {0}
run: |
conda install -y -q pip wheel
mamba install -y -q pip wheel
- name: Install sqlite
shell: bash -l {0}
run: |
conda install -y -q sqlite
mamba install -y -q sqlite
# Postgres-14 is already installed from official postgres repo, but we
# also need pgsphere which is not installed. The repo is not in the list,
Expand All @@ -52,13 +54,13 @@ jobs:
- name: Install postgresql Python packages
shell: bash -l {0}
run: |
conda install -y -q psycopg2
mamba install -y -q psycopg2
pip install testing.postgresql
- name: Install cryptography package for moto
shell: bash -l {0}
run: |
conda install -y -q cryptography
mamba install -y -q cryptography
- name: Install dependencies
shell: bash -l {0}
Expand All @@ -69,13 +71,13 @@ jobs:
- name: Install pytest packages
shell: bash -l {0}
run: |
conda install -y -q \
mamba install -y -q \
pytest pytest-xdist pytest-openfiles pytest-cov
- name: List installed packages
shell: bash -l {0}
run: |
conda list
mamba list
pip list -v
- name: Build and install
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
cache: "pip"
cache-dependency-path: "setup.cfg"

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/rebase_checker.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
---
name: Check that 'main' is not merged into the development branch

on: pull_request
Expand Down
1 change: 1 addition & 0 deletions doc/changes/DM-40002.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Modified to work natively with Pydantic v1 and v2.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ disallow_untyped_defs = True
disallow_incomplete_defs = True
strict_equality = True
warn_unreachable = True
warn_unused_ignores = True
warn_unused_ignores = False

# ...except the modules and subpackages below (can't find a way to do line
# breaks in the lists of modules).
Expand Down
198 changes: 198 additions & 0 deletions python/lsst/daf/butler/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Code to support backwards compatibility."""

__all__ = ["PYDANTIC_V2", "_BaseModelCompat"]

import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic.version import VERSION as PYDANTIC_VERSION

if sys.version_info >= (3, 11, 0):
from typing import Self
else:
from typing import TypeVar

Self = TypeVar("Self", bound="_BaseModelCompat") # type: ignore


PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


if PYDANTIC_V2:

class _BaseModelCompat(BaseModel):
"""Methods from pydantic v1 that we want to emulate in v2.
Some of these methods are provided by v2 but issue deprecation
warnings. We need to decide whether we are also okay with deprecating
them or want to support them without the deprecation message.
"""

def json(
self,
*,
include: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None, # type: ignore
exclude: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None, # type: ignore
by_alias: bool = False,
skip_defaults: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Callable[[Any], Any] | None = None,
models_as_dict: bool = True,
**dumps_kwargs: Any,
) -> str:
if dumps_kwargs:
raise TypeError("dumps_kwargs no longer supported.")
if encoder is not None:
raise TypeError("json encoder is no longer supported.")
# Can catch warnings and call BaseModel.json() directly.
return self.model_dump_json(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
)

@classmethod
def parse_obj(cls, obj: Any) -> Self:
# Catch warnings and call BaseModel.parse_obj directly?
return cls.model_validate(obj)

if TYPE_CHECKING and not PYDANTIC_V2:
# mypy sees the first definition of a class and ignores any
# redefinition. This means that if mypy is run with pydantic v1
# it will not see the classes defined in the else block below.

@classmethod
def model_construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self:
return cls()

@classmethod
def model_validate(
cls,
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> Self:
return cls()

def model_dump_json(
self,
*,
indent: int | None = None,
include: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None,
exclude: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> str:
return ""

@property
def model_fields(self) -> dict[str, FieldInfo]: # type: ignore
return {}

@classmethod
def model_rebuild(
cls,
*,
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: dict[str, Any] | None = None,
) -> bool | None:
return None

else:
from astropy.utils.decorators import classproperty

class _BaseModelCompat(BaseModel): # type:ignore[no-redef]
"""Methods from pydantic v2 that can be used in pydantic v1."""

@classmethod
def model_validate(
cls,
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> Self:
return cls.parse_obj(obj)

def model_dump_json(
self,
*,
indent: int | None = None,
include: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None,
exclude: set[int] | set[str] | dict[int, Any] | dict[str, Any] | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> str:
return self.json(
include=include, # type: ignore
exclude=exclude, # type: ignore
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

@classmethod # type: ignore
def model_construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self:
# BaseModel.construct() is very close to what we previously
# implemented manually in each direct() method but does have one
# extra loop in it to fill in defaults and handle aliases.
return cls.construct(_fields_set=_fields_set, **values)

@classmethod
@classproperty
def model_fields(cls) -> dict[str, FieldInfo]: # type: ignore
return cls.__fields__ # type: ignore

@classmethod
def model_rebuild(
cls,
*,
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: dict[str, Any] | None = None,
) -> bool | None:
return cls.update_forward_refs()
27 changes: 10 additions & 17 deletions python/lsst/daf/butler/_quantum_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@
from typing import TYPE_CHECKING, Any

from deprecated.sphinx import deprecated
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.resources import ResourcePathExpression

try:
from pydantic.v1 import BaseModel
except ModuleNotFoundError:
from pydantic import BaseModel # type: ignore

from ._butlerConfig import ButlerConfig
from ._deferredDatasetHandle import DeferredDatasetHandle
from ._limited_butler import LimitedButler
Expand Down Expand Up @@ -597,7 +593,7 @@ def extract_provenance_data(self) -> QuantumProvenanceData:
)


class QuantumProvenanceData(BaseModel):
class QuantumProvenanceData(_BaseModelCompat):
"""A serializable struct for per-quantum provenance information and
datastore records.
Expand Down Expand Up @@ -749,19 +745,16 @@ def _to_uuid_set(uuids: Iterable[str | uuid.UUID]) -> set[uuid.UUID]:
"""
return {uuid.UUID(id) if isinstance(id, str) else id for id in uuids}

data = QuantumProvenanceData.__new__(cls)
setter = object.__setattr__
setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs))
setter(data, "available_inputs", _to_uuid_set(available_inputs))
setter(data, "actual_inputs", _to_uuid_set(actual_inputs))
setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs))
setter(data, "actual_outputs", _to_uuid_set(actual_outputs))
setter(
data,
"datastore_records",
{
data = cls.model_construct(
predicted_inputs=_to_uuid_set(predicted_inputs),
available_inputs=_to_uuid_set(available_inputs),
actual_inputs=_to_uuid_set(actual_inputs),
predicted_outputs=_to_uuid_set(predicted_outputs),
actual_outputs=_to_uuid_set(actual_outputs),
datastore_records={
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)

return data
37 changes: 18 additions & 19 deletions python/lsst/daf/butler/core/datasets/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,11 @@
import sys
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, runtime_checkable

from lsst.daf.butler._compat import _BaseModelCompat
from lsst.utils.classes import immutable

try:
from pydantic.v1 import BaseModel, StrictStr, validator
except ModuleNotFoundError:
from pydantic import BaseModel, StrictStr, validator # type: ignore
from pydantic import StrictStr, validator

from ..configSupport import LookupKey
from ..dimensions import DataCoordinate, DimensionGraph, DimensionUniverse, SerializedDataCoordinate
Expand Down Expand Up @@ -173,7 +170,7 @@ def makeDatasetId(
_serializedDatasetRefFieldsSet = {"id", "datasetType", "dataId", "run", "component"}


class SerializedDatasetRef(BaseModel):
class SerializedDatasetRef(_BaseModelCompat):
"""Simplified model of a `DatasetRef` suitable for serialization."""

id: uuid.UUID
Expand Down Expand Up @@ -224,22 +221,24 @@ def direct(
This method should only be called when the inputs are trusted.
"""
node = SerializedDatasetRef.__new__(cls)
setter = object.__setattr__
setter(node, "id", uuid.UUID(id))
setter(
node,
"datasetType",
datasetType if datasetType is None else SerializedDatasetType.direct(**datasetType),
serialized_datasetType = (
SerializedDatasetType.direct(**datasetType) if datasetType is not None else None
)
serialized_dataId = SerializedDataCoordinate.direct(**dataId) if dataId is not None else None

node = cls.model_construct(
_fields_set=_serializedDatasetRefFieldsSet,
id=uuid.UUID(id),
datasetType=serialized_datasetType,
dataId=serialized_dataId,
run=sys.intern(run),
component=component,
)
setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId))
setter(node, "run", sys.intern(run))
setter(node, "component", component)
setter(node, "__fields_set__", _serializedDatasetRefFieldsSet)

return node


DatasetId = uuid.UUID
DatasetId: TypeAlias = uuid.UUID
"""A type-annotation alias for dataset ID providing typing flexibility.
"""

Expand Down
Loading

0 comments on commit 1fe3838

Please sign in to comment.