diff --git a/greens/routers/v1/vegs.py b/greens/routers/v1/vegs.py index 74bbf1d..2273cf7 100644 --- a/greens/routers/v1/vegs.py +++ b/greens/routers/v1/vegs.py @@ -16,7 +16,7 @@ "", status_code=HTTP_201_CREATED, response_description="Document created", - # response_model=DocumentResponse, + response_model=DocumentResponse, ) async def add_document(payload: Document): """ @@ -45,7 +45,7 @@ async def get_document(object_id: ObjectIdField): """ try: return await retrieve_document(object_id, collection) - except ValueError as exception: + except (ValueError, TypeError) as exception: raise NotFoundHTTPException(msg=str(exception)) from exception diff --git a/greens/schemas/vegs.py b/greens/schemas/vegs.py index 1b36bdf..1e74b6f 100644 --- a/greens/schemas/vegs.py +++ b/greens/schemas/vegs.py @@ -1,27 +1,46 @@ from bson import ObjectId as _ObjectId from bson.errors import InvalidId -from pydantic import BaseModel, Field, ConfigDict, AfterValidator +from pydantic import BaseModel, Field, ConfigDict, AfterValidator, BeforeValidator from typing_extensions import Annotated -def check_object_id(value: str) -> str: +# def check_object_id(value: str) -> str: +# if not _ObjectId.is_valid(value): +# raise ValueError('Invalid ObjectId') +# return value + + +def check_object_id(value: _ObjectId) -> str: + """ + Checks if the given _ObjectId is valid and returns it as a string. + + Args: + value: The _ObjectId to be checked. + + Returns: + str: The _ObjectId as a string. + + Raises: + ValueError: If the _ObjectId is invalid. + """ + if not _ObjectId.is_valid(value): - raise ValueError('Invalid ObjectId') - return value + raise ValueError("Invalid ObjectId") + return str(value) -ObjectIdField = Annotated[str, AfterValidator(check_object_id)] +ObjectIdField = Annotated[str, BeforeValidator(check_object_id)] config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) class Document(BaseModel): - # model_config = config + model_config = config - name: str = Field(...) - desc: str = Field(...) + name: str + desc: str class DocumentResponse(BaseModel): - id: ObjectIdField = Field(...) + id: ObjectIdField diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 77ec27d..64c45ee 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -30,7 +30,7 @@ def test_document_response_with_valid_id(test_id, object_id): document_response = DocumentResponse(id=object_id) # Assert - assert document_response.id == object_id, f"Test case {test_id} failed: The id field did not match the input ObjectId." + assert document_response.id == str(object_id), f"Test case {test_id} failed: The id field did not match the input ObjectId." # Edge cases