Skip to content

Commit

Permalink
Use model_construct in direct methods
Browse files Browse the repository at this point in the history
If the BaseModel.construct() method is slow we can still
reimplement it as was done in the individual direct() methods
but with the direct() methods still being much simpler.
  • Loading branch information
timj committed Jul 19, 2023
1 parent fce6d3a commit 44d0345
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 162 deletions.
50 changes: 13 additions & 37 deletions python/lsst/daf/butler/_quantum_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import TYPE_CHECKING, Any

from deprecated.sphinx import deprecated
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.resources import ResourcePathExpression

from ._butlerConfig import ButlerConfig
Expand Down Expand Up @@ -745,40 +745,16 @@ def _to_uuid_set(uuids: Iterable[str | uuid.UUID]) -> set[uuid.UUID]:
"""
return {uuid.UUID(id) if isinstance(id, str) else id for id in uuids}

if PYDANTIC_V2:
data = cls.model_construct(
_fields_set={
"predicted_inputs",
"available_inputs",
"actual_inputs",
"predicted_outputs",
"actual_outputs",
"datastore_records",
},
predicted_inputs=_to_uuid_set(predicted_inputs),
available_inputs=_to_uuid_set(available_inputs),
actual_inputs=_to_uuid_set(actual_inputs),
predicted_outputs=_to_uuid_set(predicted_outputs),
actual_outputs=_to_uuid_set(actual_outputs),
datastore_records={
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)
else:
data = QuantumProvenanceData.__new__(cls)
setter = object.__setattr__
setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs))
setter(data, "available_inputs", _to_uuid_set(available_inputs))
setter(data, "actual_inputs", _to_uuid_set(actual_inputs))
setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs))
setter(data, "actual_outputs", _to_uuid_set(actual_outputs))
setter(
data,
"datastore_records",
{
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)
data = cls.model_construct(
predicted_inputs=_to_uuid_set(predicted_inputs),
available_inputs=_to_uuid_set(available_inputs),
actual_inputs=_to_uuid_set(actual_inputs),
predicted_outputs=_to_uuid_set(predicted_outputs),
actual_outputs=_to_uuid_set(actual_outputs),
datastore_records={
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)

return data
28 changes: 9 additions & 19 deletions python/lsst/daf/butler/core/datasets/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, runtime_checkable

from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.utils.classes import immutable
from pydantic import StrictStr, validator

Expand Down Expand Up @@ -226,24 +226,14 @@ def direct(
)
serialized_dataId = SerializedDataCoordinate.direct(**dataId) if dataId is not None else None

if PYDANTIC_V2:
node = cls.model_construct(
_fields_set=_serializedDatasetRefFieldsSet,
id=uuid.UUID(id),
datasetType=serialized_datasetType,
dataId=serialized_dataId,
run=sys.intern(run),
component=component,
)
else:
node = SerializedDatasetRef.__new__(cls)
setter = object.__setattr__
setter(node, "id", uuid.UUID(id))
setter(node, "datasetType", serialized_datasetType)
setter(node, "dataId", serialized_dataId)
setter(node, "run", sys.intern(run))
setter(node, "component", component)
setter(node, "__fields_set__", _serializedDatasetRefFieldsSet)
node = cls.model_construct(
_fields_set=_serializedDatasetRefFieldsSet,
id=uuid.UUID(id),
datasetType=serialized_datasetType,
dataId=serialized_dataId,
run=sys.intern(run),
component=component,
)

return node

Expand Down
30 changes: 8 additions & 22 deletions python/lsst/daf/butler/core/datasets/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar

from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from pydantic import StrictBool, StrictStr

from ..configSupport import LookupKey
Expand Down Expand Up @@ -85,27 +85,13 @@ def direct(
SerializedDimensionGraph.direct(**dimensions) if dimensions is not None else None
)

if PYDANTIC_V2:
node = cls.model_construct(
name=name,
storageClass=storageClass,
dimensions=serialized_dimensions,
parentStorageClass=parentStorageClass,
isCalibration=isCalibration,
)
else:
node = SerializedDatasetType.__new__(cls)
setter = object.__setattr__
setter(node, "name", name)
setter(node, "storageClass", storageClass)
setter(node, "dimensions", serialized_dimensions)
setter(node, "parentStorageClass", parentStorageClass)
setter(node, "isCalibration", isCalibration)
setter(
node,
"__fields_set__",
{"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"},
)
node = cls.model_construct(
name=name,
storageClass=storageClass,
dimensions=serialized_dimensions,
parentStorageClass=parentStorageClass,
isCalibration=isCalibration,
)

if cache is not None:
cache[key] = node
Expand Down
17 changes: 6 additions & 11 deletions python/lsst/daf/butler/core/datastoreRecordData.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,13 @@ def direct(
if (id := record.get("dataset_id")) is not None:
record["dataset_id"] = uuid.UUID(id) if isinstance(id, str) else id

if PYDANTIC_V2:
data = cls.model_construct(
_fields_set={"dataset_ids", "records"},
dataset_ids=[uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids],
records=records,
)
else:
data = SerializedDatastoreRecordData.__new__(cls)
setter = object.__setattr__
data = cls.model_construct(
_fields_set={"dataset_ids", "records"},
# JSON makes strings out of UUIDs, need to convert them back
setter(data, "dataset_ids", [uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids])
setter(data, "records", records)
dataset_ids=[uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids],
records=records,
)

return data


Expand Down
11 changes: 2 additions & 9 deletions python/lsst/daf/butler/core/dimensions/_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload

from deprecated.sphinx import deprecated
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.sphgeom import IntersectionRegion, Region

from ..json import from_json_pydantic, to_json_pydantic
Expand Down Expand Up @@ -89,14 +89,7 @@ def direct(
else:
serialized_records = {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()}

if PYDANTIC_V2:
node = cls.model_construct(dataId=dataId, records=serialized_records)
else:
node = SerializedDataCoordinate.__new__(cls)
setter = object.__setattr__
setter(node, "dataId", dataId)
setter(node, "records", serialized_records)
setter(node, "__fields_set__", {"dataId", "records"})
node = cls.model_construct(dataId=dataId, records=serialized_records)

if cache is not None:
cache[key] = node
Expand Down
10 changes: 2 additions & 8 deletions python/lsst/daf/butler/core/dimensions/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar

from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.utils.classes import cached_getter, immutable

from .._topology import TopologicalFamily, TopologicalSpace
Expand Down Expand Up @@ -57,13 +57,7 @@ def direct(cls, *, names: list[str]) -> SerializedDimensionGraph:
This method should only be called when the inputs are trusted.
"""
if PYDANTIC_V2:
return cls.model_construct(names=names)
else:
node = SerializedDimensionGraph.__new__(cls)
object.__setattr__(node, "names", names)
object.__setattr__(node, "__fields_set__", {"names"})
return node
return cls.model_construct(names=names)


