diff --git a/python/lsst/daf/butler/queries/expression_factory.py b/python/lsst/daf/butler/queries/expression_factory.py index a55504f2b3..668f5476bc 100644 --- a/python/lsst/daf/butler/queries/expression_factory.py +++ b/python/lsst/daf/butler/queries/expression_factory.py @@ -463,7 +463,7 @@ def __repr__(self) -> str: # to include Datastore record fields. def __getattr__(self, field: str) -> ScalarExpressionProxy: - if field not in tree.DATASET_FIELD_NAMES: + if not tree.is_dataset_field(field): raise AttributeError(field) expression = tree.DatasetFieldReference(dataset_type=self._dataset_type, field=field) return ResolvedScalarExpressionProxy(expression) diff --git a/python/lsst/daf/butler/queries/tree/_base.py b/python/lsst/daf/butler/queries/tree/_base.py index d8627c8c7f..e641dff84b 100644 --- a/python/lsst/daf/butler/queries/tree/_base.py +++ b/python/lsst/daf/butler/queries/tree/_base.py @@ -32,10 +32,11 @@ "ColumnExpressionBase", "DatasetFieldName", "DATASET_FIELD_NAMES", + "is_dataset_field", ) from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeVar, cast, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeGuard, TypeVar, cast, get_args import pydantic @@ -63,6 +64,22 @@ _O = TypeVar("_O") +def is_dataset_field(s: str) -> TypeGuard[DatasetFieldName]: + """Validate a field name. + + Parameters + ---------- + s : `str` + The field name to test. + + Returns + ------- + is_field : `bool` + Whether or not this is a dataset field. + """ + return s in DATASET_FIELD_NAMES + + class QueryTreeBase(pydantic.BaseModel): """Base class for all non-primitive types in a query tree."""