diff --git a/beanie/__init__.py b/beanie/__init__.py
index e889ab54..edcb9043 100644
--- a/beanie/__init__.py
+++ b/beanie/__init__.py
@@ -27,7 +27,7 @@
from beanie.odm.views import View
from beanie.odm.union_doc import UnionDoc
-__version__ = "1.18.0"
+__version__ = "1.18.1"
__all__ = [
# ODM
"Document",
diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py
index d3f32c0b..751a536b 100644
--- a/beanie/odm/documents.py
+++ b/beanie/odm/documents.py
@@ -25,6 +25,7 @@
from pydantic.main import BaseModel
from pymongo import InsertOne
from pymongo.client_session import ClientSession
+from pymongo.errors import DuplicateKeyError
from pymongo.results import (
DeleteResult,
InsertManyResult,
@@ -70,10 +71,12 @@
CurrentDate,
Inc,
Set as SetOperator,
+ Unset,
+ SetRevisionId,
)
from beanie.odm.queries.update import UpdateMany, UpdateResponse
from beanie.odm.settings.document import DocumentSettings
-from beanie.odm.utils.dump import get_dict
+from beanie.odm.utils.dump import get_dict, get_top_level_nones
from beanie.odm.utils.parsing import merge_models
from beanie.odm.utils.self_validation import validate_self_before
from beanie.odm.utils.state import (
@@ -211,7 +214,10 @@ async def insert(
await obj.save(link_rule=WriteRules.WRITE)
result = await self.get_motor_collection().insert_one(
- get_dict(self, to_db=True), session=session
+ get_dict(
+ self, to_db=True, keep_nulls=self.get_settings().keep_nulls
+ ),
+ session=session,
)
new_id = result.inserted_id
if not isinstance(new_id, self.__fields__["id"].type_):
@@ -259,7 +265,11 @@ async def insert_one(
bulk_writer.add_operation(
Operation(
operation=InsertOne,
- first_query=get_dict(document, to_db=True),
+ first_query=get_dict(
+ document,
+ to_db=True,
+ keep_nulls=document.get_settings().keep_nulls,
+ ),
object_class=type(document),
)
)
@@ -287,7 +297,12 @@ async def insert_many(
"Cascade insert not supported for insert many method"
)
documents_list = [
- get_dict(document, to_db=True) for document in documents
+ get_dict(
+ document,
+ to_db=True,
+ keep_nulls=document.get_settings().keep_nulls,
+ )
+ for document in documents
]
return await cls.get_motor_collection().insert_many(
documents_list, session=session, **pymongo_kwargs
@@ -374,6 +389,7 @@ async def save(
self: DocType,
session: Optional[ClientSession] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
+ ignore_revision: bool = False,
**kwargs,
) -> None:
"""
@@ -382,6 +398,7 @@ async def save(
:param session: Optional[ClientSession] - pymongo session.
:param link_rule: WriteRules - rules how to deal with links on writing
+ :param ignore_revision: bool - do force save.
:return: None
"""
if link_rule == WriteRules.WRITE:
@@ -407,9 +424,36 @@ async def save(
await obj.save(
link_rule=link_rule, session=session
)
- await self.set(
- get_dict(self, to_db=True), session=session, upsert=True, **kwargs
- )
+
+ if self.get_settings().keep_nulls is False:
+ await self.update(
+ SetOperator(
+ get_dict(
+ self,
+ to_db=True,
+ keep_nulls=self.get_settings().keep_nulls,
+ )
+ ),
+ Unset(get_top_level_nones(self)),
+ session=session,
+ ignore_revision=ignore_revision,
+ upsert=True,
+ **kwargs,
+ )
+ else:
+ await self.update(
+ SetOperator(
+ get_dict(
+ self,
+ to_db=True,
+ keep_nulls=self.get_settings().keep_nulls,
+ )
+ ),
+ session=session,
+ ignore_revision=ignore_revision,
+ upsert=True,
+ **kwargs,
+ )
@saved_state_needed
@wrap_with_actions(EventTypes.SAVE_CHANGES)
@@ -432,12 +476,21 @@ async def save_changes(
if not self.is_changed:
return None
changes = self.get_changes()
- await self.set(
- changes, # type: ignore #TODO fix typing
- ignore_revision=ignore_revision,
- session=session,
- bulk_writer=bulk_writer,
- )
+ if self.get_settings().keep_nulls is False:
+ await self.update(
+ SetOperator(changes),
+ Unset(get_top_level_nones(self)),
+ ignore_revision=ignore_revision,
+ session=session,
+ bulk_writer=bulk_writer,
+ )
+ else:
+ await self.set(
+ changes, # type: ignore #TODO fix typing
+ ignore_revision=ignore_revision,
+ session=session,
+ bulk_writer=bulk_writer,
+ )
@classmethod
async def replace_many(
@@ -482,6 +535,9 @@ async def update(
:param pymongo_kwargs: pymongo native parameters for update operation
:return: None
"""
+
+ arguments = list(args)
+
if skip_sync is not None:
raise DeprecationWarning(
"skip_sync parameter is not supported. The document get synced always using atomic operation."
@@ -496,17 +552,21 @@ async def update(
if use_revision_id and not ignore_revision:
find_query["revision_id"] = self._previous_revision_id
- result = await self.find_one(find_query).update(
- *args,
- session=session,
- response_type=UpdateResponse.NEW_DOCUMENT,
- bulk_writer=bulk_writer,
- **pymongo_kwargs,
- )
+ if use_revision_id:
+ arguments.append(SetRevisionId(self.revision_id))
+ try:
+ result = await self.find_one(find_query).update(
+ *arguments,
+ session=session,
+ response_type=UpdateResponse.NEW_DOCUMENT,
+ bulk_writer=bulk_writer,
+ **pymongo_kwargs,
+ )
+ except DuplicateKeyError:
+ raise RevisionIdWasChanged
if bulk_writer is None:
if use_revision_id and not ignore_revision and result is None:
raise RevisionIdWasChanged
-
merge_models(self, result)
@classmethod
@@ -737,7 +797,9 @@ def _save_state(self) -> None:
if self.state_management_save_previous():
self._previous_saved_state = self._saved_state
- self._saved_state = get_dict(self)
+ self._saved_state = get_dict(
+ self, to_db=True, keep_nulls=self.get_settings().keep_nulls
+ )
def get_saved_state(self) -> Optional[Dict[str, Any]]:
"""
@@ -756,7 +818,9 @@ def get_previous_saved_state(self) -> Optional[Dict[str, Any]]:
@property # type: ignore
@saved_state_needed
def is_changed(self) -> bool:
- if self._saved_state == get_dict(self, to_db=True):
+ if self._saved_state == get_dict(
+ self, to_db=True, keep_nulls=self.get_settings().keep_nulls
+ ):
return False
return True
@@ -784,35 +848,37 @@ def _collect_updates(
"""
updates = {}
-
- for field_name, field_value in new_dict.items():
- if field_value != old_dict.get(field_name):
- if not self.state_management_replace_objects() and (
- isinstance(field_value, dict)
- and isinstance(old_dict.get(field_name), dict)
- ):
- if old_dict.get(field_name) is None:
- updates[field_name] = field_value
- elif isinstance(field_value, dict) and isinstance(
- old_dict.get(field_name), dict
+ if old_dict.keys() - new_dict.keys():
+ updates = new_dict
+ else:
+ for field_name, field_value in new_dict.items():
+ if field_value != old_dict.get(field_name):
+ if not self.state_management_replace_objects() and (
+ isinstance(field_value, dict)
+ and isinstance(old_dict.get(field_name), dict)
):
+ if old_dict.get(field_name) is None:
+ updates[field_name] = field_value
+ elif isinstance(field_value, dict) and isinstance(
+ old_dict.get(field_name), dict
+ ):
+
+ field_data = self._collect_updates(
+ old_dict.get(field_name), # type: ignore
+ field_value,
+ )
- field_data = self._collect_updates(
- old_dict.get(field_name), # type: ignore
- field_value,
- )
-
- for k, v in field_data.items():
- updates[f"{field_name}.{k}"] = v
- else:
- updates[field_name] = field_value
+ for k, v in field_data.items():
+ updates[f"{field_name}.{k}"] = v
+ else:
+ updates[field_name] = field_value
return updates
@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
return self._collect_updates(
- self._saved_state, get_dict(self, to_db=True) # type: ignore
+ self._saved_state, get_dict(self, to_db=True, keep_nulls=self.get_settings().keep_nulls) # type: ignore
)
@saved_state_needed
diff --git a/beanie/odm/operators/update/general.py b/beanie/odm/operators/update/general.py
index 7352363d..f18441d7 100644
--- a/beanie/odm/operators/update/general.py
+++ b/beanie/odm/operators/update/general.py
@@ -38,6 +38,39 @@ class Sample(Document):
operator = "$set"
+class SetRevisionId:
+ """
+ `$set` update query operator
+
+ Example:
+
+ ```python
+ class Sample(Document):
+ one: int
+
+ Set({Sample.one: 2})
+ ```
+
+ Will return query object like
+
+ ```python
+ {"$set": {"one": 2}}
+ ```
+
+ MongoDB doc:
+
+ """
+
+ def __init__(self, revision_id):
+ self.revision_id = revision_id
+ self.operator = "$set"
+ self.expression = {"revision_id": self.revision_id}
+
+ @property
+ def query(self):
+ return {self.operator: self.expression}
+
+
class CurrentDate(BaseUpdateGeneralOperator):
"""
`$currentDate` update query operator
diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py
index 524686bc..a4b1c911 100644
--- a/beanie/odm/queries/find.py
+++ b/beanie/odm/queries/find.py
@@ -923,7 +923,12 @@ async def replace_one(
result: UpdateResult = (
await self.document_model.get_motor_collection().replace_one(
self.get_filter_query(),
- get_dict(document, to_db=True, exclude={"_id"}),
+ get_dict(
+ document,
+ to_db=True,
+ exclude={"_id"},
+ keep_nulls=document.get_settings().keep_nulls,
+ ),
session=self.session,
)
)
diff --git a/beanie/odm/queries/update.py b/beanie/odm/queries/update.py
index d24d9c5b..42cd2290 100644
--- a/beanie/odm/queries/update.py
+++ b/beanie/odm/queries/update.py
@@ -3,6 +3,7 @@
from beanie.odm.bulk import BulkWriter, Operation
from beanie.odm.interfaces.clone import CloneInterface
+from beanie.odm.operators.update.general import SetRevisionId
from beanie.odm.utils.encoder import Encoder
from typing import (
Callable,
@@ -74,6 +75,10 @@ def update_query(self) -> Dict[str, Any]:
query.update(expression.query)
elif isinstance(expression, dict):
query.update(expression)
+ elif isinstance(expression, SetRevisionId):
+ set_query = query.get("$set", {})
+ set_query.update(expression.query.get("$set", {}))
+ query["$set"] = set_query
else:
raise TypeError("Wrong expression type")
return Encoder(custom_encoders=self.encoders).encode(query)
@@ -339,7 +344,6 @@ def __await__(
Run the query
:return:
"""
-
update_result = yield from self._update().__await__()
if self.upsert_insert_doc is None:
return update_result
diff --git a/beanie/odm/settings/document.py b/beanie/odm/settings/document.py
index d1ed5dc3..53f0b455 100644
--- a/beanie/odm/settings/document.py
+++ b/beanie/odm/settings/document.py
@@ -33,5 +33,7 @@ class DocumentSettings(ItemSettings):
lazy_parsing: bool = False
+ keep_nulls: bool = True
+
class Config:
arbitrary_types_allowed = True
diff --git a/beanie/odm/utils/dump.py b/beanie/odm/utils/dump.py
index d93285aa..af8bc4f3 100644
--- a/beanie/odm/utils/dump.py
+++ b/beanie/odm/utils/dump.py
@@ -10,6 +10,7 @@ def get_dict(
document: "Document",
to_db: bool = False,
exclude: Optional[Set[str]] = None,
+ keep_nulls: bool = True,
):
if exclude is None:
exclude = set()
@@ -17,6 +18,34 @@ def get_dict(
exclude.add("_id")
if not document.get_settings().use_revision:
exclude.add("revision_id")
- return Encoder(by_alias=True, exclude=exclude, to_db=to_db).encode(
- document
- )
+ return Encoder(
+ by_alias=True, exclude=exclude, to_db=to_db, keep_nulls=keep_nulls
+ ).encode(document)
+
+
+def get_nulls(
+ document: "Document",
+ exclude: Optional[Set[str]] = None,
+):
+ dictionary = get_dict(document, exclude=exclude, keep_nulls=True)
+ return filter_none(dictionary)
+
+
+def get_top_level_nones(
+ document: "Document",
+ exclude: Optional[Set[str]] = None,
+):
+ dictionary = get_dict(document, exclude=exclude, keep_nulls=True)
+ return {k: v for k, v in dictionary.items() if v is None}
+
+
+def filter_none(d):
+ result = {}
+ for k, v in d.items():
+ if isinstance(v, dict):
+ filtered = filter_none(v)
+ if filtered:
+ result[k] = filtered
+ elif v is None:
+ result[k] = v
+ return result
diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py
index 6ec92cc5..62b9556f 100644
--- a/beanie/odm/utils/encoder.py
+++ b/beanie/odm/utils/encoder.py
@@ -67,11 +67,13 @@ def __init__(
custom_encoders: Optional[Dict[Type, Callable]] = None,
by_alias: bool = True,
to_db: bool = False,
+ keep_nulls: bool = True,
):
self.exclude = exclude or {}
self.by_alias = by_alias
self.custom_encoders = custom_encoders or {}
self.to_db = to_db
+ self.keep_nulls = keep_nulls
def encode(self, obj: Any):
"""
@@ -89,6 +91,7 @@ def encode_document(self, obj):
custom_encoders=obj.get_settings().bson_encoders,
by_alias=self.by_alias,
to_db=self.to_db,
+ keep_nulls=self.keep_nulls,
)
link_fields = obj.get_link_fields()
@@ -101,7 +104,9 @@ def encode_document(self, obj):
obj_dict[obj.get_settings().class_id] = obj._class_id
for k, o in obj._iter(to_dict=False, by_alias=self.by_alias):
- if k not in self.exclude:
+ if k not in self.exclude and (
+ self.keep_nulls is True or o is not None
+ ):
if link_fields and k in link_fields:
if link_fields[k].link_type == LinkTypes.LIST:
obj_dict[k] = [link.to_ref() for link in o]
@@ -128,7 +133,9 @@ def encode_base_model(self, obj):
"""
obj_dict = {}
for k, o in obj._iter(to_dict=False, by_alias=self.by_alias):
- if k not in self.exclude:
+ if k not in self.exclude and (
+ self.keep_nulls is True or o is not None
+ ):
obj_dict[k] = self._encode(o)
return obj_dict
diff --git a/beanie/odm/utils/parsing.py b/beanie/odm/utils/parsing.py
index 68019422..52728f08 100644
--- a/beanie/odm/utils/parsing.py
+++ b/beanie/odm/utils/parsing.py
@@ -14,6 +14,8 @@
def merge_models(left: BaseModel, right: BaseModel) -> None:
from beanie.odm.fields import Link
+ if hasattr(left, "_previous_revision_id"):
+ left._previous_revision_id = right._previous_revision_id
for k, right_value in right.__iter__():
left_value = left.__getattribute__(k)
if isinstance(right_value, BaseModel) and isinstance(
diff --git a/docs/changelog.md b/docs/changelog.md
index 18657495..132e3428 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -2,6 +2,14 @@
Beanie project
+## [1.18.1] - 2023-05-04
+
+### Keep nulls config
+- Author - [Roman Right](https://github.com/roman-right)
+- PR
+
+[1.18.1]: https://pypi.org/project/beanie/1.18.1
+
## [1.18.0] - 2023-03-31
### Prevent the models returned from find_all to be modified in the middle of modifying
diff --git a/docs/tutorial/defining-a-document.md b/docs/tutorial/defining-a-document.md
index 88bd1a0e..32c926b7 100644
--- a/docs/tutorial/defining-a-document.md
+++ b/docs/tutorial/defining-a-document.md
@@ -115,6 +115,7 @@ The inner class `Settings` is used to configure:
- Use of cache
- Use of state management
- Validation on save
+- Configure if nulls should be saved to the database
### Collection name
@@ -209,3 +210,19 @@ class Sample(Document):
IPv4Address: ipv4address_to_int
}
```
+
+### Keep nulls
+
+By default, Beanie saves fields with `None` value as `null` in the database.
+
+But if you don't want to save `null` values, you can set `keep_nulls` to `False` in the `Settings` class:
+
+```python
+class Sample(Document):
+ num: int
+ description: Optional[str] = None
+
+ class Settings:
+ keep_nulls = False
+```
+
diff --git a/docs/tutorial/revision.md b/docs/tutorial/revision.md
index b6a34564..491836c1 100644
--- a/docs/tutorial/revision.md
+++ b/docs/tutorial/revision.md
@@ -6,6 +6,11 @@ If the application with an older local copy of the document tries to change it,
Only when the local copy is synced with the database, the application will be allowed to change the data.
This helps to avoid data losses.
+### Be aware
+revision id feature may wor incorrectly with BulkWriter.
+
+### Usage
+
This feature must be explicitly turned on in the `Settings` inner class:
```python
diff --git a/pyproject.toml b/pyproject.toml
index 8207f0d9..0e2319f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "beanie"
-version = "1.18.0"
+version = "1.18.1"
description = "Asynchronous Python ODM for MongoDB"
authors = ["Roman "]
license = "Apache-2.0"
diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py
index 8e44d5aa..fb857ec8 100644
--- a/tests/odm/conftest.py
+++ b/tests/odm/conftest.py
@@ -72,6 +72,7 @@
LoopedLinksB,
DocumentWithTurnedOnStateManagementWithCustomId,
DocumentWithDecimalField,
+ DocumentWithKeepNullsFalse,
)
from tests.odm.views import TestView, TestViewWithLink
@@ -229,6 +230,7 @@ async def init(loop, db):
LoopedLinksB,
DocumentWithTurnedOnStateManagementWithCustomId,
DocumentWithDecimalField,
+ DocumentWithKeepNullsFalse,
]
await init_beanie(
database=db,
diff --git a/tests/odm/documents/test_create.py b/tests/odm/documents/test_create.py
index b2e499f0..4746ba70 100644
--- a/tests/odm/documents/test_create.py
+++ b/tests/odm/documents/test_create.py
@@ -2,7 +2,11 @@
from pymongo.errors import DuplicateKeyError
from beanie.odm.fields import PydanticObjectId
-from tests.odm.models import DocumentTestModel
+from tests.odm.models import (
+ DocumentTestModel,
+ ModelWithOptionalField,
+ DocumentWithKeepNullsFalse,
+)
async def test_insert_one(document_not_inserted):
@@ -53,3 +57,64 @@ async def test_insert_many_with_session(documents_not_inserted, session):
async def test_create_with_session(document_not_inserted, session):
await document_not_inserted.insert(session=session)
assert isinstance(document_not_inserted.id, PydanticObjectId)
+
+
+async def test_insert_keep_nulls_false():
+ model = ModelWithOptionalField(i=10)
+ doc = DocumentWithKeepNullsFalse(m=model)
+
+ await doc.insert()
+
+ new_doc = await DocumentWithKeepNullsFalse.get(doc.id)
+
+ assert new_doc.m.i == 10
+ assert new_doc.m.s is None
+ assert new_doc.o is None
+
+ raw_data = (
+ await DocumentWithKeepNullsFalse.get_motor_collection().find_one(
+ {"_id": doc.id}
+ )
+ )
+ assert raw_data == {
+ "_id": doc.id,
+ "m": {"i": 10},
+ }
+
+
+async def test_insert_many_keep_nulls_false():
+ models = [ModelWithOptionalField(i=10), ModelWithOptionalField(i=11)]
+ docs = [DocumentWithKeepNullsFalse(m=m) for m in models]
+
+ await DocumentWithKeepNullsFalse.insert_many(docs)
+
+ new_docs = await DocumentWithKeepNullsFalse.find_all().to_list()
+
+ assert len(new_docs) == 2
+
+ assert new_docs[0].m.i == 10
+ assert new_docs[0].m.s is None
+ assert new_docs[0].o is None
+
+ assert new_docs[1].m.i == 11
+ assert new_docs[1].m.s is None
+ assert new_docs[1].o is None
+
+ raw_data = (
+ await DocumentWithKeepNullsFalse.get_motor_collection().find_one(
+ {"_id": new_docs[0].id}
+ )
+ )
+ assert raw_data == {
+ "_id": new_docs[0].id,
+ "m": {"i": 10},
+ }
+ raw_data = (
+ await DocumentWithKeepNullsFalse.get_motor_collection().find_one(
+ {"_id": new_docs[1].id}
+ )
+ )
+ assert raw_data == {
+ "_id": new_docs[1].id,
+ "m": {"i": 11},
+ }
diff --git a/tests/odm/documents/test_revision.py b/tests/odm/documents/test_revision.py
index 65c4421b..d8c88f4f 100644
--- a/tests/odm/documents/test_revision.py
+++ b/tests/odm/documents/test_revision.py
@@ -1,7 +1,10 @@
import pytest
+from beanie import BulkWriter
from beanie.exceptions import RevisionIdWasChanged
+from beanie.odm.operators.update.general import Inc
from tests.odm.models import DocumentWithRevisionTurnedOn
+from pymongo.errors import BulkWriteError
async def test_replace():
@@ -25,10 +28,33 @@ async def test_replace():
await doc.replace()
await doc.replace(ignore_revision=True)
+ await doc.replace()
async def test_update():
doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2)
+
+ await doc.insert()
+
+ await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))
+ await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))
+
+ for i in range(5):
+ found_doc = await DocumentWithRevisionTurnedOn.get(doc.id)
+ await found_doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))
+
+ doc._previous_revision_id = "wrong"
+ with pytest.raises(RevisionIdWasChanged):
+ await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))
+
+ await doc.update(
+ Inc({DocumentWithRevisionTurnedOn.num_1: 1}), ignore_revision=True
+ )
+ await doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1}))
+
+
+async def test_save_changes():
+ doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2)
await doc.insert()
doc.num_1 = 2
@@ -48,6 +74,62 @@ async def test_update():
await doc.save_changes()
await doc.save_changes(ignore_revision=True)
+ await doc.save_changes()
+
+
+async def test_save():
+ doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2)
+
+ doc.num_1 = 2
+ await doc.save()
+
+ doc.num_2 = 3
+ await doc.save()
+
+ for i in range(5):
+ found_doc = await DocumentWithRevisionTurnedOn.get(doc.id)
+ found_doc.num_1 += 1
+ await found_doc.save()
+
+ doc._previous_revision_id = "wrong"
+ doc.num_1 = 4
+ with pytest.raises(RevisionIdWasChanged):
+ await doc.save()
+
+ await doc.save(ignore_revision=True)
+ await doc.save()
+
+
+async def test_update_bulk_writer():
+ doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2)
+ await doc.save()
+
+ doc.num_1 = 2
+ async with BulkWriter() as bulk_writer:
+ await doc.save(bulk_writer=bulk_writer)
+
+ doc = await DocumentWithRevisionTurnedOn.get(doc.id)
+
+ doc.num_2 = 3
+ async with BulkWriter() as bulk_writer:
+ await doc.save(bulk_writer=bulk_writer)
+
+ doc = await DocumentWithRevisionTurnedOn.get(doc.id)
+
+ for i in range(5):
+ found_doc = await DocumentWithRevisionTurnedOn.get(doc.id)
+ found_doc.num_1 += 1
+ async with BulkWriter() as bulk_writer:
+ await found_doc.save(bulk_writer=bulk_writer)
+
+ doc._previous_revision_id = "wrong"
+ doc.num_1 = 4
+ with pytest.raises(BulkWriteError):
+ async with BulkWriter() as bulk_writer:
+ await doc.save(bulk_writer=bulk_writer)
+
+ async with BulkWriter() as bulk_writer:
+ await doc.save(bulk_writer=bulk_writer, ignore_revision=True)
async def test_empty_update():
diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py
index a4901872..adc941d4 100644
--- a/tests/odm/documents/test_update.py
+++ b/tests/odm/documents/test_update.py
@@ -5,7 +5,11 @@
ReplaceError,
)
from beanie.odm.fields import PydanticObjectId
-from tests.odm.models import DocumentTestModel
+from tests.odm.models import (
+ DocumentTestModel,
+ ModelWithOptionalField,
+ DocumentWithKeepNullsFalse,
+)
# REPLACE
@@ -149,6 +153,51 @@ async def test_update_all(documents):
assert len(smth_else_documetns) == 17
+async def test_save_keep_nulls_false():
+ model = ModelWithOptionalField(i=10, s="TEST_MODEL")
+ doc = DocumentWithKeepNullsFalse(m=model, o="TEST_DOCUMENT")
+
+ await doc.insert()
+
+ doc.o = None
+ doc.m.s = None
+ await doc.save()
+
+ from_db = await DocumentWithKeepNullsFalse.get(doc.id)
+ assert from_db.o is None
+ assert from_db.m.s is None
+
+ raw_data = (
+ await DocumentWithKeepNullsFalse.get_motor_collection().find_one(
+ {"_id": doc.id}
+ )
+ )
+ assert raw_data == {"_id": doc.id, "m": {"i": 10}}
+
+
+async def test_save_changes_keep_nulls_false():
+ model = ModelWithOptionalField(i=10, s="TEST_MODEL")
+ doc = DocumentWithKeepNullsFalse(m=model, o="TEST_DOCUMENT")
+
+ await doc.insert()
+
+ doc.o = None
+ doc.m.s = None
+
+ await doc.save_changes()
+
+ from_db = await DocumentWithKeepNullsFalse.get(doc.id)
+ assert from_db.o is None
+ assert from_db.m.s is None
+
+ raw_data = (
+ await DocumentWithKeepNullsFalse.get_motor_collection().find_one(
+ {"_id": doc.id}
+ )
+ )
+ assert raw_data == {"_id": doc.id, "m": {"i": 10}}
+
+
# WITH SESSION
diff --git a/tests/odm/models.py b/tests/odm/models.py
index 960bd1d7..6c5e5780 100644
--- a/tests/odm/models.py
+++ b/tests/odm/models.py
@@ -726,3 +726,17 @@ class Settings:
name="other_amt_descending",
),
]
+
+
+class ModelWithOptionalField(BaseModel):
+ s: Optional[str]
+ i: int
+
+
+class DocumentWithKeepNullsFalse(Document):
+ o: Optional[str]
+ m: ModelWithOptionalField
+
+ class Settings:
+ keep_nulls = False
+ use_state_management = True
diff --git a/tests/odm/test_encoder.py b/tests/odm/test_encoder.py
index e3108db8..3354ca3c 100644
--- a/tests/odm/test_encoder.py
+++ b/tests/odm/test_encoder.py
@@ -11,6 +11,8 @@
SampleWithMutableObjects,
Child,
DocumentWithDecimalField,
+ DocumentWithKeepNullsFalse,
+ ModelWithOptionalField,
)
@@ -114,3 +116,12 @@ async def test_decimal():
obj = await DocumentWithDecimalField.get(test_amts.id)
assert obj.other_amt == 7
+
+
+def test_keep_nulls_false():
+ model = ModelWithOptionalField(i=10)
+ doc = DocumentWithKeepNullsFalse(m=model)
+
+ encoder = Encoder(keep_nulls=False, to_db=True)
+ encoded_doc = encoder.encode(doc)
+ assert encoded_doc == {"m": {"i": 10}}
diff --git a/tests/test_beanie.py b/tests/test_beanie.py
index 5dbdbe60..241701dc 100644
--- a/tests/test_beanie.py
+++ b/tests/test_beanie.py
@@ -2,4 +2,4 @@
def test_version():
- assert __version__ == "1.18.0"
+ assert __version__ == "1.18.1"