Skip to content

Commit

Permalink
upgrading to pydantic v2
Browse files Browse the repository at this point in the history
	replacing older style validations with newer pydantic validations

	simplifying map_json to just fail out like everything else instead of
	collecting a comprehensive list of errors
  • Loading branch information
benboger committed Sep 15, 2023
1 parent bcebbe5 commit 171d303
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 42 deletions.
10 changes: 6 additions & 4 deletions dysql/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ class MapperError(Exception):


class DbMapResultBase(abc.ABC):
_key_columns = ['id']

@classmethod
def get_key_columns(cls):
return cls._key_columns
@property
def key_columns(cls):
return ['id']


@classmethod
def create_instance(cls, *args, **kwargs) -> 'DbMapResultBase':
Expand Down Expand Up @@ -149,7 +150,8 @@ def _get_lookup(self, record):
Note: if the id_columns contain an invalid column, logs a warning and returns None
"""
key_columns = self.record_mapper.get_key_columns()
key_columns = self.record_mapper.key_columns
print(f'key_columns: {key_columns}')
if not key_columns:
# preserving older expectations
return None
Expand Down
33 changes: 9 additions & 24 deletions dysql/pydantic_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
NOTICE: Adobe permits you to use, modify, and distribute this file in accordance
with the terms of the Adobe license agreement accompanying it.
"""
import json
from json import JSONDecodeError
from typing import Any, Dict, Set

import sqlalchemy
from pydantic import BaseModel # pylint: disable=no-name-in-module
from pydantic.error_wrappers import ValidationError, ErrorWrapper
from pydantic import BaseModel, TypeAdapter

from .mappers import DbMapResultBase

Expand Down Expand Up @@ -47,16 +44,12 @@ def create_instance(cls, *args, **kwargs) -> 'DbMapResultModel':
return cls.construct(*args, **kwargs)

def _map_json(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str):
model_field = self.__fields__[field]
model_field = self.model_fields[field]
value = record[field]
if not value:
return
if not self._has_been_mapped():
try:
potential_json_data = record[field]
if potential_json_data:
current_dict[field] = json.loads(record[field])
except JSONDecodeError as exc:
return ErrorWrapper(ValueError(
f'Invalid JSON given to {model_field.alias}', exc), loc=model_field.alias)
return None
current_dict[field] = TypeAdapter(model_field.annotation).validate_json(value)

def _map_list(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str):
if record[field] is None:
Expand Down Expand Up @@ -98,11 +91,9 @@ def _map_list_from_string(self, current_dict: dict, record: sqlalchemy.engine.Ro
list_string = str(list_string)
values_from_string = list(map(str.strip, list_string.split(',')))

model_field = self.__fields__[field]
model_field = self.model_fields[field]
# pre-validates the list we are expecting because we want to ensure all records are validated
values, errors_ = model_field.validate(values_from_string, current_dict, loc=model_field.alias)
if errors_:
raise ValidationError(errors_, DbMapResultModel)
values = TypeAdapter(model_field.annotation).validate_python(values_from_string)

if self._has_been_mapped() and current_dict[field]:
current_dict[field].extend(values)
Expand All @@ -122,17 +113,14 @@ def map_record(self, record: sqlalchemy.engine.Row) -> None:
- Remove all DB fields that are present in _dict_value_mappings since they were likely added above
:param record: the DB record
"""
errors = []
current_dict: dict = self.__dict__
for field in record.keys():
if field in self._list_fields:
self._map_list(current_dict, record, field)
elif field in self._csv_list_fields:
self._map_list_from_string(current_dict, record, field)
elif field in self._json_fields:
error = self._map_json(current_dict, record, field)
if error:
errors.append(error)
self._map_json(current_dict, record, field)
elif field in self._set_fields:
self._map_set(current_dict, record, field)
elif field in self._dict_key_fields:
Expand All @@ -142,12 +130,9 @@ def map_record(self, record: sqlalchemy.engine.Row) -> None:
if not self._has_been_mapped():
current_dict[field] = record[field]

if errors:
raise ValidationError(errors, DbMapResultModel)
# Remove all dict value fields (if present)
for db_field in self._dict_value_mappings.values():
current_dict.pop(db_field, None)

if self._has_been_mapped():
# At this point, just update the previous record
self.__dict__.update(current_dict)
Expand Down
32 changes: 20 additions & 12 deletions dysql/test/test_pydantic_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Dict, List, Set, Optional
import pytest

from pydantic.error_wrappers import ValidationError
from pydantic import ValidationError

from dysql import (
RecordCombiningMapper,
Expand Down Expand Up @@ -46,7 +46,7 @@ class ListWithStringsModel(DbMapResultModel):
_csv_list_fields: Set[str] = {'list1', 'list2'}

id: int
list1: Optional[List[str]]
list1: Optional[List[str]] = None
list2: List[int] = [] # help test empty list gets filled


Expand All @@ -55,11 +55,16 @@ class JsonModel(DbMapResultModel):

id: int
json1: dict
json2: Optional[dict]
json2: Optional[dict] = None


class MultiKeyModel(DbMapResultModel):
_key_columns = ['a', 'b']

@classmethod
@property
def key_columns(cls):
return ['a', 'b']

_list_fields = {'c'}
a: int
b: str
Expand Down Expand Up @@ -238,7 +243,7 @@ def test_csv_list_field_without_mapping_ignored():

def test_csv_list_field_invalid_type():
mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel)
with pytest.raises(ValidationError, match="value is not a valid integer"):
with pytest.raises(ValidationError, match="1 validation error for list"):
mapper.map_records([{
'id': 1,
'list1': 'a,b',
Expand Down Expand Up @@ -282,17 +287,20 @@ def test_json_field():
}


def test_invalid_json():
with pytest.raises(ValidationError) as excinfo:
@pytest.mark.parametrize('json1, json2', [
('{ "json": value', None),
('{ "json": value', '{ "json": value }'),
('{ "json": value }', '{ "json": value'),
(None, '{ "json": value'),
])
def test_invalid_json(json1, json2):
with pytest.raises(ValidationError, match='Invalid JSON'):
mapper = SingleRowMapper(record_mapper=JsonModel)
mapper.map_records([{
'id': 1,
'json1': '{ "json": value',
'json2': 'just a string'
'json1': json1,
'json2': json2
}])
assert len(excinfo.value.args[0]) == 2
assert excinfo.value.args[0][0].exc.args[0] == 'Invalid JSON given to json1'
assert excinfo.value.args[0][1].exc.args[0] == 'Invalid JSON given to json2'


def test_json_none():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .

# Used in development, and as an extra
pydantic>=1.8.2,<2
pydantic>2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_version():
'sqlalchemy<2',
),
extras_require={
'pydantic': ['pydantic>=1.8.2,<2'],
'pydantic': ['pydantic>2'],
},
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 171d303

Please sign in to comment.