@immutable
Expand Down
10 changes: 1 addition & 9 deletions python/lsst/daf/butler/core/dimensions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,7 @@ def direct(
# transform to tuples
serialized_record = {k: v if type(v) != list else tuple(v) for k, v in record.items()} # type: ignore

if PYDANTIC_V2:
node = cls.model_construct(definition=definition, record=serialized_record) # type: ignore
else:
node = SerializedDimensionRecord.__new__(cls)
setter = object.__setattr__
setter(node, "definition", definition)
setter(node, "record", serialized_record)

setter(node, "__fields_set__", {"definition", "record"})
node = cls.model_construct(definition=definition, record=serialized_record) # type: ignore

if cache is not None:
cache[key] = node
Expand Down
58 changes: 11 additions & 47 deletions python/lsst/daf/butler/core/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from typing import Any

from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.utils import doImportType
from lsst.utils.introspection import find_outside_stacklevel

Expand Down Expand Up @@ -126,52 +126,16 @@ def direct(
else None
)

if PYDANTIC_V2:
node = cls.model_construct(
_fields_set={
"taskName",
"dataId",
"datasetTypeMapping",
"initInputs",
"inputs",
"outputs",
"dimensionRecords",
"datastoreRecords",
},
taskName=sys.intern(taskName or ""),
dataId=serialized_dataId,
datasetTypeMapping=serialized_datasetTypeMapping,
initInputs=serialized_initInputs,
inputs=serialized_inputs,
outputs=serialized_outputs,
dimensionRecords=serialized_records,
datastoreRecords=serialized_datastore_records,
)
else:
node = SerializedQuantum.__new__(cls)
setter = object.__setattr__
setter(node, "taskName", sys.intern(taskName or ""))
setter(node, "dataId", serialized_dataId)
setter(node, "datasetTypeMapping", serialized_datasetTypeMapping)
setter(node, "initInputs", serialized_initInputs)
setter(node, "inputs", serialized_inputs)
setter(node, "outputs", serialized_outputs)
setter(node, "dimensionRecords", serialized_records)
setter(node, "datastoreRecords", serialized_datastore_records)
setter(
node,
"__fields_set__",
{
"taskName",
"dataId",
"datasetTypeMapping",
"initInputs",
"inputs",
"outputs",
"dimensionRecords",
"datastoreRecords",
},
)
node = cls.model_construct(
taskName=sys.intern(taskName or ""),
dataId=serialized_dataId,
datasetTypeMapping=serialized_datasetTypeMapping,
initInputs=serialized_initInputs,
inputs=serialized_inputs,
outputs=serialized_outputs,
dimensionRecords=serialized_records,
datastoreRecords=serialized_datastore_records,
)

return node

Expand Down

0 comments on commit 44d0345

Please sign in to comment.