Skip to content

Commit

Permalink
Fix other issues (#518)
Browse files Browse the repository at this point in the history
* fix | issue 340 | support decimals

* fix | issue 426 | count for case with fetch_links = True

* version | 1.18.0
  • Loading branch information
roman-right committed Mar 31, 2023
1 parent cdc4d29 commit 2349a4b
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 5 deletions.
2 changes: 1 addition & 1 deletion beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from beanie.odm.views import View
from beanie.odm.union_doc import UnionDoc

__version__ = "1.18.0b1"
__version__ = "1.18.0"
__all__ = [
# ODM
"Document",
Expand Down
47 changes: 47 additions & 0 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,39 @@ async def first_or_none(self) -> Optional[FindQueryResultType]:
return None
return res[0]

async def count(self) -> int:
"""
Number of found documents
:return: int
"""
if self.fetch_links:
aggregation_pipeline: List[
Dict[str, Any]
] = construct_lookup_queries(self.document_model)

aggregation_pipeline.append({"$match": self.get_filter_query()})

if self.skip_number != 0:
aggregation_pipeline.append({"$skip": self.skip_number})
if self.limit_number != 0:
aggregation_pipeline.append({"$limit": self.limit_number})

aggregation_pipeline.append({"$count": "count"})

result = (
await self.document_model.get_motor_collection()
.aggregate(
aggregation_pipeline,
session=self.session,
**self.pymongo_kwargs,
)
.to_list(length=1)
)

return result[0]["count"] if result else 0

return await super(FindMany, self).count()


class FindOne(FindQuery[FindQueryResultType]):
"""
Expand Down Expand Up @@ -964,3 +997,17 @@ def __await__(
return cast(
FindQueryResultType, parse_obj(self.projection_model, document)
)

async def count(self) -> int:
"""
Count the number of documents matching the query
:return: int
"""
if self.fetch_links:
return await self.document_model.find_many(
*self.find_expressions,
session=self.session,
fetch_links=self.fetch_links,
**self.pymongo_kwargs,
).count()
return await super(FindOne, self).count()
16 changes: 14 additions & 2 deletions beanie/odm/utils/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def encode_document(self, obj):
link_fields = obj.get_link_fields()
obj_dict: Dict[str, Any] = {}
if obj.get_settings().union_doc is not None:
obj_dict[obj.get_settings().class_id] = obj.get_settings().union_doc_alias or obj.__class__.__name__
obj_dict[obj.get_settings().class_id] = (
obj.get_settings().union_doc_alias or obj.__class__.__name__
)
if obj._inheritance_inited:
obj_dict[obj.get_settings().class_id] = obj._class_id

Expand Down Expand Up @@ -170,7 +172,17 @@ def _encode(
return self.encode_iterable(obj)

if isinstance(
obj, (str, int, float, ObjectId, datetime, type(None), DBRef)
obj,
(
str,
int,
float,
ObjectId,
datetime,
type(None),
DBRef,
Decimal128,
),
):
return obj

Expand Down
24 changes: 24 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@

Beanie project

## [1.18.0] - 2023-03-31
### Prevent the models returned from find_all to be modified in the middle of modifying
- Author - [Harris Tsim](https://github.com/harris)
- PR <https://github.com/roman-right/beanie/pull/502>

### Allow change class_id and use name settings in uniondoc
- Author - [설원준(Wonjoon Seol)/Dispatch squad](https://github.com/wonjoonSeol-WS)
- PR <https://github.com/roman-right/beanie/pull/466>

### Fix: make `revision_id` not show in schema
- Author - [Ivan GJ](https://github.com/ivan-gj)
- PR <https://github.com/roman-right/beanie/pull/478>

### Fix: added re.pattern support to common encoder suite
- Author - [Ilia](https://github.com/Abashinos)
- PR <https://github.com/roman-right/beanie/pull/511>

### Fix other issues
- Author - [Roman Right](https://github.com/roman-right)
- PR <https://github.com/roman-right/beanie/pull/518>

[1.18.0]: https://pypi.org/project/beanie/1.18.0

## [1.18.0b1] - 2023-02-09

### Fix
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "beanie"
version = "1.18.0b1"
version = "1.18.0"
description = "Asynchronous Python ODM for MongoDB"
authors = ["Roman <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
LoopedLinksA,
LoopedLinksB,
DocumentWithTurnedOnStateManagementWithCustomId,
DocumentWithDecimalField,
)
from tests.odm.views import TestView, TestViewWithLink

Expand Down Expand Up @@ -227,6 +228,7 @@ async def init(loop, db):
LoopedLinksA,
LoopedLinksB,
DocumentWithTurnedOnStateManagementWithCustomId,
DocumentWithDecimalField,
]
await init_beanie(
database=db,
Expand Down
25 changes: 25 additions & 0 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Dict, List, Optional, Set, Tuple, Union
from uuid import UUID, uuid4

import pydantic
import pymongo
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -701,3 +702,27 @@ class DocWithCollectionInnerClass(Document):

class Collection:
name = "test"


class DocumentWithDecimalField(Document):
amt: decimal.Decimal
other_amt: pydantic.condecimal(
decimal_places=1, multiple_of=decimal.Decimal("0.5")
) = 0

class Config:
validate_assignment = True

class Settings:
name = "amounts"
use_revision = True
use_state_management = True
indexes = [
pymongo.IndexModel(
keys=[("amt", pymongo.ASCENDING)], name="amt_ascending"
),
pymongo.IndexModel(
keys=[("other_amt", pymongo.DESCENDING)],
name="other_amt_descending",
),
]
22 changes: 22 additions & 0 deletions tests/odm/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DocumentWithStringField,
SampleWithMutableObjects,
Child,
DocumentWithDecimalField,
)


Expand Down Expand Up @@ -92,3 +93,24 @@ async def test_mutable_objects_on_save():
await instance.save()
assert isinstance(instance.d["Bar"], Child)
assert isinstance(instance.lst[0], Child)


async def test_decimal():
test_amts = DocumentWithDecimalField(amt=1, other_amt=2)
await test_amts.insert()
obj = await DocumentWithDecimalField.get(test_amts.id)
assert obj.amt == 1
assert obj.other_amt == 2

test_amts.amt = 6
await test_amts.save_changes()

obj = await DocumentWithDecimalField.get(test_amts.id)
assert obj.amt == 6

test_amts = (await DocumentWithDecimalField.find_all().to_list())[0]
test_amts.other_amt = 7
await test_amts.save_changes()

obj = await DocumentWithDecimalField.get(test_amts.id)
assert obj.other_amt == 7
7 changes: 7 additions & 0 deletions tests/odm/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ async def test_prefetch_find_many(self, houses):

assert len(houses) == 3

async def test_prefect_count(self, houses):
c = await House.find(House.door.t > 5, fetch_links=True).count()
assert c == 3

c = await House.find_one(House.door.t > 5, fetch_links=True).count()
assert c == 3

async def test_prefetch_find_one(self, house):
house = await House.find_one(House.name == "test")
for window in house.windows:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_beanie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_version():
assert __version__ == "1.18.0b1"
assert __version__ == "1.18.0"

0 comments on commit 2349a4b

Please sign in to comment.