Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass constructor arguments to new schema instances #15

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions marshmallow_oneofschema/one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def get_obj_type(self, obj):
type_field_remove = True
type_schemas = []

def __init__(self, *args, **kwargs):
self._schema_args = args
self._schema_kwargs = kwargs

super(OneOfSchema, self).__init__(*args, **kwargs)

def get_obj_type(self, obj):
"""Returns name of object schema"""
return obj.__class__.__name__
Expand Down Expand Up @@ -102,7 +108,7 @@ def _dump(self, obj, update_fields=True, **kwargs):

schema = (
type_schema if isinstance(type_schema, Schema)
else type_schema()
else type_schema(*self._schema_args, **self._schema_kwargs)
)

schema.context.update(getattr(self, 'context', {}))
Expand Down Expand Up @@ -174,7 +180,8 @@ def _load(self, data, partial=None):
})

schema = (
type_schema if isinstance(type_schema, Schema) else type_schema()
type_schema if isinstance(type_schema, Schema) else
type_schema(*self._schema_args, **self._schema_kwargs)
)

schema.context.update(getattr(self, 'context', {}))
Expand Down
41 changes: 41 additions & 0 deletions tests/test_one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,36 @@ def make_baz(self, data):
return Baz(**data)


class Bat(object):
def __init__(self, value1=None, value2=None, value3=None):
self.value1 = value1
self.value2 = value2
self.value3 = value3

def __repr__(self):
return '<Bat value1=%s value2=%s>' % (self.value1, self.value2)

def __eq__(self, other):
return isinstance(other, self.__class__) and self.value1 == other.value1 \
and self.value2 == other.value2


class BatSchema(m.Schema):
value1 = f.Integer(required=True)
value2 = f.String(required=True)
value3 = f.String()

@m.post_load
def make_bat(self, data):
return Bat(**data)


class MySchema(OneOfSchema):
type_schemas = {
'Foo': FooSchema,
'Bar': BarSchema,
'Baz': BazSchema,
'Bat': BatSchema,
}


Expand All @@ -83,6 +108,22 @@ def test_dump(self):
bar_result = MySchema().dump(Bar(123))
assert {'type': 'Bar', 'value': 123} == bar_result

def test_dump_exclude(self):
bat_result = MySchema().dump(Bat(1, 'hello', 'i like turtles'))
assert {
'type': 'Bat',
'value1': 1,
'value2': 'hello',
'value3': 'i like turtles'
} == bat_result

exclude_schema = MySchema(exclude=('value3',))
bat_exclude_result = exclude_schema.dump(Bat(1, 'hello', 'i like turtles'))
assert {'type': 'Bat', 'value1': 1, 'value2': 'hello'} == bat_exclude_result

foo_result = exclude_schema.dump(Foo('hello'))
assert {'type': 'Foo', 'value': 'hello'} == foo_result

def test_dump_many(self):
result = MySchema().dump([Foo('hello'), Bar(123)], many=True)
assert [{'type': 'Foo', 'value': 'hello'},
Expand Down