diff --git a/beanie/__init__.py b/beanie/__init__.py index b79822ea..e889ab54 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -27,7 +27,7 @@ from beanie.odm.views import View from beanie.odm.union_doc import UnionDoc -__version__ = "1.18.0b1" +__version__ = "1.18.0" __all__ = [ # ODM "Document", diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 778a83be..524686bc 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -660,6 +660,39 @@ async def first_or_none(self) -> Optional[FindQueryResultType]: return None return res[0] + async def count(self) -> int: + """ + Number of found documents + :return: int + """ + if self.fetch_links: + aggregation_pipeline: List[ + Dict[str, Any] + ] = construct_lookup_queries(self.document_model) + + aggregation_pipeline.append({"$match": self.get_filter_query()}) + + if self.skip_number != 0: + aggregation_pipeline.append({"$skip": self.skip_number}) + if self.limit_number != 0: + aggregation_pipeline.append({"$limit": self.limit_number}) + + aggregation_pipeline.append({"$count": "count"}) + + result = ( + await self.document_model.get_motor_collection() + .aggregate( + aggregation_pipeline, + session=self.session, + **self.pymongo_kwargs, + ) + .to_list(length=1) + ) + + return result[0]["count"] if result else 0 + + return await super(FindMany, self).count() + class FindOne(FindQuery[FindQueryResultType]): """ @@ -964,3 +997,17 @@ def __await__( return cast( FindQueryResultType, parse_obj(self.projection_model, document) ) + + async def count(self) -> int: + """ + Count the number of documents matching the query + :return: int + """ + if self.fetch_links: + return await self.document_model.find_many( + *self.find_expressions, + session=self.session, + fetch_links=self.fetch_links, + **self.pymongo_kwargs, + ).count() + return await super(FindOne, self).count() diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index 612ec98a..6ec92cc5 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -94,7 +94,9 @@ def encode_document(self, obj): link_fields = obj.get_link_fields() obj_dict: Dict[str, Any] = {} if obj.get_settings().union_doc is not None: - obj_dict[obj.get_settings().class_id] = obj.get_settings().union_doc_alias or obj.__class__.__name__ + obj_dict[obj.get_settings().class_id] = ( + obj.get_settings().union_doc_alias or obj.__class__.__name__ + ) if obj._inheritance_inited: obj_dict[obj.get_settings().class_id] = obj._class_id @@ -170,7 +172,17 @@ def _encode( return self.encode_iterable(obj) if isinstance( - obj, (str, int, float, ObjectId, datetime, type(None), DBRef) + obj, + ( + str, + int, + float, + ObjectId, + datetime, + type(None), + DBRef, + Decimal128, + ), ): return obj diff --git a/docs/changelog.md b/docs/changelog.md index 61463303..18657495 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,30 @@ Beanie project +## [1.18.0] - 2023-03-31 + +### Prevent the models returned from find_all to be modified in the middle of modifying +- Author - [Harris Tsim](https://github.com/harris) +- PR + +### Allow change class_id and use name settings in uniondoc +- Author - [설원준(Wonjoon Seol)/Dispatch squad](https://github.com/wonjoonSeol-WS) +- PR + +### Fix: make `revision_id` not show in schema +- Author - [Ivan GJ](https://github.com/ivan-gj) +- PR + +### Fix: added re.pattern support to common encoder suite +- Author - [Ilia](https://github.com/Abashinos) +- PR + +### Fix other issues +- Author - [Roman Right](https://github.com/roman-right) +- PR + +[1.18.0]: https://pypi.org/project/beanie/1.18.0 + ## [1.18.0b1] - 2023-02-09 ### Fix diff --git a/pyproject.toml b/pyproject.toml index 24121c93..8207f0d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "beanie" -version = "1.18.0b1" +version = "1.18.0" description = "Asynchronous Python ODM for MongoDB" authors = ["Roman "] license = "Apache-2.0" diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 304b1ac4..8e44d5aa 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -71,6 +71,7 @@ LoopedLinksA, LoopedLinksB, DocumentWithTurnedOnStateManagementWithCustomId, + DocumentWithDecimalField, ) from tests.odm.views import TestView, TestViewWithLink @@ -227,6 +228,7 @@ async def init(loop, db): LoopedLinksA, LoopedLinksB, DocumentWithTurnedOnStateManagementWithCustomId, + DocumentWithDecimalField, ] await init_beanie( database=db, diff --git a/tests/odm/models.py b/tests/odm/models.py index b9be85a9..960bd1d7 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -13,6 +13,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union from uuid import UUID, uuid4 +import pydantic import pymongo from pydantic import ( BaseModel, @@ -701,3 +702,27 @@ class DocWithCollectionInnerClass(Document): class Collection: name = "test" + + +class DocumentWithDecimalField(Document): + amt: decimal.Decimal + other_amt: pydantic.condecimal( + decimal_places=1, multiple_of=decimal.Decimal("0.5") + ) = 0 + + class Config: + validate_assignment = True + + class Settings: + name = "amounts" + use_revision = True + use_state_management = True + indexes = [ + pymongo.IndexModel( + keys=[("amt", pymongo.ASCENDING)], name="amt_ascending" + ), + pymongo.IndexModel( + keys=[("other_amt", pymongo.DESCENDING)], + name="other_amt_descending", + ), + ] diff --git a/tests/odm/test_encoder.py b/tests/odm/test_encoder.py index 75173faf..e3108db8 100644 --- a/tests/odm/test_encoder.py +++ b/tests/odm/test_encoder.py @@ -10,6 +10,7 @@ DocumentWithStringField, SampleWithMutableObjects, Child, + DocumentWithDecimalField, ) @@ -92,3 +93,24 @@ async def test_mutable_objects_on_save(): await instance.save() assert isinstance(instance.d["Bar"], Child) assert isinstance(instance.lst[0], Child) + + +async def test_decimal(): + test_amts = DocumentWithDecimalField(amt=1, other_amt=2) + await test_amts.insert() + obj = await DocumentWithDecimalField.get(test_amts.id) + assert obj.amt == 1 + assert obj.other_amt == 2 + + test_amts.amt = 6 + await test_amts.save_changes() + + obj = await DocumentWithDecimalField.get(test_amts.id) + assert obj.amt == 6 + + test_amts = (await DocumentWithDecimalField.find_all().to_list())[0] + test_amts.other_amt = 7 + await test_amts.save_changes() + + obj = await DocumentWithDecimalField.get(test_amts.id) + assert obj.other_amt == 7 diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index d6bbb9bf..3cdf0430 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -228,6 +228,13 @@ async def test_prefetch_find_many(self, houses): assert len(houses) == 3 + async def test_prefect_count(self, houses): + c = await House.find(House.door.t > 5, fetch_links=True).count() + assert c == 3 + + c = await House.find_one(House.door.t > 5, fetch_links=True).count() + assert c == 3 + async def test_prefetch_find_one(self, house): house = await House.find_one(House.name == "test") for window in house.windows: diff --git a/tests/test_beanie.py b/tests/test_beanie.py index 1d00b7b6..5dbdbe60 100644 --- a/tests/test_beanie.py +++ b/tests/test_beanie.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "1.18.0b1" + assert __version__ == "1.18.0"