Skip to content

Commit

Permalink
Merge pull request #865 from materialsproject/pydantic-2.0
Browse files Browse the repository at this point in the history
Pydantic 2.0
  • Loading branch information
Jason Munro authored Sep 26, 2023
2 parents 145d0ed + 13a0dc7 commit fbc1ac5
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 74 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
install_requires=[
"setuptools",
"ruamel.yaml<0.18",
"pydantic<2.0",
"pydantic>=0.32.2",
"pydantic>=2.0",
"pydantic-settings>=2.0.3",
"pymongo>=4.2.0",
"monty>=1.0.2",
"monty>=2023.9.25",
"mongomock>=3.10.0",
"pydash>=4.1.0",
"jsonschema>=3.1.1",
Expand Down
102 changes: 64 additions & 38 deletions src/maggma/api/query_operator/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import inspect
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from fastapi.params import Query
from monty.json import MontyDecoder
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.fields import FieldInfo

from maggma.api.query_operator import QueryOperator
from maggma.api.utils import STORE_PARAMS
Expand All @@ -25,8 +25,10 @@ def __init__(
self.fields = fields
self.excluded_fields = excluded_fields

all_fields: Dict[str, ModelField] = model.__fields__
param_fields = fields or list(set(all_fields.keys()) - set(excluded_fields or []))
all_fields: Dict[str, FieldInfo] = model.model_fields
param_fields = fields or list(
set(all_fields.keys()) - set(excluded_fields or [])
)

# Convert the fields into operator tuples
ops = [
Expand All @@ -47,7 +49,9 @@ def query(**kwargs) -> STORE_PARAMS:
try:
criteria.append(self.mapping[k](v))
except KeyError:
raise KeyError(f"Cannot find key {k} in current query to database mapping")
raise KeyError(
f"Cannot find key {k} in current query to database mapping"
)

final_crit = {}
for entry in criteria:
Expand Down Expand Up @@ -78,9 +82,11 @@ def query(self):
"Stub query function for abstract class"

@abstractmethod
def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic ModelField into a Tuple with the
Converts a PyDantic FieldInfo into a Tuple with the
- query param name,
- query param type
- FastAPI Query object,
Expand All @@ -107,81 +113,91 @@ def as_dict(self) -> Dict:
class NumericQuery(DynamicQueryOperator):
"Query Operator to enable searching on numeric fields"

def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic ModelField into a Tuple with the
Converts a PyDantic FieldInfo into a Tuple with the
query_param name,
default value,
Query object,
and callable to convert it into a query dict
"""

ops = []
field_type = field.type_
field_type = field.annotation

if field_type in [int, float]:
title: str = field.field_info.title or field.name
if field_type in [int, float, Union[float, None], Union[int, None]]:
title: str = name or field.alias

ops = [
(
f"{field.name}_max",
f"{title}_max",
field_type,
Query(
default=None,
description=f"Query for maximum value of {title}",
),
lambda val: {f"{field.name}": {"$lte": val}},
lambda val: {f"{title}": {"$lte": val}},
),
(
f"{field.name}_min",
f"{title}_min",
field_type,
Query(
default=None,
description=f"Query for minimum value of {title}",
),
lambda val: {f"{field.name}": {"$gte": val}},
lambda val: {f"{title}": {"$gte": val}},
),
]

if field_type is int:
if field_type in [int, Union[int, None]]:
ops.extend(
[
(
f"{field.name}",
f"{title}",
field_type,
Query(
default=None,
description=f"Query for {title} being equal to an exact value",
),
lambda val: {f"{field.name}": val},
lambda val: {f"{title}": val},
),
(
f"{field.name}_not_eq",
f"{title}_not_eq",
field_type,
Query(
default=None,
description=f"Query for {title} being not equal to an exact value",
),
lambda val: {f"{field.name}": {"$ne": val}},
lambda val: {f"{title}": {"$ne": val}},
),
(
f"{field.name}_eq_any",
f"{title}_eq_any",
str, # type: ignore
Query(
default=None,
description=f"Query for {title} being any of these values. Provide a comma separated list.",
),
lambda val: {f"{field.name}": {"$in": [int(entry.strip()) for entry in val.split(",")]}},
lambda val: {
f"{title}": {
"$in": [int(entry.strip()) for entry in val.split(",")]
}
},
),
(
f"{field.name}_neq_any",
f"{title}_neq_any",
str, # type: ignore
Query(
default=None,
description=f"Query for {title} being not any of these values. \
Provide a comma separated list.",
),
lambda val: {f"{field.name}": {"$nin": [int(entry.strip()) for entry in val.split(",")]}},
lambda val: {
f"{title}": {
"$nin": [int(entry.strip()) for entry in val.split(",")]
}
},
),
]
)
Expand All @@ -192,57 +208,67 @@ def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any
class StringQueryOperator(DynamicQueryOperator):
"Query Operator to enable searching on numeric fields"

def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
def field_to_operator(
self, name: str, field: FieldInfo
) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]:
"""
Converts a PyDantic ModelField into a Tuple with the
Converts a PyDantic FieldInfo into a Tuple with the
query_param name,
default value,
Query object,
and callable to convert it into a query dict
"""

ops = []
field_type: type = field.type_
field_type: type = field.annotation

if field_type in [str]:
title: str = field.field_info.title or field.name
if field_type in [str, Union[str, None]]:
title: str = name

ops = [
(
f"{field.name}",
f"{title}",
field_type,
Query(
default=None,
description=f"Query for {title} being equal to a value",
),
lambda val: {f"{field.name}": val},
lambda val: {f"{title}": val},
),
(
f"{field.name}_not_eq",
f"{title}_not_eq",
field_type,
Query(
default=None,
description=f"Query for {title} being not equal to a value",
),
lambda val: {f"{field.name}": {"$ne": val}},
lambda val: {f"{title}": {"$ne": val}},
),
(
f"{field.name}_eq_any",
f"{title}_eq_any",
str, # type: ignore
Query(
default=None,
description=f"Query for {title} being any of these values. Provide a comma separated list.",
),
lambda val: {f"{field.name}": {"$in": [entry.strip() for entry in val.split(",")]}},
lambda val: {
f"{title}": {
"$in": [entry.strip() for entry in val.split(",")]
}
},
),
(
f"{field.name}_neq_any",
f"{title}_neq_any",
str, # type: ignore
Query(
default=None,
description=f"Query for {title} being not any of these values. Provide a comma separated list",
),
lambda val: {f"{field.name}": {"$nin": [entry.strip() for entry in val.split(",")]}},
lambda val: {
f"{title}": {
"$nin": [entry.strip() for entry in val.split(",")]
}
},
),
]

Expand Down
69 changes: 41 additions & 28 deletions src/maggma/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from bson.objectid import ObjectId
from monty.json import MSONable
from pydantic import BaseModel
from pydantic.schema import get_flat_models_from_model
from pydantic.utils import lenient_issubclass
from typing_extensions import Literal
from pydantic.fields import FieldInfo
from maggma.utils import get_flat_models_from_model
from pydantic._internal._utils import lenient_issubclass
from typing_extensions import Literal, Union

if sys.version_info >= (3, 8):
from typing import get_args
Expand Down Expand Up @@ -42,7 +43,12 @@ def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS:
if "properties" in sub_query:
properties.extend(sub_query["properties"])

remainder = {k: v for query in queries for k, v in query.items() if k not in ["criteria", "properties"]}
remainder = {
k: v
for query in queries
for k, v in query.items()
if k not in ["criteria", "properties"]
}

return {
"criteria": criteria,
Expand Down Expand Up @@ -86,50 +92,55 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict):


def api_sanitize(
pydantic_model: Type[BaseModel],
fields_to_leave: Optional[List[str]] = None,
pydantic_model: BaseModel,
fields_to_leave: Union[str, None] = None,
allow_dict_msonable=False,
):
"""
Function to clean up pydantic models for the API by:
"""Function to clean up pydantic models for the API by:
1.) Making fields optional
2.) Allowing dictionaries in-place of the objects for MSONable quantities
2.) Allowing dictionaries in-place of the objects for MSONable quantities.
WARNING: This works in place, so it mutates the model and all sub-models
Args:
fields_to_leave: list of strings for model fields as "model__name__.field"
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""

models: List[Type[BaseModel]] = [
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
]
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
assert all(len(f) == 2 for f in fields_tuples)

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.__fields__.items():
field_type = field.type_

if name not in model_fields_to_leave:
field.required = False
field.default = None
field.default_factory = None
field.allow_none = True
field.field_info.default = None
field.field_info.default_factory = None
for name in model.model_fields:
field = model.model_fields[name]
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field.type_ = allow_msonable_dict(field_type)
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)
field.populate_validators()

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
model.model_fields[name] = new_field

model.model_rebuild(force=True)

return pydantic_model

Expand All @@ -139,7 +150,7 @@ def allow_msonable_dict(monty_cls: Type[MSONable]):
Patch Monty to allow for dict values for MSONable
"""

def validate_monty(cls, v):
def validate_monty(cls, v, _):
"""
Stub validator for MSONable as a dictionary only
"""
Expand All @@ -155,13 +166,15 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")

monty_cls.validate_monty = classmethod(validate_monty)
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls

Expand Down
Loading

0 comments on commit fbc1ac5

Please sign in to comment.