Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement support for only and exclude options #141

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions marshmallow_oneofschema/one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,46 @@ def get_obj_type(self, obj):

type_field = "type"
type_field_remove = True
type_field_dump = True
type_schemas = {}

def __init__(self, **kwargs):
only = kwargs.get("only")
exclude = kwargs.get("exclude", ())
if only is not None:
self.type_field_dump = self.type_field in only
kwargs["only"] = ()
if exclude:
self.type_field_dump = self.type_field not in exclude
kwargs["exclude"] = ()
super().__init__(**kwargs)
self._init_type_schemas(only, exclude)

def _init_type_schemas(self, only, exclude):
self.type_schemas = {
k: self._create_type_schema_instance(v, only, exclude)
for k, v in self.type_schemas.items()
}

def _create_type_schema_instance(self, SchemaCls, only, exclude):
if only or exclude:
if SchemaCls.opts.fields:
available_field_names = self.set_class(SchemaCls.opts.fields)
else:
available_field_names = self.set_class(
SchemaCls._declared_fields.keys()
)
if SchemaCls.opts.additional:
available_field_names |= self.set_class(
SchemaCls.opts.additional
)
if only:
only = self.set_class(only) & available_field_names
if exclude:
exclude = self.set_class(exclude) & available_field_names

return SchemaCls(only=only, exclude=exclude)

def get_obj_type(self, obj):
"""Returns name of object schema"""
return obj.__class__.__name__
Expand Down Expand Up @@ -96,16 +134,14 @@ def _dump(self, obj, *, update_fields=True, **kwargs):
{"_schema": "Unknown object class: %s" % obj.__class__.__name__},
)

type_schema = self.type_schemas.get(obj_type)
if not type_schema:
schema = self.type_schemas.get(obj_type)
if not schema:
return None, {"_schema": "Unsupported object type: %s" % obj_type}

schema = type_schema if isinstance(type_schema, Schema) else type_schema()

schema.context.update(getattr(self, "context", {}))

result = schema.dump(obj, many=False, **kwargs)
if result is not None:
if result is not None and self.type_field_dump:
result[self.type_field] = obj_type
return result

Expand Down Expand Up @@ -160,17 +196,17 @@ def _load(self, data, *, partial=None, unknown=None, **kwargs):
)

try:
type_schema = self.type_schemas.get(data_type)
schema = self.type_schemas.get(data_type)
except TypeError:
# data_type could be unhashable
raise ValidationError({self.type_field: ["Invalid value: %s" % data_type]})
if not type_schema:
raise ValidationError(
{self.type_field: ["Invalid value: %s" % data_type]}
)
if not schema:
raise ValidationError(
{self.type_field: ["Unsupported value: %s" % data_type]}
)

schema = type_schema if isinstance(type_schema, Schema) else type_schema()

schema.context.update(getattr(self, "context", {}))

return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ def test_dump_with_empty_keeps_type(self):
result = MySchema().dump(Empty())
assert {"type": "Empty"} == result

def test_dump_only(self):
result = MySchema(only=("type", "value1")).dump(
[Foo("hello"), Bar(123), Baz(456, 789)], many=True
)
assert [
{"type": "Foo"},
{"type": "Bar"},
{"type": "Baz", "value1": 456},
] == result

def test_dump_exclude(self):
result = MySchema(exclude=("type", "value2")).dump(
[Foo("hello"), Bar(123), Baz(456, 789)], many=True
)
assert [{"value": "hello"}, {"value": 123}, {"value1": 456}] == result

def test_load(self):
foo_result = MySchema().load({"type": "Foo", "value": "world"})
assert Foo("world") == foo_result
Expand Down