Skip to content

Commit

Permalink
keep nulls config (#555)
Browse files Browse the repository at this point in the history
* feature | keep nulls
* fix | revision ids for update operations
* test | insert many keep nulls = false
  • Loading branch information
roman-right committed May 4, 2023
1 parent 2349a4b commit 0cd2c60
Show file tree
Hide file tree
Showing 20 changed files with 457 additions and 56 deletions.
2 changes: 1 addition & 1 deletion beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from beanie.odm.views import View
from beanie.odm.union_doc import UnionDoc

__version__ = "1.18.0"
__version__ = "1.18.1"
__all__ = [
# ODM
"Document",
Expand Down
154 changes: 110 additions & 44 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pydantic.main import BaseModel
from pymongo import InsertOne
from pymongo.client_session import ClientSession
from pymongo.errors import DuplicateKeyError
from pymongo.results import (
DeleteResult,
InsertManyResult,
Expand Down Expand Up @@ -70,10 +71,12 @@
CurrentDate,
Inc,
Set as SetOperator,
Unset,
SetRevisionId,
)
from beanie.odm.queries.update import UpdateMany, UpdateResponse
from beanie.odm.settings.document import DocumentSettings
from beanie.odm.utils.dump import get_dict
from beanie.odm.utils.dump import get_dict, get_top_level_nones
from beanie.odm.utils.parsing import merge_models
from beanie.odm.utils.self_validation import validate_self_before
from beanie.odm.utils.state import (
Expand Down Expand Up @@ -211,7 +214,10 @@ async def insert(
await obj.save(link_rule=WriteRules.WRITE)

result = await self.get_motor_collection().insert_one(
get_dict(self, to_db=True), session=session
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
),
session=session,
)
new_id = result.inserted_id
if not isinstance(new_id, self.__fields__["id"].type_):
Expand Down Expand Up @@ -259,7 +265,11 @@ async def insert_one(
bulk_writer.add_operation(
Operation(
operation=InsertOne,
first_query=get_dict(document, to_db=True),
first_query=get_dict(
document,
to_db=True,
keep_nulls=document.get_settings().keep_nulls,
),
object_class=type(document),
)
)
Expand Down Expand Up @@ -287,7 +297,12 @@ async def insert_many(
"Cascade insert not supported for insert many method"
)
documents_list = [
get_dict(document, to_db=True) for document in documents
get_dict(
document,
to_db=True,
keep_nulls=document.get_settings().keep_nulls,
)
for document in documents
]
return await cls.get_motor_collection().insert_many(
documents_list, session=session, **pymongo_kwargs
Expand Down Expand Up @@ -374,6 +389,7 @@ async def save(
self: DocType,
session: Optional[ClientSession] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
ignore_revision: bool = False,
**kwargs,
) -> None:
"""
Expand All @@ -382,6 +398,7 @@ async def save(
:param session: Optional[ClientSession] - pymongo session.
:param link_rule: WriteRules - rules how to deal with links on writing
:param ignore_revision: bool - do force save.
:return: None
"""
if link_rule == WriteRules.WRITE:
Expand All @@ -407,9 +424,36 @@ async def save(
await obj.save(
link_rule=link_rule, session=session
)
await self.set(
get_dict(self, to_db=True), session=session, upsert=True, **kwargs
)

if self.get_settings().keep_nulls is False:
await self.update(
SetOperator(
get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
)
),
Unset(get_top_level_nones(self)),
session=session,
ignore_revision=ignore_revision,
upsert=True,
**kwargs,
)
else:
await self.update(
SetOperator(
get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
)
),
session=session,
ignore_revision=ignore_revision,
upsert=True,
**kwargs,
)

@saved_state_needed
@wrap_with_actions(EventTypes.SAVE_CHANGES)
Expand All @@ -432,12 +476,21 @@ async def save_changes(
if not self.is_changed:
return None
changes = self.get_changes()
await self.set(
changes, # type: ignore #TODO fix typing
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)
if self.get_settings().keep_nulls is False:
await self.update(
SetOperator(changes),
Unset(get_top_level_nones(self)),
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)
else:
await self.set(
changes, # type: ignore #TODO fix typing
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)

@classmethod
async def replace_many(
Expand Down Expand Up @@ -482,6 +535,9 @@ async def update(
:param pymongo_kwargs: pymongo native parameters for update operation
:return: None
"""

arguments = list(args)

if skip_sync is not None:
raise DeprecationWarning(
"skip_sync parameter is not supported. The document get synced always using atomic operation."
Expand All @@ -496,17 +552,21 @@ async def update(
if use_revision_id and not ignore_revision:
find_query["revision_id"] = self._previous_revision_id

result = await self.find_one(find_query).update(
*args,
session=session,
response_type=UpdateResponse.NEW_DOCUMENT,
bulk_writer=bulk_writer,
**pymongo_kwargs,
)
if use_revision_id:
arguments.append(SetRevisionId(self.revision_id))
try:
result = await self.find_one(find_query).update(
*arguments,
session=session,
response_type=UpdateResponse.NEW_DOCUMENT,
bulk_writer=bulk_writer,
**pymongo_kwargs,
)
except DuplicateKeyError:
raise RevisionIdWasChanged
if bulk_writer is None:
if use_revision_id and not ignore_revision and result is None:
raise RevisionIdWasChanged

merge_models(self, result)

@classmethod
Expand Down Expand Up @@ -737,7 +797,9 @@ def _save_state(self) -> None:
if self.state_management_save_previous():
self._previous_saved_state = self._saved_state

self._saved_state = get_dict(self)
self._saved_state = get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
)

def get_saved_state(self) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -756,7 +818,9 @@ def get_previous_saved_state(self) -> Optional[Dict[str, Any]]:
@property # type: ignore
@saved_state_needed
def is_changed(self) -> bool:
if self._saved_state == get_dict(self, to_db=True):
if self._saved_state == get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
):
return False
return True

Expand Down Expand Up @@ -784,35 +848,37 @@ def _collect_updates(
"""
updates = {}

for field_name, field_value in new_dict.items():
if field_value != old_dict.get(field_name):
if not self.state_management_replace_objects() and (
isinstance(field_value, dict)
and isinstance(old_dict.get(field_name), dict)
):
if old_dict.get(field_name) is None:
updates[field_name] = field_value
elif isinstance(field_value, dict) and isinstance(
old_dict.get(field_name), dict
if old_dict.keys() - new_dict.keys():
updates = new_dict
else:
for field_name, field_value in new_dict.items():
if field_value != old_dict.get(field_name):
if not self.state_management_replace_objects() and (
isinstance(field_value, dict)
and isinstance(old_dict.get(field_name), dict)
):
if old_dict.get(field_name) is None:
updates[field_name] = field_value
elif isinstance(field_value, dict) and isinstance(
old_dict.get(field_name), dict
):

field_data = self._collect_updates(
old_dict.get(field_name), # type: ignore
field_value,
)

field_data = self._collect_updates(
old_dict.get(field_name), # type: ignore
field_value,
)

for k, v in field_data.items():
updates[f"{field_name}.{k}"] = v
else:
updates[field_name] = field_value
for k, v in field_data.items():
updates[f"{field_name}.{k}"] = v
else:
updates[field_name] = field_value

return updates

@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
return self._collect_updates(
self._saved_state, get_dict(self, to_db=True) # type: ignore
self._saved_state, get_dict(self, to_db=True, keep_nulls=self.get_settings().keep_nulls) # type: ignore
)

@saved_state_needed
Expand Down
33 changes: 33 additions & 0 deletions beanie/odm/operators/update/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ class Sample(Document):
operator = "$set"


class SetRevisionId:
"""
`$set` update query operator
Example:
```python
class Sample(Document):
one: int
Set({Sample.one: 2})
```
Will return query object like
```python
{"$set": {"one": 2}}
```
MongoDB doc:
<https://docs.mongodb.com/manual/reference/operator/update/set/>
"""

def __init__(self, revision_id):
self.revision_id = revision_id
self.operator = "$set"
self.expression = {"revision_id": self.revision_id}

@property
def query(self):
return {self.operator: self.expression}


class CurrentDate(BaseUpdateGeneralOperator):
"""
`$currentDate` update query operator
Expand Down
7 changes: 6 additions & 1 deletion beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,12 @@ async def replace_one(
result: UpdateResult = (
await self.document_model.get_motor_collection().replace_one(
self.get_filter_query(),
get_dict(document, to_db=True, exclude={"_id"}),
get_dict(
document,
to_db=True,
exclude={"_id"},
keep_nulls=document.get_settings().keep_nulls,
),
session=self.session,
)
)
Expand Down
6 changes: 5 additions & 1 deletion beanie/odm/queries/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from beanie.odm.bulk import BulkWriter, Operation
from beanie.odm.interfaces.clone import CloneInterface
from beanie.odm.operators.update.general import SetRevisionId
from beanie.odm.utils.encoder import Encoder
from typing import (
Callable,
Expand Down Expand Up @@ -74,6 +75,10 @@ def update_query(self) -> Dict[str, Any]:
query.update(expression.query)
elif isinstance(expression, dict):
query.update(expression)
elif isinstance(expression, SetRevisionId):
set_query = query.get("$set", {})
set_query.update(expression.query.get("$set", {}))
query["$set"] = set_query
else:
raise TypeError("Wrong expression type")
return Encoder(custom_encoders=self.encoders).encode(query)
Expand Down Expand Up @@ -339,7 +344,6 @@ def __await__(
Run the query
:return:
"""

update_result = yield from self._update().__await__()
if self.upsert_insert_doc is None:
return update_result
Expand Down
2 changes: 2 additions & 0 deletions beanie/odm/settings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,7 @@ class DocumentSettings(ItemSettings):

lazy_parsing: bool = False

keep_nulls: bool = True

class Config:
arbitrary_types_allowed = True
Loading

0 comments on commit 0cd2c60

Please sign in to comment.