diff --git a/beanie/__init__.py b/beanie/__init__.py index e889ab54..edcb9043 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -27,7 +27,7 @@ from beanie.odm.views import View from beanie.odm.union_doc import UnionDoc -__version__ = "1.18.0" +__version__ = "1.18.1" __all__ = [ # ODM "Document", diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index d3f32c0b..751a536b 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -25,6 +25,7 @@ from pydantic.main import BaseModel from pymongo import InsertOne from pymongo.client_session import ClientSession +from pymongo.errors import DuplicateKeyError from pymongo.results import ( DeleteResult, InsertManyResult, @@ -70,10 +71,12 @@ CurrentDate, Inc, Set as SetOperator, + Unset, + SetRevisionId, ) from beanie.odm.queries.update import UpdateMany, UpdateResponse from beanie.odm.settings.document import DocumentSettings -from beanie.odm.utils.dump import get_dict +from beanie.odm.utils.dump import get_dict, get_top_level_nones from beanie.odm.utils.parsing import merge_models from beanie.odm.utils.self_validation import validate_self_before from beanie.odm.utils.state import ( @@ -211,7 +214,10 @@ async def insert( await obj.save(link_rule=WriteRules.WRITE) result = await self.get_motor_collection().insert_one( - get_dict(self, to_db=True), session=session + get_dict( + self, to_db=True, keep_nulls=self.get_settings().keep_nulls + ), + session=session, ) new_id = result.inserted_id if not isinstance(new_id, self.__fields__["id"].type_): @@ -259,7 +265,11 @@ async def insert_one( bulk_writer.add_operation( Operation( operation=InsertOne, - first_query=get_dict(document, to_db=True), + first_query=get_dict( + document, + to_db=True, + keep_nulls=document.get_settings().keep_nulls, + ), object_class=type(document), ) ) @@ -287,7 +297,12 @@ async def insert_many( "Cascade insert not supported for insert many method" ) documents_list = [ - get_dict(document, to_db=True) for document in documents + get_dict( + document, + to_db=True, + keep_nulls=document.get_settings().keep_nulls, + ) + for document in documents ] return await cls.get_motor_collection().insert_many( documents_list, session=session, **pymongo_kwargs @@ -374,6 +389,7 @@ async def save( self: DocType, session: Optional[ClientSession] = None, link_rule: WriteRules = WriteRules.DO_NOTHING, + ignore_revision: bool = False, **kwargs, ) -> None: """ @@ -382,6 +398,7 @@ async def save( :param session: Optional[ClientSession] - pymongo session. :param link_rule: WriteRules - rules how to deal with links on writing + :param ignore_revision: bool - do force save. :return: None """ if link_rule == WriteRules.WRITE: @@ -407,9 +424,36 @@ async def save( await obj.save( link_rule=link_rule, session=session ) - await self.set( - get_dict(self, to_db=True), session=session, upsert=True, **kwargs - ) + + if self.get_settings().keep_nulls is False: + await self.update( + SetOperator( + get_dict( + self, + to_db=True, + keep_nulls=self.get_settings().keep_nulls, + ) + ), + Unset(get_top_level_nones(self)), + session=session, + ignore_revision=ignore_revision, + upsert=True, + **kwargs, + ) + else: + await self.update( + SetOperator( + get_dict( + self, + to_db=True, + keep_nulls=self.get_settings().keep_nulls, + ) + ), + session=session, + ignore_revision=ignore_revision, + upsert=True, + **kwargs, + ) @saved_state_needed @wrap_with_actions(EventTypes.SAVE_CHANGES) @@ -432,12 +476,21 @@ async def save_changes( if not self.is_changed: return None changes = self.get_changes() - await self.set( - changes, # type: ignore #TODO fix typing - ignore_revision=ignore_revision, - session=session, - bulk_writer=bulk_writer, - ) + if self.get_settings().keep_nulls is False: + await self.update( + SetOperator(changes), + Unset(get_top_level_nones(self)), + ignore_revision=ignore_revision, + session=session, + bulk_writer=bulk_writer, + ) + else: + await self.set( + changes, # type: ignore #TODO fix typing + ignore_revision=ignore_revision, + session=session, + bulk_writer=bulk_writer, + ) @classmethod async def replace_many( @@ -482,6 +535,9 @@ async def update( :param pymongo_kwargs: pymongo native parameters for update operation :return: None """ + + arguments = list(args) + if skip_sync is not None: raise DeprecationWarning( "skip_sync parameter is not supported. The document get synced always using atomic operation." @@ -496,17 +552,21 @@ async def update( if use_revision_id and not ignore_revision: find_query["revision_id"] = self._previous_revision_id - result = await self.find_one(find_query).update( - *args, - session=session, - response_type=UpdateResponse.NEW_DOCUMENT, - bulk_writer=bulk_writer, - **pymongo_kwargs, - ) + if use_revision_id: + arguments.append(SetRevisionId(self.revision_id)) + try: + result = await self.find_one(find_query).update( + *arguments, + session=session, + response_type=UpdateResponse.NEW_DOCUMENT, + bulk_writer=bulk_writer, + **pymongo_kwargs, + ) + except DuplicateKeyError: + raise RevisionIdWasChanged if bulk_writer is None: if use_revision_id and not ignore_revision and result is None: raise RevisionIdWasChanged - merge_models(self, result) @classmethod @@ -737,7 +797,9 @@ def _save_state(self) -> None: if self.state_management_save_previous(): self._previous_saved_state = self._saved_state - self._saved_state = get_dict(self) + self._saved_state = get_dict( + self, to_db=True, keep_nulls=self.get_settings().keep_nulls + ) def get_saved_state(self) -> Optional[Dict[str, Any]]: """ @@ -756,7 +818,9 @@ def get_previous_saved_state(self) -> Optional[Dict[str, Any]]: @property # type: ignore @saved_state_needed def is_changed(self) -> bool: - if self._saved_state == get_dict(self, to_db=True): + if self._saved_state == get_dict( + self, to_db=True, keep_nulls=self.get_settings().keep_nulls + ): return False return True @@ -784,35 +848,37 @@ def _collect_updates( """ updates = {} - - for field_name, field_value in new_dict.items(): - if field_value != old_dict.get(field_name): - if not self.state_management_replace_objects() and ( - isinstance(field_value, dict) - and isinstance(old_dict.get(field_name), dict) - ): - if old_dict.get(field_name) is None: - updates[field_name] = field_value - elif isinstance(field_value, dict) and isinstance( - old_dict.get(field_name), dict + if old_dict.keys() - new_dict.keys(): + updates = new_dict + else: + for field_name, field_value in new_dict.items(): + if field_value != old_dict.get(field_name): + if not self.state_management_replace_objects() and ( + isinstance(field_value, dict) + and isinstance(old_dict.get(field_name), dict) ): + if old_dict.get(field_name) is None: + updates[field_name] = field_value + elif isinstance(field_value, dict) and isinstance( + old_dict.get(field_name), dict + ): + + field_data = self._collect_updates( + old_dict.get(field_name), # type: ignore + field_value, + ) - field_data = self._collect_updates( - old_dict.get(field_name), # type: ignore - field_value, - ) - - for k, v in field_data.items(): - updates[f"{field_name}.{k}"] = v - else: - updates[field_name] = field_value + for k, v in field_data.items(): + updates[f"{field_name}.{k}"] = v + else: + updates[field_name] = field_value return updates @saved_state_needed def get_changes(self) -> Dict[str, Any]: return self._collect_updates( - self._saved_state, get_dict(self, to_db=True) # type: ignore + self._saved_state, get_dict(self, to_db=True, keep_nulls=self.get_settings().keep_nulls) # type: ignore ) @saved_state_needed diff --git a/beanie/odm/operators/update/general.py b/beanie/odm/operators/update/general.py index 7352363d..f18441d7 100644 --- a/beanie/odm/operators/update/general.py +++ b/beanie/odm/operators/update/general.py @@ -38,6 +38,39 @@ class Sample(Document): operator = "$set" +class SetRevisionId: + """ + `$set` update query operator + + Example: + + ```python + class Sample(Document): + one: int + + Set({Sample.one: 2}) + ``` + + Will return query object like + + ```python + {"$set": {"one": 2}} + ``` + + MongoDB doc: + + """ + + def __init__(self, revision_id): + self.revision_id = revision_id + self.operator = "$set" + self.expression = {"revision_id": self.revision_id} + + @property + def query(self): + return {self.operator: self.expression} + + class CurrentDate(BaseUpdateGeneralOperator): """ `$currentDate` update query operator diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 524686bc..a4b1c911 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -923,7 +923,12 @@ async def replace_one( result: UpdateResult = ( await self.document_model.get_motor_collection().replace_one( self.get_filter_query(), - get_dict(document, to_db=True, exclude={"_id"}), + get_dict( + document, + to_db=True, + exclude={"_id"}, + keep_nulls=document.get_settings().keep_nulls, + ), session=self.session, ) ) diff --git a/beanie/odm/queries/update.py b/beanie/odm/queries/update.py index d24d9c5b..42cd2290 100644 --- a/beanie/odm/queries/update.py +++ b/beanie/odm/queries/update.py @@ -3,6 +3,7 @@ from beanie.odm.bulk import BulkWriter, Operation from beanie.odm.interfaces.clone import CloneInterface +from beanie.odm.operators.update.general import SetRevisionId from beanie.odm.utils.encoder import Encoder from typing import ( Callable, @@ -74,6 +75,10 @@ def update_query(self) -> Dict[str, Any]: query.update(expression.query) elif isinstance(expression, dict): query.update(expression) + elif isinstance(expression, SetRevisionId): + set_query = query.get("$set", {}) + set_query.update(expression.query.get("$set", {})) + query["$set"] = set_query else: raise TypeError("Wrong expression type") return Encoder(custom_encoders=self.encoders).encode(query) @@ -339,7 +344,6 @@ def __await__( Run the query :return: """ - update_result = yield from self._update().__await__() if self.upsert_insert_doc is None: return update_result diff --git a/beanie/odm/settings/document.py b/beanie/odm/settings/document.py index d1ed5dc3..53f0b455 100644 --- a/beanie/odm/settings/document.py +++ b/beanie/odm/settings/document.py @@ -33,5 +33,7 @@ class DocumentSettings(ItemSettings): lazy_parsing: bool = False + keep_nulls: bool = True + class Config: arbitrary_types_allowed = True diff --git a/beanie/odm/utils/dump.py b/beanie/odm/utils/dump.py index d93285aa..af8bc4f3 100644 --- a/beanie/odm/utils/dump.py +++ b/beanie/odm/utils/dump.py @@ -10,6 +10,7 @@ def get_dict( document: "Document", to_db: bool = False, exclude: Optional[Set[str]] = None, + keep_nulls: bool = True, ): if exclude is None: exclude = set() @@ -17,6 +18,34 @@ def get_dict( exclude.add("_id") if not document.get_settings().use_revision: exclude.add("revision_id") - return Encoder(by_alias=True, exclude=exclude, to_db=to_db).encode( - document - ) + return Encoder( + by_alias=True, exclude=exclude, to_db=to_db, keep_nulls=keep_nulls + ).encode(document) + + +def get_nulls( + document: "Document", + exclude: Optional[Set[str]] = None, +): + dictionary = get_dict(document, exclude=exclude, keep_nulls=True) + return filter_none(dictionary) + + +def get_top_level_nones( + document: "Document", + exclude: Optional[Set[str]] = None, +): + dictionary = get_dict(document, exclude=exclude, keep_nulls=True) + return {k: v for k, v in dictionary.items() if v is None} + + +def filter_none(d): + result = {} + for k, v in d.items(): + if isinstance(v, dict): + filtered = filter_none(v) + if filtered: + result[k] = filtered + elif v is None: + result[k] = v + return result diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index 6ec92cc5..62b9556f 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -67,11 +67,13 @@ def __init__( custom_encoders: Optional[Dict[Type, Callable]] = None, by_alias: bool = True, to_db: bool = False, + keep_nulls: bool = True, ): self.exclude = exclude or {} self.by_alias = by_alias self.custom_encoders = custom_encoders or {} self.to_db = to_db + self.keep_nulls = keep_nulls def encode(self, obj: Any): """ @@ -89,6 +91,7 @@ def encode_document(self, obj): custom_encoders=obj.get_settings().bson_encoders, by_alias=self.by_alias, to_db=self.to_db, + keep_nulls=self.keep_nulls, ) link_fields = obj.get_link_fields() @@ -101,7 +104,9 @@ def encode_document(self, obj): obj_dict[obj.get_settings().class_id] = obj._class_id for k, o in obj._iter(to_dict=False, by_alias=self.by_alias): - if k not in self.exclude: + if k not in self.exclude and ( + self.keep_nulls is True or o is not None + ): if link_fields and k in link_fields: if link_fields[k].link_type == LinkTypes.LIST: obj_dict[k] = [link.to_ref() for link in o] @@ -128,7 +133,9 @@ def encode_base_model(self, obj): """ obj_dict = {} for k, o in obj._iter(to_dict=False, by_alias=self.by_alias): - if k not in self.exclude: + if k not in self.exclude and ( + self.keep_nulls is True or o is not None + ): obj_dict[k] = self._encode(o) return obj_dict diff --git a/beanie/odm/utils/parsing.py b/beanie/odm/utils/parsing.py index 68019422..52728f08 100644 --- a/beanie/odm/utils/parsing.py +++ b/beanie/odm/utils/parsing.py @@ -14,6 +14,8 @@ def merge_models(left: BaseModel, right: BaseModel) -> None: from beanie.odm.fields import Link + if hasattr(left, "_previous_revision_id"): + left._previous_revision_id = right._previous_revision_id for k, right_value in right.__iter__(): left_value = left.__getattribute__(k) if isinstance(right_value, BaseModel) and isinstance( diff --git a/docs/changelog.md b/docs/changelog.md index 18657495..132e3428 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,14 @@ Beanie project +## [1.18.1] - 2023-05-04 + +### Keep nulls config +- Author - [Roman Right](https://github.com/roman-right) +- PR + +[1.18.1]: https://pypi.org/project/beanie/1.18.1 + ## [1.18.0] - 2023-03-31 ### Prevent the models returned from find_all to be modified in the middle of modifying diff --git a/docs/tutorial/defining-a-document.md b/docs/tutorial/defining-a-document.md index 88bd1a0e..32c926b7 100644 --- a/docs/tutorial/defining-a-document.md +++ b/docs/tutorial/defining-a-document.md @@ -115,6 +115,7 @@ The inner class `Settings` is used to configure: - Use of cache - Use of state management - Validation on save +- Configure if nulls should be saved to the database ### Collection name @@ -209,3 +210,19 @@ class Sample(Document): IPv4Address: ipv4address_to_int } ``` + +### Keep nulls + +By default, Beanie saves fields with `None` value as `null` in the database. + +But if you don't want to save `null` values, you can set `keep_nulls` to `False` in the `Settings` class: + +```python +class Sample(Document): + num: int + description: Optional[str] = None + + class Settings: + keep_nulls = False +``` + diff --git a/docs/tutorial/revision.md b/docs/tutorial/revision.md index b6a34564..491836c1 100644 --- a/docs/tutorial/revision.md +++ b/docs/tutorial/revision.md @@ -6,6 +6,11 @@ If the application with an older local copy of the document tries to change it, Only when the local copy is synced with the database, the application will be allowed to change the data. This helps to avoid data losses. +### Be aware +revision id feature may wor incorrectly with BulkWriter. + +### Usage + This feature must be explicitly turned on in the `Settings` inner class: ```python diff --git a/pyproject.toml b/pyproject.toml index 8207f0d9..0e2319f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "beanie" -version = "1.18.0" +version = "1.18.1" description = "Asynchronous Python ODM for MongoDB" authors = ["Roman "] license = "Apache-2.0" diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 8e44d5aa..fb857ec8 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -72,6 +72,7 @@ LoopedLinksB, DocumentWithTurnedOnStateManagementWithCustomId, DocumentWithDecimalField, + DocumentWithKeepNullsFalse, ) from tests.odm.views import TestView, TestViewWithLink @@ -229,6 +230,7 @@ async def init(loop, db): LoopedLinksB, DocumentWithTurnedOnStateManagementWithCustomId, DocumentWithDecimalField, + DocumentWithKeepNullsFalse, ] await init_beanie( database=db, diff --git a/tests/odm/documents/test_create.py b/tests/odm/documents/test_create.py index b2e499f0..4746ba70 100644 --- a/tests/odm/documents/test_create.py +++ b/tests/odm/documents/test_create.py @@ -2,7 +2,11 @@ from pymongo.errors import DuplicateKeyError from beanie.odm.fields import PydanticObjectId -from tests.odm.models import DocumentTestModel +from tests.odm.models import ( + DocumentTestModel, + ModelWithOptionalField, + DocumentWithKeepNullsFalse, +) async def test_insert_one(document_not_inserted): @@ -53,3 +57,64 @@ async def test_insert_many_with_session(documents_not_inserted, session): async def test_create_with_session(document_not_inserted, session): await document_not_inserted.insert(session=session) assert isinstance(document_not_inserted.id, PydanticObjectId) + + +async def test_insert_keep_nulls_false(): + model = ModelWithOptionalField(i=10) + doc = DocumentWithKeepNullsFalse(m=model) + + await doc.insert() + + new_doc = await DocumentWithKeepNullsFalse.get(doc.id) + + assert new_doc.m.i == 10 + assert new_doc.m.s is None + assert new_doc.o is None + + raw_data = ( + await DocumentWithKeepNullsFalse.get_motor_collection().find_one( + {"_id": doc.id} + ) + ) + assert raw_data == { + "_id": doc.id, + "m": {"i": 10}, + } + + +async def test_insert_many_keep_nulls_false(): + models = [ModelWithOptionalField(i=10), ModelWithOptionalField(i=11)] + docs = [DocumentWithKeepNullsFalse(m=m) for m in models] + + await DocumentWithKeepNullsFalse.insert_many(docs) + + new_docs = await DocumentWithKeepNullsFalse.find_all().to_list() + + assert len(new_docs) == 2 + + assert new_docs[0].m.i == 10 + assert new_docs[0].m.s is None + assert new_docs[0].o is None + + assert new_docs[1].m.i == 11 + assert new_docs[1].m.s is None + assert new_docs[1].o is None + + raw_data = ( + await DocumentWithKeepNullsFalse.get_motor_collection().find_one( + {"_id": new_docs[0].id} + ) + ) + assert raw_data == { + "_id": new_docs[0].id, + "m": {"i": 10}, + } + raw_data = ( + await DocumentWithKeepNullsFalse.get_motor_collection().find_one( + {"_id": new_docs[1].id} + ) + ) + assert raw_data == { + "_id": new_docs[1].id, + "m": {"i": 11}, + } diff --git a/tests/odm/documents/test_revision.py b/tests/odm/documents/test_revision.py index 65c4421b..d8c88f4f 100644 --- a/tests/odm/documents/test_revision.py +++ b/tests/odm/documents/test_revision.py @@ -1,7 +1,10 @@ import pytest +from beanie import BulkWriter from beanie.exceptions import RevisionIdWasChanged +from beanie.odm.operators.update.general import Inc from tests.odm.models import DocumentWithRevisionTurnedOn +from pymongo.errors import BulkWriteError async def test_replace(): @@ -25,10 +28,33 @@ async def test_replace(): await doc.replace() await doc.replace(ignore_revision=True) + await doc.replace() async def test_update(): doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) + + await doc.insert() + + await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) + await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) + + for i in range(5): + found_doc = await DocumentWithRevisionTurnedOn.get(doc.id) + await found_doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) + + doc._previous_revision_id = "wrong" + with pytest.raises(RevisionIdWasChanged): + await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) + + await doc.update( + Inc({DocumentWithRevisionTurnedOn.num_1: 1}), ignore_revision=True + ) + await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) + + +async def test_save_changes(): + doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) await doc.insert() doc.num_1 = 2 @@ -48,6 +74,62 @@ async def test_update(): await doc.save_changes() await doc.save_changes(ignore_revision=True) + await doc.save_changes() + + +async def test_save(): + doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) + + doc.num_1 = 2 + await doc.save() + + doc.num_2 = 3 + await doc.save() + + for i in range(5): + found_doc = await DocumentWithRevisionTurnedOn.get(doc.id) + found_doc.num_1 += 1 + await found_doc.save() + + doc._previous_revision_id = "wrong" + doc.num_1 = 4 + with pytest.raises(RevisionIdWasChanged): + await doc.save() + + await doc.save(ignore_revision=True) + await doc.save() + + +async def test_update_bulk_writer(): + doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) + await doc.save() + + doc.num_1 = 2 + async with BulkWriter() as bulk_writer: + await doc.save(bulk_writer=bulk_writer) + + doc = await DocumentWithRevisionTurnedOn.get(doc.id) + + doc.num_2 = 3 + async with BulkWriter() as bulk_writer: + await doc.save(bulk_writer=bulk_writer) + + doc = await DocumentWithRevisionTurnedOn.get(doc.id) + + for i in range(5): + found_doc = await DocumentWithRevisionTurnedOn.get(doc.id) + found_doc.num_1 += 1 + async with BulkWriter() as bulk_writer: + await found_doc.save(bulk_writer=bulk_writer) + + doc._previous_revision_id = "wrong" + doc.num_1 = 4 + with pytest.raises(BulkWriteError): + async with BulkWriter() as bulk_writer: + await doc.save(bulk_writer=bulk_writer) + + async with BulkWriter() as bulk_writer: + await doc.save(bulk_writer=bulk_writer, ignore_revision=True) async def test_empty_update(): diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py index a4901872..adc941d4 100644 --- a/tests/odm/documents/test_update.py +++ b/tests/odm/documents/test_update.py @@ -5,7 +5,11 @@ ReplaceError, ) from beanie.odm.fields import PydanticObjectId -from tests.odm.models import DocumentTestModel +from tests.odm.models import ( + DocumentTestModel, + ModelWithOptionalField, + DocumentWithKeepNullsFalse, +) # REPLACE @@ -149,6 +153,51 @@ async def test_update_all(documents): assert len(smth_else_documetns) == 17 +async def test_save_keep_nulls_false(): + model = ModelWithOptionalField(i=10, s="TEST_MODEL") + doc = DocumentWithKeepNullsFalse(m=model, o="TEST_DOCUMENT") + + await doc.insert() + + doc.o = None + doc.m.s = None + await doc.save() + + from_db = await DocumentWithKeepNullsFalse.get(doc.id) + assert from_db.o is None + assert from_db.m.s is None + + raw_data = ( + await DocumentWithKeepNullsFalse.get_motor_collection().find_one( + {"_id": doc.id} + ) + ) + assert raw_data == {"_id": doc.id, "m": {"i": 10}} + + +async def test_save_changes_keep_nulls_false(): + model = ModelWithOptionalField(i=10, s="TEST_MODEL") + doc = DocumentWithKeepNullsFalse(m=model, o="TEST_DOCUMENT") + + await doc.insert() + + doc.o = None + doc.m.s = None + + await doc.save_changes() + + from_db = await DocumentWithKeepNullsFalse.get(doc.id) + assert from_db.o is None + assert from_db.m.s is None + + raw_data = ( + await DocumentWithKeepNullsFalse.get_motor_collection().find_one( + {"_id": doc.id} + ) + ) + assert raw_data == {"_id": doc.id, "m": {"i": 10}} + + # WITH SESSION diff --git a/tests/odm/models.py b/tests/odm/models.py index 960bd1d7..6c5e5780 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -726,3 +726,17 @@ class Settings: name="other_amt_descending", ), ] + + +class ModelWithOptionalField(BaseModel): + s: Optional[str] + i: int + + +class DocumentWithKeepNullsFalse(Document): + o: Optional[str] + m: ModelWithOptionalField + + class Settings: + keep_nulls = False + use_state_management = True diff --git a/tests/odm/test_encoder.py b/tests/odm/test_encoder.py index e3108db8..3354ca3c 100644 --- a/tests/odm/test_encoder.py +++ b/tests/odm/test_encoder.py @@ -11,6 +11,8 @@ SampleWithMutableObjects, Child, DocumentWithDecimalField, + DocumentWithKeepNullsFalse, + ModelWithOptionalField, ) @@ -114,3 +116,12 @@ async def test_decimal(): obj = await DocumentWithDecimalField.get(test_amts.id) assert obj.other_amt == 7 + + +def test_keep_nulls_false(): + model = ModelWithOptionalField(i=10) + doc = DocumentWithKeepNullsFalse(m=model) + + encoder = Encoder(keep_nulls=False, to_db=True) + encoded_doc = encoder.encode(doc) + assert encoded_doc == {"m": {"i": 10}} diff --git a/tests/test_beanie.py b/tests/test_beanie.py index 5dbdbe60..241701dc 100644 --- a/tests/test_beanie.py +++ b/tests/test_beanie.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "1.18.0" + assert __version__ == "1.18.1"