diff --git a/python/lsst/daf/butler/core/datasets/ref.py b/python/lsst/daf/butler/core/datasets/ref.py index 64796f0c21..9edac6bccf 100644 --- a/python/lsst/daf/butler/core/datasets/ref.py +++ b/python/lsst/daf/butler/core/datasets/ref.py @@ -293,6 +293,10 @@ def __init__( @property def id(self) -> DatasetId: + """Primary key of the dataset (`DatasetId`). + + Cannot be changed after a `DatasetRef` is constructed. + """ return uuid.UUID(int=self._id) def __eq__(self, other: Any) -> bool: @@ -708,9 +712,3 @@ class associated with the dataset type of the other ref can be Cannot be changed after a `DatasetRef` is constructed. """ - - id: DatasetId - """Primary key of the dataset (`DatasetId`). - - Cannot be changed after a `DatasetRef` is constructed. - """ diff --git a/python/lsst/daf/butler/core/datastoreRecordData.py b/python/lsst/daf/butler/core/datastoreRecordData.py index 58b8da8b42..aad556427a 100644 --- a/python/lsst/daf/butler/core/datastoreRecordData.py +++ b/python/lsst/daf/butler/core/datastoreRecordData.py @@ -214,8 +214,8 @@ def from_simple( """ cache = PersistenceContextVars.dataStoreRecords.get() key = frozenset(simple.dataset_ids) - if cache is not None and (record := cache.get(key)) is not None: - return record + if cache is not None and (cachedRecord := cache.get(key)) is not None: + return cachedRecord records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {} # make sure that all dataset IDs appear in the dict even if they don't # have records. @@ -228,7 +228,7 @@ def from_simple( info = klass.from_record(record) dataset_type_records = records.setdefault(info.dataset_id, {}) dataset_type_records.setdefault(table_name, []).append(info) - record = cls(records=records) + newRecord = cls(records=records) if cache is not None: - cache[key] = record - return record + cache[key] = newRecord + return newRecord diff --git a/python/lsst/daf/butler/core/dimensions/_records.py b/python/lsst/daf/butler/core/dimensions/_records.py index b5a0e6a04f..d35f23bd36 100644 --- a/python/lsst/daf/butler/core/dimensions/_records.py +++ b/python/lsst/daf/butler/core/dimensions/_records.py @@ -167,9 +167,12 @@ def direct( This method should only be called when the inputs are trusted. """ + _recItems = record.items() + # Type ignore because the ternary statement seems to confuse mypy + # based on conflicting inferred types of v. key = ( definition, - frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in record.items()), + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in _recItems), # type: ignore ) cache = PersistenceContextVars.serializedDimensionRecordMapping.get() if cache is not None and (result := cache.get(key)) is not None: @@ -376,9 +379,12 @@ def from_simple( if universe is None: # this is for mypy raise ValueError("Unable to determine a usable universe") + _recItems = simple.record.items() + # Type ignore because the ternary statement seems to confuse mypy + # based on conflicting inferred types of v. key = ( simple.definition, - frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in simple.record.items()), + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in _recItems), # type: ignore ) cache = PersistenceContextVars.dimensionRecords.get() if cache is not None and (result := cache.get(key)) is not None: diff --git a/python/lsst/daf/butler/core/persistenceContext.py b/python/lsst/daf/butler/core/persistenceContext.py index 3001aae5bc..7d9f616e17 100644 --- a/python/lsst/daf/butler/core/persistenceContext.py +++ b/python/lsst/daf/butler/core/persistenceContext.py @@ -27,7 +27,7 @@ import uuid from collections.abc import Callable from contextvars import Context, ContextVar, Token, copy_context -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast if TYPE_CHECKING: from .datasets.ref import DatasetRef @@ -37,6 +37,10 @@ from .dimensions._records import DimensionRecord, SerializedDimensionRecord _T = TypeVar("_T") +_V = TypeVar("_V") + +_P = ParamSpec("_P") +_Q = ParamSpec("_Q") class PersistenceContextVars: @@ -76,6 +80,7 @@ class PersistenceContextVars: until process completion. It was determined the runtime cost of recreating the `SerializedDatasetRef`\ s was worth the memory savings. """ + serializedDatasetTypeMapping: ContextVar[ dict[tuple[str, str], SerializedDatasetType] | None ] = ContextVar("serializedDatasetTypeMapping", default=None) @@ -141,11 +146,11 @@ def _getContextVars(cls) -> dict[str, ContextVar]: classAttributes[k] = v return classAttributes - def __init__(self): + def __init__(self) -> None: self._ctx: Context | None = None self._tokens: dict[str, Token] | None = None - def _functionRunner(self, function: Callable[..., _T], *args, **kwargs) -> _T: + def _functionRunner(self, function: Callable[_P, _V], *args: _P.args, **kwargs: _P.kwargs) -> _V: # create a storage space for the tokens returned from setting the # context variables self._tokens = {} @@ -168,7 +173,7 @@ def _functionRunner(self, function: Callable[..., _T], *args, **kwargs) -> _T: self._tokens = None return result - def run(self, function: Callable[..., _T], *args, **kwargs) -> _T: + def run(self, function: Callable[_Q, _T], *args: _Q.args, **kwargs: _Q.kwargs) -> _T: """Execute the supplied function inside context specific caches. Parameters @@ -186,4 +191,9 @@ def run(self, function: Callable[..., _T], *args, **kwargs) -> _T: The result returned by executing the supplied `Callable` """ self._ctx = copy_context() - return self._ctx.run(self._functionRunner, function, *args, **kwargs) + # Type checkers seem to have trouble with a second layer nesting of + # parameter specs in callables, so ignore the call here and explicitly + # cast the result as we know this is exactly what the return type will + # be. + result = self._ctx.run(self._functionRunner, function, *args, **kwargs) # type: ignore + return cast(_T, result) diff --git a/python/lsst/daf/butler/core/quantum.py b/python/lsst/daf/butler/core/quantum.py index 25357c45f3..f3cf45e7f3 100644 --- a/python/lsst/daf/butler/core/quantum.py +++ b/python/lsst/daf/butler/core/quantum.py @@ -103,7 +103,7 @@ def direct( """ node = SerializedQuantum.__new__(cls) setter = object.__setattr__ - setter(node, "taskName", sys.intern(taskName)) + setter(node, "taskName", sys.intern(taskName or "")) setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) setter( diff --git a/python/lsst/daf/butler/transfers/_yaml.py b/python/lsst/daf/butler/transfers/_yaml.py index e1309ef72f..ae0cba4f8e 100644 --- a/python/lsst/daf/butler/transfers/_yaml.py +++ b/python/lsst/daf/butler/transfers/_yaml.py @@ -28,7 +28,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping from datetime import datetime -from typing import IO, TYPE_CHECKING, Any, cast +from typing import IO, TYPE_CHECKING, Any import astropy.time import yaml @@ -341,25 +341,20 @@ def __init__(self, stream: IO, registry: Registry): collectionType = CollectionType.from_name(data["collection_type"]) if collectionType is CollectionType.TAGGED: self.tagAssociations[data["collection"]].extend( - [ - x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x]) - for x in data["dataset_ids"] - ] + [x if not isinstance(x, int) else _refIntId2UUID[x] for x in data["dataset_ids"]] ) elif collectionType is CollectionType.CALIBRATION: assocsByTimespan = self.calibAssociations[data["collection"]] for d in data["validity_ranges"]: if "timespan" in d: assocsByTimespan[d["timespan"]] = [ - x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x]) - for x in d["dataset_ids"] + x if not isinstance(x, int) else _refIntId2UUID[x] for x in d["dataset_ids"] ] else: # TODO: this is for backward compatibility, should # be removed at some point. assocsByTimespan[Timespan(begin=d["begin"], end=d["end"])] = [ - x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x]) - for x in d["dataset_ids"] + x if not isinstance(x, int) else _refIntId2UUID[x] for x in d["dataset_ids"] ] else: raise ValueError(f"Unexpected calibration type for association: {collectionType.name}.") @@ -379,9 +374,7 @@ def __init__(self, stream: IO, registry: Registry): datasetType, dataId, run=data["run"], - id=refid - if not isinstance(refid, int) - else cast(DatasetId, _refIntId2UUID[refid]), + id=refid if not isinstance(refid, int) else _refIntId2UUID[refid], ) for dataId, refid in zip( ensure_iterable(d["data_id"]), ensure_iterable(d["dataset_id"])