From a4fcc530d74ee3e392c0b56cbea9ba2bc315e06b Mon Sep 17 00:00:00 2001 From: Tim Jenness Date: Mon, 17 Jul 2023 16:18:23 -0700 Subject: [PATCH] Use model_construct for pydantic v2 This is meant to be the rquivalent of using __setattr__ and __field_set__ directly. The latter has been renamed in pydantic v2 --- python/lsst/daf/butler/_quantum_backed.py | 53 +++++--- python/lsst/daf/butler/core/datasets/ref.py | 40 ++++-- python/lsst/daf/butler/core/datasets/type.py | 43 +++--- .../daf/butler/core/datastoreRecordData.py | 21 ++- .../daf/butler/core/dimensions/_coordinate.py | 28 ++-- .../lsst/daf/butler/core/dimensions/_graph.py | 13 +- .../daf/butler/core/dimensions/_records.py | 20 ++- python/lsst/daf/butler/core/quantum.py | 123 ++++++++++-------- 8 files changed, 212 insertions(+), 129 deletions(-) diff --git a/python/lsst/daf/butler/_quantum_backed.py b/python/lsst/daf/butler/_quantum_backed.py index 9f2aabb4a2..fcd8279628 100644 --- a/python/lsst/daf/butler/_quantum_backed.py +++ b/python/lsst/daf/butler/_quantum_backed.py @@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any from deprecated.sphinx import deprecated -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.resources import ResourcePathExpression from ._butlerConfig import ButlerConfig @@ -745,19 +745,40 @@ def _to_uuid_set(uuids: Iterable[str | uuid.UUID]) -> set[uuid.UUID]: """ return {uuid.UUID(id) if isinstance(id, str) else id for id in uuids} - data = QuantumProvenanceData.__new__(cls) - setter = object.__setattr__ - setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs)) - setter(data, "available_inputs", _to_uuid_set(available_inputs)) - setter(data, "actual_inputs", _to_uuid_set(actual_inputs)) - setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs)) - setter(data, "actual_outputs", _to_uuid_set(actual_outputs)) - setter( - data, - "datastore_records", - { - key: SerializedDatastoreRecordData.direct(**records) - for key, records in datastore_records.items() - }, - ) + if PYDANTIC_V2: + data = cls.model_construct( + _fields_set={ + "predicted_inputs", + "available_inputs", + "actual_inputs", + "predicted_outputs", + "actual_outputs", + "datastore_records", + }, + predicted_inputs=_to_uuid_set(predicted_inputs), + available_inputs=_to_uuid_set(available_inputs), + actual_inputs=_to_uuid_set(actual_inputs), + predicted_outputs=_to_uuid_set(predicted_outputs), + actual_outputs=_to_uuid_set(actual_outputs), + datastore_records={ + key: SerializedDatastoreRecordData.direct(**records) + for key, records in datastore_records.items() + }, + ) + else: + data = QuantumProvenanceData.__new__(cls) + setter = object.__setattr__ + setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs)) + setter(data, "available_inputs", _to_uuid_set(available_inputs)) + setter(data, "actual_inputs", _to_uuid_set(actual_inputs)) + setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs)) + setter(data, "actual_outputs", _to_uuid_set(actual_outputs)) + setter( + data, + "datastore_records", + { + key: SerializedDatastoreRecordData.direct(**records) + for key, records in datastore_records.items() + }, + ) return data diff --git a/python/lsst/daf/butler/core/datasets/ref.py b/python/lsst/daf/butler/core/datasets/ref.py index 030c97e5a0..93ed396edb 100644 --- a/python/lsst/daf/butler/core/datasets/ref.py +++ b/python/lsst/daf/butler/core/datasets/ref.py @@ -33,9 +33,9 @@ import sys import uuid from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, runtime_checkable -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.utils.classes import immutable from pydantic import StrictStr, validator @@ -221,22 +221,34 @@ def direct( This method should only be called when the inputs are trusted. """ - node = SerializedDatasetRef.__new__(cls) - setter = object.__setattr__ - setter(node, "id", uuid.UUID(id)) - setter( - node, - "datasetType", - datasetType if datasetType is None else SerializedDatasetType.direct(**datasetType), + serialized_datasetType = ( + SerializedDatasetType.direct(**datasetType) if datasetType is not None else None ) - setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) - setter(node, "run", sys.intern(run)) - setter(node, "component", component) - setter(node, "__fields_set__", _serializedDatasetRefFieldsSet) + serialized_dataId = SerializedDataCoordinate.direct(**dataId) if dataId is not None else None + + if PYDANTIC_V2: + node = cls.model_construct( + _fields_set=_serializedDatasetRefFieldsSet, + id=uuid.UUID(id), + datasetType=serialized_datasetType, + dataId=serialized_dataId, + run=sys.intern(run), + component=component, + ) + else: + node = SerializedDatasetRef.__new__(cls) + setter = object.__setattr__ + setter(node, "id", uuid.UUID(id)) + setter(node, "datasetType", serialized_datasetType) + setter(node, "dataId", serialized_dataId) + setter(node, "run", sys.intern(run)) + setter(node, "component", component) + setter(node, "__fields_set__", _serializedDatasetRefFieldsSet) + return node -DatasetId = uuid.UUID +DatasetId: TypeAlias = uuid.UUID """A type-annotation alias for dataset ID providing typing flexibility. """ diff --git a/python/lsst/daf/butler/core/datasets/type.py b/python/lsst/daf/butler/core/datasets/type.py index 80e5f31ecb..55b3885366 100644 --- a/python/lsst/daf/butler/core/datasets/type.py +++ b/python/lsst/daf/butler/core/datasets/type.py @@ -29,7 +29,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, ClassVar -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from pydantic import StrictBool, StrictStr from ..configSupport import LookupKey @@ -80,22 +80,33 @@ def direct( key = (name, storageClass or "") if cache is not None and (type_ := cache.get(key, None)) is not None: return type_ - node = SerializedDatasetType.__new__(cls) - setter = object.__setattr__ - setter(node, "name", name) - setter(node, "storageClass", storageClass) - setter( - node, - "dimensions", - dimensions if dimensions is None else SerializedDimensionGraph.direct(**dimensions), - ) - setter(node, "parentStorageClass", parentStorageClass) - setter(node, "isCalibration", isCalibration) - setter( - node, - "__fields_set__", - {"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"}, + + serialized_dimensions = ( + SerializedDimensionGraph.direct(**dimensions) if dimensions is not None else None ) + + if PYDANTIC_V2: + node = cls.model_construct( + name=name, + storageClass=storageClass, + dimensions=serialized_dimensions, + parentStorageClass=parentStorageClass, + isCalibration=isCalibration, + ) + else: + node = SerializedDatasetType.__new__(cls) + setter = object.__setattr__ + setter(node, "name", name) + setter(node, "storageClass", storageClass) + setter(node, "dimensions", serialized_dimensions) + setter(node, "parentStorageClass", parentStorageClass) + setter(node, "isCalibration", isCalibration) + setter( + node, + "__fields_set__", + {"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"}, + ) + if cache is not None: cache[key] = node return node diff --git a/python/lsst/daf/butler/core/datastoreRecordData.py b/python/lsst/daf/butler/core/datastoreRecordData.py index c2a13c8fe8..e8345673eb 100644 --- a/python/lsst/daf/butler/core/datastoreRecordData.py +++ b/python/lsst/daf/butler/core/datastoreRecordData.py @@ -30,7 +30,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name @@ -71,10 +71,6 @@ def direct( This method should only be called when the inputs are trusted. """ - data = SerializedDatastoreRecordData.__new__(cls) - setter = object.__setattr__ - # JSON makes strings out of UUIDs, need to convert them back - setter(data, "dataset_ids", [uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids]) # See also comments in record_ids_to_uuid() for table_data in records.values(): for table_records in table_data.values(): @@ -83,7 +79,20 @@ def direct( # columns that are UUIDs we'd need more generic approach. if (id := record.get("dataset_id")) is not None: record["dataset_id"] = uuid.UUID(id) if isinstance(id, str) else id - setter(data, "records", records) + + if PYDANTIC_V2: + print("INSIDE RECORD DATA DIRECT") + data = cls.model_construct( + _fields_set={"dataset_ids", "records"}, + dataset_ids=[uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids], + records=records, + ) + else: + data = SerializedDatastoreRecordData.__new__(cls) + setter = object.__setattr__ + # JSON makes strings out of UUIDs, need to convert them back + setter(data, "dataset_ids", [uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids]) + setter(data, "records", records) return data diff --git a/python/lsst/daf/butler/core/dimensions/_coordinate.py b/python/lsst/daf/butler/core/dimensions/_coordinate.py index 6c66194ad7..9492c437bf 100644 --- a/python/lsst/daf/butler/core/dimensions/_coordinate.py +++ b/python/lsst/daf/butler/core/dimensions/_coordinate.py @@ -34,7 +34,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload from deprecated.sphinx import deprecated -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.sphgeom import IntersectionRegion, Region from ..json import from_json_pydantic, to_json_pydantic @@ -81,17 +81,21 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) -> cache = PersistenceContextVars.serializedDataCoordinateMapping.get() if cache is not None and (result := cache.get(key)) is not None: return result - node = SerializedDataCoordinate.__new__(cls) - setter = object.__setattr__ - setter(node, "dataId", dataId) - setter( - node, - "records", - records - if records is None - else {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()}, - ) - setter(node, "__fields_set__", {"dataId", "records"}) + + if records is None: + serialized_records = None + else: + serialized_records = {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()} + + if PYDANTIC_V2: + node = cls.model_construct(dataId=dataId, records=serialized_records) + else: + node = SerializedDataCoordinate.__new__(cls) + setter = object.__setattr__ + setter(node, "dataId", dataId) + setter(node, "records", serialized_records) + setter(node, "__fields_set__", {"dataId", "records"}) + if cache is not None: cache[key] = node return node diff --git a/python/lsst/daf/butler/core/dimensions/_graph.py b/python/lsst/daf/butler/core/dimensions/_graph.py index d98749fe03..a03c28a6ab 100644 --- a/python/lsst/daf/butler/core/dimensions/_graph.py +++ b/python/lsst/daf/butler/core/dimensions/_graph.py @@ -28,7 +28,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, ClassVar -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.utils.classes import cached_getter, immutable from .._topology import TopologicalFamily, TopologicalSpace @@ -57,10 +57,13 @@ def direct(cls, *, names: list[str]) -> SerializedDimensionGraph: This method should only be called when the inputs are trusted. """ - node = SerializedDimensionGraph.__new__(cls) - object.__setattr__(node, "names", names) - object.__setattr__(node, "__fields_set__", {"names"}) - return node + if PYDANTIC_V2: + return cls.model_construct(names=names) + else: + node = SerializedDimensionGraph.__new__(cls) + object.__setattr__(node, "names", names) + object.__setattr__(node, "__fields_set__", {"names"}) + return node @immutable diff --git a/python/lsst/daf/butler/core/dimensions/_records.py b/python/lsst/daf/butler/core/dimensions/_records.py index 71c2ab27b7..eaffd1ee14 100644 --- a/python/lsst/daf/butler/core/dimensions/_records.py +++ b/python/lsst/daf/butler/core/dimensions/_records.py @@ -180,16 +180,22 @@ def direct( cache = PersistenceContextVars.serializedDimensionRecordMapping.get() if cache is not None and (result := cache.get(key)) is not None: return result - node = SerializedDimensionRecord.__new__(cls) - setter = object.__setattr__ - setter(node, "definition", definition) + # This method requires tuples as values of the mapping, but JSON # readers will read things in as lists. Be kind and transparently # transform to tuples - setter( - node, "record", {k: v if type(v) != list else tuple(v) for k, v in record.items()} # type: ignore - ) - setter(node, "__fields_set__", {"definition", "record"}) + serialized_record = {k: v if type(v) != list else tuple(v) for k, v in record.items()} + + if PYDANTIC_V2: + node = cls.model_construct(definition=definition, record=serialized_record) + else: + node = SerializedDimensionRecord.__new__(cls) + setter = object.__setattr__ + setter(node, "definition", definition) + setter(node, "record", serialized_record) # type: ignore + + setter(node, "__fields_set__", {"definition", "record"}) + if cache is not None: cache[key] = node return node diff --git a/python/lsst/daf/butler/core/quantum.py b/python/lsst/daf/butler/core/quantum.py index 96dfaefa1e..925dbc4d9a 100644 --- a/python/lsst/daf/butler/core/quantum.py +++ b/python/lsst/daf/butler/core/quantum.py @@ -28,7 +28,7 @@ from collections.abc import Iterable, Mapping, MutableMapping, Sequence from typing import Any -from lsst.daf.butler._compat import _BaseModelCompat +from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat from lsst.utils import doImportType from lsst.utils.introspection import find_outside_stacklevel @@ -102,60 +102,77 @@ def direct( This method should only be called when the inputs are trusted. """ - node = SerializedQuantum.__new__(cls) - setter = object.__setattr__ - setter(node, "taskName", sys.intern(taskName or "")) - setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) - - setter( - node, - "datasetTypeMapping", - {k: SerializedDatasetType.direct(**v) for k, v in datasetTypeMapping.items()}, - ) - - setter( - node, - "initInputs", - {k: (SerializedDatasetRef.direct(**v), refs) for k, (v, refs) in initInputs.items()}, - ) - setter( - node, - "inputs", - {k: [(SerializedDatasetRef.direct(**ref), id) for ref, id in v] for k, v in inputs.items()}, - ) - setter( - node, - "outputs", - {k: [(SerializedDatasetRef.direct(**ref), id) for ref, id in v] for k, v in outputs.items()}, - ) - setter( - node, - "dimensionRecords", - dimensionRecords - if dimensionRecords is None - else {int(k): SerializedDimensionRecord.direct(**v) for k, v in dimensionRecords.items()}, - ) - setter( - node, - "datastoreRecords", - datastoreRecords - if datastoreRecords is None - else {k: SerializedDatastoreRecordData.direct(**v) for k, v in datastoreRecords.items()}, + serialized_dataId = SerializedDataCoordinate.direct(**dataId) if dataId is not None else None + serialized_datasetTypeMapping = { + k: SerializedDatasetType.direct(**v) for k, v in datasetTypeMapping.items() + } + serialized_initInputs = { + k: (SerializedDatasetRef.direct(**v), refs) for k, (v, refs) in initInputs.items() + } + serialized_inputs = { + k: [(SerializedDatasetRef.direct(**ref), id) for ref, id in v] for k, v in inputs.items() + } + serialized_outputs = { + k: [(SerializedDatasetRef.direct(**ref), id) for ref, id in v] for k, v in outputs.items() + } + serialized_records = ( + {int(k): SerializedDimensionRecord.direct(**v) for k, v in dimensionRecords.items()} + if dimensionRecords is not None + else None ) - setter( - node, - "__fields_set__", - { - "taskName", - "dataId", - "datasetTypeMapping", - "initInputs", - "inputs", - "outputs", - "dimensionRecords", - "datastore_records", - }, + serialized_datastore_records = ( + {k: SerializedDatastoreRecordData.direct(**v) for k, v in datastoreRecords.items()} + if datastoreRecords is not None + else None ) + + if PYDANTIC_V2: + node = cls.model_construct( + _fields_set={ + "taskName", + "dataId", + "datasetTypeMapping", + "initInputs", + "inputs", + "outputs", + "dimensionRecords", + "datastoreRecords", + }, + taskName=sys.intern(taskName or ""), + dataId=serialized_dataId, + datasetTypeMapping=serialized_datasetTypeMapping, + initInputs=serialized_initInputs, + inputs=serialized_inputs, + outputs=serialized_outputs, + dimensionRecords=serialized_records, + datastoreRecords=serialized_datastore_records, + ) + else: + node = SerializedQuantum.__new__(cls) + setter = object.__setattr__ + setter(node, "taskName", sys.intern(taskName or "")) + setter(node, "dataId", serialized_dataId) + setter(node, "datasetTypeMapping", serialized_datasetTypeMapping) + setter(node, "initInputs", serialized_initInputs) + setter(node, "inputs", serialized_inputs) + setter(node, "outputs", serialized_outputs) + setter(node, "dimensionRecords", serialized_records) + setter(node, "datastoreRecords", serialized_datastore_records) + setter( + node, + "__fields_set__", + { + "taskName", + "dataId", + "datasetTypeMapping", + "initInputs", + "inputs", + "outputs", + "dimensionRecords", + "datastoreRecords", + }, + ) + return node