Skip to content

Commit

Permalink
Merge pull request #2 from bargool/master
Browse files Browse the repository at this point in the history
Type strict field values have to override basic fields by annotations
  • Loading branch information
nonamenix authored Jun 3, 2019
2 parents c22bce0 + 9798d19 commit f44cdd4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
19 changes: 12 additions & 7 deletions jam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def is_optional(annotation: typing.Type) -> bool:

def unpack_optional_type(annotation: typing.Union) -> typing.Type:
"""Optional[Type] -> Type"""
return [t for t in annotation.__args__ if t is not NoneType][0]
return next(t for t in annotation.__args__ if t is not NoneType)


def get_marshmallow_field(annotation):
Expand Down Expand Up @@ -102,14 +102,19 @@ def get_fields_from_annotations(annotations):
}


def _skip_fields_from_annotations(annotations, attrs):
return {attr_name: attr_value for attr_name, attr_value in attrs.items()
if attr_name not in annotations or attr_value is not None}


class SchemaMeta(BaseSchemaMeta):
def __new__(mcs, name, bases, attrs):
new_class = super().__new__(
mcs,
name,
bases,
{**attrs, **get_fields_from_annotations(attrs.get("__annotations__", {}))},
)
annotations = attrs.get("__annotations__", {})

attrs = _skip_fields_from_annotations(annotations, attrs)
attrs = {**get_fields_from_annotations(annotations), **attrs}

new_class = super().__new__(mcs, name, bases, attrs)
setattr(new_class, "_dataclass", dataclass(type(name, (), attrs)))
return new_class

Expand Down
9 changes: 9 additions & 0 deletions jam/tests/test_annotation_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ class Response(Schema):
optional_field: t.Optional[int] = None

assert repr(Response().declared_fields["optional_field"]) == repr(fields.Integer())


def test_strict_marshmallow_field():
class Response(Schema):
basic_field: int
email_field: str = fields.Email(required=True)

assert repr(Response().declared_fields["basic_field"]) == repr(fields.Integer(required=True))
assert repr(Response().declared_fields["email_field"]) == repr(fields.Email(required=True))
2 changes: 1 addition & 1 deletion project.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
url = https://github.com/nonamenix/marshmallow-jam
version = 1.0.1
version = 1.1.0
name = marshmallow-jam
description = Some extra sweets for marshmallow.

0 comments on commit f44cdd4

Please sign in to comment.