Skip to content

Commit

Permalink
Address formatting/MYPY issues
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Jun 28, 2023
1 parent b849bea commit 253a01d
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 31 deletions.
10 changes: 4 additions & 6 deletions python/lsst/daf/butler/core/datasets/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def __init__(

@property
def id(self) -> DatasetId:
"""Primary key of the dataset (`DatasetId`).
Cannot be changed after a `DatasetRef` is constructed.
"""
return uuid.UUID(int=self._id)

def __eq__(self, other: Any) -> bool:
Expand Down Expand Up @@ -708,9 +712,3 @@ class associated with the dataset type of the other ref can be
Cannot be changed after a `DatasetRef` is constructed.
"""

id: DatasetId
"""Primary key of the dataset (`DatasetId`).
Cannot be changed after a `DatasetRef` is constructed.
"""
10 changes: 5 additions & 5 deletions python/lsst/daf/butler/core/datastoreRecordData.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def from_simple(
"""
cache = PersistenceContextVars.dataStoreRecords.get()
key = frozenset(simple.dataset_ids)
if cache is not None and (record := cache.get(key)) is not None:
return record
if cache is not None and (cachedRecord := cache.get(key)) is not None:
return cachedRecord
records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {}
# make sure that all dataset IDs appear in the dict even if they don't
# have records.
Expand All @@ -228,7 +228,7 @@ def from_simple(
info = klass.from_record(record)
dataset_type_records = records.setdefault(info.dataset_id, {})
dataset_type_records.setdefault(table_name, []).append(info)
record = cls(records=records)
newRecord = cls(records=records)
if cache is not None:
cache[key] = record
return record
cache[key] = newRecord
return newRecord
10 changes: 8 additions & 2 deletions python/lsst/daf/butler/core/dimensions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,12 @@ def direct(
This method should only be called when the inputs are trusted.
"""
_recItems = record.items()
# Type ignore because the ternary statement seems to confuse mypy
# based on conflicting inferred types of v.
key = (
definition,
frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in record.items()),
frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in _recItems), # type: ignore
)
cache = PersistenceContextVars.serializedDimensionRecordMapping.get()
if cache is not None and (result := cache.get(key)) is not None:
Expand Down Expand Up @@ -376,9 +379,12 @@ def from_simple(
if universe is None:
# this is for mypy
raise ValueError("Unable to determine a usable universe")
_recItems = simple.record.items()
# Type ignore because the ternary statement seems to confuse mypy
# based on conflicting inferred types of v.
key = (
simple.definition,
frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in simple.record.items()),
frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in _recItems), # type: ignore
)
cache = PersistenceContextVars.dimensionRecords.get()
if cache is not None and (result := cache.get(key)) is not None:
Expand Down
20 changes: 15 additions & 5 deletions python/lsst/daf/butler/core/persistenceContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import uuid
from collections.abc import Callable
from contextvars import Context, ContextVar, Token, copy_context
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast

if TYPE_CHECKING:
from .datasets.ref import DatasetRef
Expand All @@ -37,6 +37,10 @@
from .dimensions._records import DimensionRecord, SerializedDimensionRecord

_T = TypeVar("_T")
_V = TypeVar("_V")

_P = ParamSpec("_P")
_Q = ParamSpec("_Q")


class PersistenceContextVars:
Expand Down Expand Up @@ -76,6 +80,7 @@ class PersistenceContextVars:
until process completion. It was determined the runtime cost of recreating
the `SerializedDatasetRef`\ s was worth the memory savings.
"""

serializedDatasetTypeMapping: ContextVar[
dict[tuple[str, str], SerializedDatasetType] | None
] = ContextVar("serializedDatasetTypeMapping", default=None)
Expand Down Expand Up @@ -141,11 +146,11 @@ def _getContextVars(cls) -> dict[str, ContextVar]:
classAttributes[k] = v
return classAttributes

def __init__(self):
def __init__(self) -> None:
self._ctx: Context | None = None
self._tokens: dict[str, Token] | None = None

def _functionRunner(self, function: Callable[..., _T], *args, **kwargs) -> _T:
def _functionRunner(self, function: Callable[_P, _V], *args: _P.args, **kwargs: _P.kwargs) -> _V:
# create a storage space for the tokens returned from setting the
# context variables
self._tokens = {}
Expand All @@ -168,7 +173,7 @@ def _functionRunner(self, function: Callable[..., _T], *args, **kwargs) -> _T:
self._tokens = None
return result

def run(self, function: Callable[..., _T], *args, **kwargs) -> _T:
def run(self, function: Callable[_Q, _T], *args: _Q.args, **kwargs: _Q.kwargs) -> _T:
"""Execute the supplied function inside context specific caches.
Parameters
Expand All @@ -186,4 +191,9 @@ def run(self, function: Callable[..., _T], *args, **kwargs) -> _T:
The result returned by executing the supplied `Callable`
"""
self._ctx = copy_context()
return self._ctx.run(self._functionRunner, function, *args, **kwargs)
# Type checkers seem to have trouble with a second layer nesting of
# parameter specs in callables, so ignore the call here and explicitly
# cast the result as we know this is exactly what the return type will
# be.
result = self._ctx.run(self._functionRunner, function, *args, **kwargs) # type: ignore
return cast(_T, result)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/core/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def direct(
"""
node = SerializedQuantum.__new__(cls)
setter = object.__setattr__
setter(node, "taskName", sys.intern(taskName))
setter(node, "taskName", sys.intern(taskName or ""))
setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId))

setter(
Expand Down
17 changes: 5 additions & 12 deletions python/lsst/daf/butler/transfers/_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections import defaultdict
from collections.abc import Iterable, Mapping
from datetime import datetime
from typing import IO, TYPE_CHECKING, Any, cast
from typing import IO, TYPE_CHECKING, Any

import astropy.time
import yaml
Expand Down Expand Up @@ -341,25 +341,20 @@ def __init__(self, stream: IO, registry: Registry):
collectionType = CollectionType.from_name(data["collection_type"])
if collectionType is CollectionType.TAGGED:
self.tagAssociations[data["collection"]].extend(
[
x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x])
for x in data["dataset_ids"]
]
[x if not isinstance(x, int) else _refIntId2UUID[x] for x in data["dataset_ids"]]
)
elif collectionType is CollectionType.CALIBRATION:
assocsByTimespan = self.calibAssociations[data["collection"]]
for d in data["validity_ranges"]:
if "timespan" in d:
assocsByTimespan[d["timespan"]] = [
x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x])
for x in d["dataset_ids"]
x if not isinstance(x, int) else _refIntId2UUID[x] for x in d["dataset_ids"]
]
else:
# TODO: this is for backward compatibility, should
# be removed at some point.
assocsByTimespan[Timespan(begin=d["begin"], end=d["end"])] = [
x if not isinstance(x, int) else cast(DatasetId, _refIntId2UUID[x])
for x in d["dataset_ids"]
x if not isinstance(x, int) else _refIntId2UUID[x] for x in d["dataset_ids"]
]
else:
raise ValueError(f"Unexpected calibration type for association: {collectionType.name}.")
Expand All @@ -379,9 +374,7 @@ def __init__(self, stream: IO, registry: Registry):
datasetType,
dataId,
run=data["run"],
id=refid
if not isinstance(refid, int)
else cast(DatasetId, _refIntId2UUID[refid]),
id=refid if not isinstance(refid, int) else _refIntId2UUID[refid],
)
for dataId, refid in zip(
ensure_iterable(d["data_id"]), ensure_iterable(d["dataset_id"])
Expand Down

0 comments on commit 253a01d

Please sign in to comment.