From c45a5048659b7b99f2b7200912dd0c535579dc10 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 24 May 2024 16:04:23 +0200 Subject: [PATCH 01/21] Improve docstrings. - Fix typos. - Fix broken links. - Improve readability. --- docs/source/index.rst | 2 +- trolldb/__init__.py | 2 +- trolldb/api/__init__.py | 2 +- trolldb/api/api.py | 10 +++---- trolldb/api/routes/__init__.py | 2 +- trolldb/config/config.py | 18 ++++-------- trolldb/database/errors.py | 12 ++++---- trolldb/database/mongodb.py | 17 ++++++++---- trolldb/database/piplines.py | 2 +- trolldb/errors/errors.py | 18 +++++++----- trolldb/test_utils/__init__.py | 2 +- trolldb/test_utils/common.py | 2 +- trolldb/test_utils/mongodb_database.py | 38 ++++++++++++++++---------- trolldb/test_utils/mongodb_instance.py | 6 ++-- 14 files changed, 73 insertions(+), 60 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 7091703..4aec776 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Welcome to Pytroll documentation! +Pytroll-db Documentation =========================================== .. toctree:: diff --git a/trolldb/__init__.py b/trolldb/__init__.py index f054e81..8a78256 100644 --- a/trolldb/__init__.py +++ b/trolldb/__init__.py @@ -1 +1 @@ -"""trolldb package.""" +"""The database interface of the Pytroll package.""" diff --git a/trolldb/api/__init__.py b/trolldb/api/__init__.py index 6590a31..4114977 100644 --- a/trolldb/api/__init__.py +++ b/trolldb/api/__init__.py @@ -5,5 +5,5 @@ For more information and documentation, please refer to the following sub-packages and modules: - :obj:`trolldb.api.routes`: The package which defines the API routes. - - :obj:`trollddb.api.api`: The module which defines the API server and how it is run via the given configuration. + - :obj:`trolldb.api.api`: The module which defines the API server and how it is run via the given configuration. """ diff --git a/trolldb/api/api.py b/trolldb/api/api.py index 85272a1..4ce42e7 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -34,7 +34,7 @@ title="pytroll-db", summary="The database API of Pytroll", description= - "The API allows you to perform CRUD operations as well as querying the database" + "The API allows you to perform CRUD operations as well as querying the database" "At the moment only MongoDB is supported. It is based on the following Python packages" "\n * **PyMongo** (https://github.com/mongodb/mongo-python-driver)" "\n * **motor** (https://github.com/mongodb/motor)", @@ -43,7 +43,7 @@ url="https://www.gnu.org/licenses/gpl-3.0.en.html" ) ) -"""These will appear int the auto-generated documentation and are passed to the ``FastAPI`` class as keyword args.""" +"""These will appear in the auto-generated documentation and are passed to the ``FastAPI`` class as keyword args.""" @validate_call @@ -65,8 +65,8 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: `FastAPI class `_ and are directly passed to it. These keyword arguments will be first concatenated with the configurations of the API server which are read from the ``config`` argument. The keyword arguments which are passed explicitly to the function - take precedence over ``config``. Finally, ``API_INFO``, which are hard-coded information for the API server, - will be concatenated and takes precedence over all. + take precedence over ``config``. Finally, :obj:`API_INFO`, which are hard-coded information for the API + server, will be concatenated and takes precedence over all. Raises: ValidationError: @@ -75,7 +75,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: Example: .. code-block:: python - from api.api import run_server + from trolldb.api.api import run_server if __name__ == "__main__": run_server("config.yaml") """ diff --git a/trolldb/api/routes/__init__.py b/trolldb/api/routes/__init__.py index 4c69061..7378306 100644 --- a/trolldb/api/routes/__init__.py +++ b/trolldb/api/routes/__init__.py @@ -1,4 +1,4 @@ -"""routes package.""" +"""The routes package of the API.""" from .router import api_router diff --git a/trolldb/config/config.py b/trolldb/config/config.py index b43d7fe..e8c5678 100644 --- a/trolldb/config/config.py +++ b/trolldb/config/config.py @@ -1,11 +1,6 @@ """The module which handles parsing and validating the config (YAML) file. The validation is performed using `Pydantic `_. - -Note: - Functions in this module are decorated with - `pydantic.validate_call `_ - so that their arguments can be validated using the corresponding type hints, when calling the function at runtime. """ import errno @@ -32,7 +27,7 @@ def id_must_be_valid(id_like_string: str) -> ObjectId: The string to be converted to an ObjectId. Returns: - The ObjectId object if successfully. + The ObjectId object if successful. Raises: ValueError: @@ -46,7 +41,7 @@ def id_must_be_valid(id_like_string: str) -> ObjectId: MongoObjectId = Annotated[str, AfterValidator(id_must_be_valid)] -"""Type hint validator for object IDs.""" +"""The type hint validator for object IDs.""" class MongoDocument(BaseModel): @@ -60,7 +55,7 @@ class APIServerConfig(NamedTuple): Note: The attributes herein are a subset of the keyword arguments accepted by `FastAPI class `_ and are directly passed - to the FastAPI class. + to the FastAPI class. Consult :func:`trolldb.api.api.run_server` on how these configurations are treated. """ url: AnyUrl @@ -79,7 +74,7 @@ class DatabaseConfig(NamedTuple): """ url: MongoDsn - """The URL of the MongoDB server excluding the port part, e.g. ``"mongodb://localhost:27017"``""" + """The URL of the MongoDB server including the port part, e.g. ``"mongodb://localhost:27017"``""" timeout: Timeout """The timeout in seconds (non-negative float), after which an exception is raised if a connection with the @@ -95,7 +90,7 @@ class DatabaseConfig(NamedTuple): class AppConfig(BaseModel): - """A model to hold all the configurations of the application including both the API server and the database. + """A model to hold all the configurations of the application, i.e. the API server, the database, and the subscriber. This will be used by Pydantic to validate the parsed YAML file. """ @@ -121,9 +116,6 @@ def parse_config_yaml_file(filename: FilePath) -> AppConfig: ValidationError: If the successfully parsed file fails the validation, i.e. its schema or the content does not conform to :class:`AppConfig`. - - ValidationError: - If the function is not called with arguments of valid type. """ logger.info("Attempt to parse the YAML file ...") with open(filename, "r") as file: diff --git a/trolldb/database/errors.py b/trolldb/database/errors.py index b3830e3..983597e 100644 --- a/trolldb/database/errors.py +++ b/trolldb/database/errors.py @@ -12,7 +12,7 @@ class Client(ResponsesErrorGroup): - """Client error responses, e.g. if something goes wrong with initialization or closing the client.""" + """Database client error responses, e.g. if something goes wrong with initialization or closing the client.""" CloseNotAllowedError = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "Calling `close()` on a client which has not been initialized is not allowed!" @@ -37,12 +37,12 @@ class Client(ResponsesErrorGroup): ConnectionError = ResponseError({ status.HTTP_400_BAD_REQUEST: - "Could not connect to the database with URL." + "Could not connect to the database with the given URL." }) class Collections(ResponsesErrorGroup): - """Collections error responses, e.g. if a requested collection cannot be found.""" + """Collections error responses, e.g. if the requested collection cannot be found.""" NotFoundError = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given collection name inside the specified database." @@ -50,12 +50,12 @@ class Collections(ResponsesErrorGroup): WrongTypeError = ResponseError({ status.HTTP_422_UNPROCESSABLE_ENTITY: - "Both the Database and collection name must be `None` if one of them is `None`." + "Both the database and collection name must be `None` if either one is `None`." }) class Databases(ResponsesErrorGroup): - """Databases error responses, e.g. if a requested database cannot be found.""" + """Databases error responses, e.g. if the requested database cannot be found.""" NotFoundError = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given database name." @@ -68,7 +68,7 @@ class Databases(ResponsesErrorGroup): class Documents(ResponsesErrorGroup): - """Documents error responses, e.g. if a requested document cannot be found.""" + """Documents error responses, e.g. if the requested document cannot be found.""" NotFound = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find any document with the given object id." diff --git a/trolldb/database/mongodb.py b/trolldb/database/mongodb.py index 3ec5a79..214b8cf 100644 --- a/trolldb/database/mongodb.py +++ b/trolldb/database/mongodb.py @@ -51,7 +51,7 @@ async def get_id(doc: CoroutineDocument) -> str: Note: The rationale behind this method is as follows. In MongoDB, each document has a unique ID which is of type - :class:`~bson.objectid.ObjectId`. This is not suitable for purposes when a simple string is needed, hence + :class:`bson.objectid.ObjectId`. This is not suitable for purposes when a simple string is needed, hence the need for this method. Args: @@ -84,7 +84,7 @@ async def get_ids(docs: Union[AsyncIOMotorCommandCursor, AsyncIOMotorCursor]) -> class MongoDB: """A wrapper class around the `motor async driver `_ for Mongo DB. - It includes convenience methods tailored to our specific needs. As such, the :func:`~MongoDB.initialize()`` method + It includes convenience methods tailored to our specific needs. As such, the :func:`~MongoDB.initialize()` method returns a coroutine which needs to be awaited. Note: @@ -117,16 +117,20 @@ async def initialize(cls, database_config: DatabaseConfig): Args: database_config: - A named tuple which includes the database configurations. + An object of type :class:`~trolldb.config.config.DatabaseConfig` which includes the database + configurations. + + Warning: + The timeout is given in seconds in the configurations, while the MongoDB uses milliseconds. Returns: On success ``None``. Raises: SystemExit(errno.EIO): - If connection is not established (``ConnectionFailure``) + If connection is not established, i.e. ``ConnectionFailure``. SystemExit(errno.EIO): - If the attempt times out (``ServerSelectionTimeoutError``) + If the attempt times out, i.e. ``ServerSelectionTimeoutError``. SystemExit(errno.EIO): If one attempts reinitializing the class with new (different) database configurations without calling :func:`~close()` first. @@ -135,7 +139,8 @@ async def initialize(cls, database_config: DatabaseConfig): configurations still exist and are different from the new ones which have been just provided. SystemExit(errno.ENODATA): - If either ``database_config.main_database`` or ``database_config.main_collection`` does not exist. + If either ``database_config.main_database_name`` or ``database_config.main_collection_name`` does not + exist. """ logger.info("Attempt to initialize the MongoDB client ...") logger.info("Checking the database configs ...") diff --git a/trolldb/database/piplines.py b/trolldb/database/piplines.py index f85fa15..3e74640 100644 --- a/trolldb/database/piplines.py +++ b/trolldb/database/piplines.py @@ -112,7 +112,7 @@ class Pipelines(list): Each item in the list is a dictionary with its key being the literal string ``"$match"`` and its corresponding value being of type :class:`PipelineBooleanDict`. The ``"$match"`` key is what actually triggers the matching operation in the MongoDB aggregation pipeline. The condition against which the matching will be performed is given by the value - which is a simply a boolean pipeline dictionary which has a hierarchical structure. + which is a simply a boolean pipeline dictionary and has a hierarchical structure. Example: .. code-block:: python diff --git a/trolldb/errors/errors.py b/trolldb/errors/errors.py index 95fd6bf..c083b0c 100644 --- a/trolldb/errors/errors.py +++ b/trolldb/errors/errors.py @@ -31,9 +31,9 @@ def _listify(item: str | list[str]) -> list[str]: .. code-block:: python # The following evaluate to True - __listify("test") == ["test"] - __listify(["a", "b"]) = ["a", "b"] - __listify([]) == [] + _listify("test") == ["test"] + _listify(["a", "b"]) = ["a", "b"] + _listify([]) == [] """ return item if isinstance(item, list) else [item] @@ -76,8 +76,7 @@ class ResponseError(Exception): error_b = ResponseError({404: "Not Found"}) errors = error_a | error_b - # When used in a FastAPI response descriptor, - # the following string will be generated for errors + # When used in a FastAPI response descriptor, the following string is generated "Bad Request |OR| Not Found" """ @@ -101,7 +100,7 @@ def __init__(self, args_dict: OrderedDict[StatusCode, str | list[str]] | dict) - error_b = ResponseError({404: "Not Found"}) errors = error_a | error_b errors_a_or_b = ResponseError({400: "Bad Request", 404: "Not Found"}) - errors_list = ResponseError({404: ["Not Found", "Still Not Found"]}) + errors_list = ResponseError({404: ["Not Found", "Yet Not Found"]}) """ self.__dict: OrderedDict = OrderedDict(args_dict) self.extra_information: dict | None = None @@ -188,7 +187,7 @@ def get_error_details( Args: extra_information (Optional, default ``None``): - More information (if any) that wants to be added to the message string. + More information (if any) that needs to be added to the message string. status_code (Optional, default ``None``): The status code to retrieve. This is useful when there are several error items in the internal dictionary. In case of ``None``, the internal dictionary must include a single entry, otherwise an error @@ -233,6 +232,11 @@ def sys_exit_log( def fastapi_descriptor(self) -> dict[StatusCode, dict[str, str]]: """Gets the FastAPI descriptor (dictionary) of the error items stored in :obj:`ResponseError.__dict`. + Note: + Consult the FastAPI documentation for + `additional responses `_ to see why and how + descriptors are used. + Example: .. code-block:: python diff --git a/trolldb/test_utils/__init__.py b/trolldb/test_utils/__init__.py index e1fa351..9f3e45a 100644 --- a/trolldb/test_utils/__init__.py +++ b/trolldb/test_utils/__init__.py @@ -1 +1 @@ -"""This package provide tools to test the database and api packages.""" +"""This package provides tools to test the database and api packages.""" diff --git a/trolldb/test_utils/common.py b/trolldb/test_utils/common.py index 6e3fb21..10d801a 100644 --- a/trolldb/test_utils/common.py +++ b/trolldb/test_utils/common.py @@ -19,7 +19,7 @@ def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict: config will be an empty dictionary. Returns: - A dictionary which resembles an object of type :obj:`AppConfig`. + A dictionary which resembles an object of type :obj:`~trolldb.config.config.AppConfig`. """ app_config = dict( api_server=dict( diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index d8060e9..b8ce234 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -37,7 +37,7 @@ def mongodb_for_test_context(database_config: DatabaseConfig = test_app_config.d class Time: - """A static class to enclose functionalities for generating random time stamps.""" + """A static class to enclose functionalities for generating random timestamps.""" min_start_time = datetime(2019, 1, 1, 0, 0, 0) """The minimum timestamp which is allowed to appear in our data.""" @@ -86,8 +86,8 @@ def __init__(self, platform_name: str, sensor: str) -> None: def generate_dataset(self, max_count: int) -> list[dict]: """Generates the dataset for a given document. - This corresponds to the list of files which are stored in each document. The number of datasets is randomly - chosen from 1 to ``max_count`` for each document. + This corresponds to the list of files which are stored in each document. The number of items in a dataset is + randomly chosen from 1 to ``max_count`` for each document. """ dataset = [] # We suppress ruff (S311) here as we are not generating anything cryptographic here! @@ -113,27 +113,37 @@ def like_mongodb_document(self) -> dict: class TestDatabase: - """A static class which encloses functionalities to prepare and fill the test database with mock data.""" + """A static class which encloses functionalities to prepare and fill the test database with test data.""" # We suppress ruff (S311) here as we are not generating anything cryptographic here! platform_names = choices(["PA", "PB", "PC"], k=10) # noqa: S311 - """Example platform names.""" + """Example platform names. + + Warning: + The value of this variable changes randomly every time. What you see above is just an example which has been + generated as a result of building the documentation! + """ # We suppress ruff (S311) here as we are not generating anything cryptographic here! sensors = choices(["SA", "SB", "SC"], k=10) # noqa: S311 - """Example sensor names.""" + """Example sensor names. + + Warning: + The value of this variable changes randomly every time. What you see above is just an example which has been + generated as a result of building the documentation! + """ database_names = [test_app_config.database.main_database_name, "another_mock_database"] """List of all database names. - The first element is the main database that will be queried by the API and includes the mock data. The second + The first element is the main database that will be queried by the API and includes the test data. The second database is for testing scenarios when one attempts to access another existing database or collection. """ collection_names = [test_app_config.database.main_collection_name, "another_mock_collection"] """List of all collection names. - The first element is the main collection that will be queried by the API and includes the mock data. The second + The first element is the main collection that will be queried by the API and includes the test data. The second collection is for testing scenarios when one attempts to access another existing collection. """ @@ -141,14 +151,14 @@ class TestDatabase: """All database names including the default ones which are automatically created by MongoDB.""" documents: list[dict] = [] - """The list of documents which include mock data.""" + """The list of documents which include test data.""" @classmethod def generate_documents(cls, random_shuffle: bool = True) -> None: """Generates test documents which for practical purposes resemble real data. Warning: - This method is not pure! The side effect is that the :obj:`TestDatabase.documents` is filled. + This method is not pure! The side effect is that the :obj:`TestDatabase.documents` is reset to new values. """ cls.documents = [ Document(p, s).like_mongodb_document() for p, s in zip(cls.platform_names, cls.sensors, strict=False)] @@ -159,8 +169,8 @@ def generate_documents(cls, random_shuffle: bool = True) -> None: def reset(cls): """Resets all the databases/collections. - This is done by deleting all documents in the collections and then inserting a single empty ``{}`` document - in them. + This is done by deleting all documents in the collections and then inserting a single empty document, i.e. + ``{}``, in them. """ with mongodb_for_test_context() as client: for db_name, coll_name in zip(cls.database_names, cls.collection_names, strict=False): @@ -171,7 +181,7 @@ def reset(cls): @classmethod def write_mock_date(cls): - """Fills databases/collections with mock data.""" + """Fills databases/collections with test data.""" with mongodb_for_test_context() as client: # The following function call has side effects! cls.generate_documents() @@ -184,6 +194,6 @@ def write_mock_date(cls): @classmethod def prepare(cls): - """Prepares the MongoDB instance by first resetting the database and then filling it with mock data.""" + """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" cls.reset() cls.write_mock_date() diff --git a/trolldb/test_utils/mongodb_instance.py b/trolldb/test_utils/mongodb_instance.py index 1b16f04..1d7e107 100644 --- a/trolldb/test_utils/mongodb_instance.py +++ b/trolldb/test_utils/mongodb_instance.py @@ -22,14 +22,16 @@ class TestMongoInstance: """Temp directory for logging messages by the MongoDB instance. Warning: - The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way! + The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way + every time! """ storage_dir: str = tempfile.mkdtemp("__pytroll_db_temp_test_storage") """Temp directory for storing database files by the MongoDB instance. Warning: - The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way! + The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way + every time! """ port: int = 28017 From 2a5f8a1cc731690c5acb0b79bf984d4738c4c071 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 24 May 2024 18:20:02 +0200 Subject: [PATCH 02/21] Filter undesirable special members from autodoc. Update conf.py with functionalities to specify which special members have to be kept in the documentation. --- docs/source/conf.py | 52 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c02a333..47ce565 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,6 +11,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import os +import re from sphinx.ext import apidoc @@ -62,14 +63,63 @@ # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] + +# Specify which special members have to be kept +special_members_dict = { + "Document": {"init"}, + "ResponseError": {"init", "or"}, + "PipelineBooleanDict": {"init", "or", "and"}, + "PipelineAttribute": {"init", "or", "and", "eq", "gt", "ge", "lt", "le"}, + "Pipelines": {"init", "add", "iadd"} +} + +# Add trailing and leading "__" to all the aforementioned members +for cls, methods in special_members_dict.items(): + special_members_dict[cls] = {f"__{method}__" for method in methods} + +# Make a set of all allowed special members +all_special_members = set() +for methods in special_members_dict.values(): + all_special_members |= methods + autodoc_default_options = { "members": True, "member-order": "bysource", "private-members": True, "special-members": True, - "undoc-members": True, + "undoc-members": False, } + +def is_special_member(member_name: str) -> bool: + """Checks if the given member is special, i.e. its name has the following format ``____``.""" + return bool(re.compile(r"^__\w+__$").match(member_name)) + + +def skip(app, typ, member_name, obj, flag, options): + """The filter function to determine whether to keep the member in the documentation. + + ``True`` means skip the member. + """ + if is_special_member(member_name): + + if member_name not in all_special_members: + return True + + obj_name = obj.__qualname__.split(".")[0] + if methods_set := special_members_dict.get(obj_name, None): + if member_name in methods_set: + return False # Keep the member + return True + + return None + + +def setup(app): + """Sets up the sphinx app.""" + app.connect("autodoc-skip-member", skip) + + root_doc = "index" output_dir = os.path.join(".") From abc8a320f8a375ee59b5ba46178dd8ae7298c485 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 24 May 2024 18:59:51 +0200 Subject: [PATCH 03/21] Fix a typo. The name of the `pipelines` module was mistakenly spelled as `piplines`. --- trolldb/api/routes/queries.py | 2 +- trolldb/database/{piplines.py => pipelines.py} | 0 trolldb/tests/tests_database/test_pipelines.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename trolldb/database/{piplines.py => pipelines.py} (100%) diff --git a/trolldb/api/routes/queries.py b/trolldb/api/routes/queries.py index e585101..0a35a0b 100644 --- a/trolldb/api/routes/queries.py +++ b/trolldb/api/routes/queries.py @@ -11,7 +11,7 @@ from trolldb.api.routes.common import CheckCollectionDependency from trolldb.database.errors import database_collection_error_descriptor from trolldb.database.mongodb import get_ids -from trolldb.database.piplines import PipelineAttribute, Pipelines +from trolldb.database.pipelines import PipelineAttribute, Pipelines router = APIRouter() diff --git a/trolldb/database/piplines.py b/trolldb/database/pipelines.py similarity index 100% rename from trolldb/database/piplines.py rename to trolldb/database/pipelines.py diff --git a/trolldb/tests/tests_database/test_pipelines.py b/trolldb/tests/tests_database/test_pipelines.py index b993aff..df83b01 100644 --- a/trolldb/tests/tests_database/test_pipelines.py +++ b/trolldb/tests/tests_database/test_pipelines.py @@ -1,5 +1,5 @@ """Tests for the pipelines and applying comparison operations on them.""" -from trolldb.database.piplines import PipelineAttribute, PipelineBooleanDict, Pipelines +from trolldb.database.pipelines import PipelineAttribute, PipelineBooleanDict, Pipelines from trolldb.test_utils.common import compare_by_operator_name From 603b02f0eb10cbf2862165952bcb1ce7e7da56c4 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 24 May 2024 19:17:16 +0200 Subject: [PATCH 04/21] Replace the term `mock` with `test` There were several places that we had used `mock data` to refer to data generated for testing purposes. This was misleading as we are actually generating the data and the API calls are not mocked. Example `mock_database` is renamed to `test_database`. --- trolldb/test_utils/common.py | 4 ++-- trolldb/test_utils/mongodb_database.py | 8 ++++---- trolldb/tests/test_recorder.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/trolldb/test_utils/common.py b/trolldb/test_utils/common.py index 10d801a..730b299 100644 --- a/trolldb/test_utils/common.py +++ b/trolldb/test_utils/common.py @@ -26,8 +26,8 @@ def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict: url="http://localhost:8080" ), database=dict( - main_database_name="mock_database", - main_collection_name="mock_collection", + main_database_name="test_database", + main_collection_name="test_collection", url="mongodb://localhost:28017", timeout=1 ), diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index b8ce234..35c3e66 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -133,14 +133,14 @@ class TestDatabase: generated as a result of building the documentation! """ - database_names = [test_app_config.database.main_database_name, "another_mock_database"] + database_names = [test_app_config.database.main_database_name, "another_test_database"] """List of all database names. The first element is the main database that will be queried by the API and includes the test data. The second database is for testing scenarios when one attempts to access another existing database or collection. """ - collection_names = [test_app_config.database.main_collection_name, "another_mock_collection"] + collection_names = [test_app_config.database.main_collection_name, "another_test_collection"] """List of all collection names. The first element is the main collection that will be queried by the API and includes the test data. The second @@ -180,7 +180,7 @@ def reset(cls): collection.insert_one({}) @classmethod - def write_mock_date(cls): + def write_test_data(cls): """Fills databases/collections with test data.""" with mongodb_for_test_context() as client: # The following function call has side effects! @@ -196,4 +196,4 @@ def write_mock_date(cls): def prepare(cls): """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" cls.reset() - cls.write_mock_date() + cls.write_test_data() diff --git a/trolldb/tests/test_recorder.py b/trolldb/tests/test_recorder.py index 99546c6..d54bb6e 100644 --- a/trolldb/tests/test_recorder.py +++ b/trolldb/tests/test_recorder.py @@ -60,7 +60,7 @@ def config_file(tmp_path): async def message_in_database_and_delete_count_is_one(msg) -> bool: """Checks if there is exactly one item in the database which matches the data of the message.""" async with mongodb_context(test_app_config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") + collection = await MongoDB.get_collection("test_database", "test_collection") result = await collection.find_one(dict(scan_mode="EW")) result.pop("_id") deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) @@ -97,7 +97,7 @@ async def test_record_deletes_message(tmp_path, file_message, del_message): with patched_subscriber_recv([file_message, del_message]): await record_messages(config) async with mongodb_context(config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") + collection = await MongoDB.get_collection("test_database", "test_collection") result = await collection.find_one(dict(scan_mode="EW")) assert result is None From 55915f98ffcab2db9cd55f92c9c894e4928a9261 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Mon, 27 May 2024 14:10:33 +0200 Subject: [PATCH 05/21] Add a test for the `/datetime` route of the API --- trolldb/test_utils/mongodb_database.py | 43 +++++++++++++++++++++++++- trolldb/tests/tests_api/test_api.py | 6 ++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index 35c3e66..9214786 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -1,5 +1,4 @@ """The module which provides testing utilities to make MongoDB databases/collections and fill them with test data.""" - from contextlib import contextmanager from datetime import datetime, timedelta from random import choices, randint, shuffle @@ -190,8 +189,50 @@ def write_test_data(cls): ][ test_app_config.database.main_collection_name ] + collection.delete_many({}) collection.insert_many(cls.documents) + @classmethod + def find_min_max_datetime(cls): + """Finds the minimum and the maximum for both the ``start_time`` and the ``end_time``. + + We use `brute force` for this purpose. We set the minimum to a large value (year 2100) and the maximum to a + small value (year 1900). We then iterate through all documents and update the extrema. + + Returns: + A dictionary whose schema matches the response returned by the ``/datetime`` route of the API. + """ + result = dict( + start_time=dict( + _min=dict(_id=None, _time="2100-01-01T00:00:00"), + _max=dict(_id=None, _time="1900-01-01T00:00:00") + ), + end_time=dict( + _min=dict(_id=None, _time="2100-01-01T00:00:00"), + _max=dict(_id=None, _time="1900-01-01T00:00:00")) + ) + + with mongodb_for_test_context() as client: + collection = client[ + test_app_config.database.main_database_name + ][ + test_app_config.database.main_collection_name + ] + documents = collection.find({}) + + for document in documents: + for k in ["start_time", "end_time"]: + dt = document[k].isoformat() + if dt > result[k]["_max"]["_time"]: + result[k]["_max"]["_time"] = dt + result[k]["_max"]["_id"] = str(document["_id"]) + + if dt < result[k]["_min"]["_time"]: + result[k]["_min"]["_time"] = dt + result[k]["_min"]["_id"] = str(document["_id"]) + + return result + @classmethod def prepare(cls): """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" diff --git a/trolldb/tests/tests_api/test_api.py b/trolldb/tests/tests_api/test_api.py index c721345..160a0cf 100644 --- a/trolldb/tests/tests_api/test_api.py +++ b/trolldb/tests/tests_api/test_api.py @@ -79,3 +79,9 @@ def test_collections_negative(): """Checks that the non-existing collections cannot be found.""" for database_name in TestDatabase.database_names: assert http_get(f"databases/{database_name}/non_existing_collection").status == status.HTTP_404_NOT_FOUND + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_datetime(): + """Checks that the datetime route works properly.""" + assert http_get("datetime").json() == TestDatabase.find_min_max_datetime() From 9ee2e608e28af645ea7b8fec097b11da08bff6cd Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Wed, 29 May 2024 09:46:32 +0200 Subject: [PATCH 06/21] Add more tests for the `queries` route As a result, `test_utils` and the `queries` route handler have been also updated. --- trolldb/api/routes/queries.py | 8 +-- trolldb/test_utils/mongodb_database.py | 86 +++++++++++++++++----- trolldb/tests/tests_api/test_api.py | 98 +++++++++++++++++++++++++- 3 files changed, 167 insertions(+), 25 deletions(-) diff --git a/trolldb/api/routes/queries.py b/trolldb/api/routes/queries.py index 0a35a0b..7715aec 100644 --- a/trolldb/api/routes/queries.py +++ b/trolldb/api/routes/queries.py @@ -29,7 +29,6 @@ async def queries( time_min: datetime.datetime = Query(default=None), # noqa: B008 time_max: datetime.datetime = Query(default=None)) -> list[str]: # noqa: B008 """Please consult the auto-generated documentation by FastAPI.""" - # We pipelines = Pipelines() if platform: @@ -42,10 +41,7 @@ async def queries( start_time = PipelineAttribute("start_time") end_time = PipelineAttribute("end_time") pipelines += ( - (start_time >= time_min) | - (start_time <= time_max) | - (end_time >= time_min) | - (end_time <= time_max) + ((start_time >= time_min) & (start_time <= time_max)) | + ((end_time >= time_min) & (end_time <= time_max)) ) - return await get_ids(collection.aggregate(pipelines)) diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index 9214786..e041ea2 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -1,5 +1,6 @@ """The module which provides testing utilities to make MongoDB databases/collections and fill them with test data.""" from contextlib import contextmanager +from copy import deepcopy from datetime import datetime, timedelta from random import choices, randint, shuffle from typing import Iterator @@ -114,8 +115,11 @@ def like_mongodb_document(self) -> dict: class TestDatabase: """A static class which encloses functionalities to prepare and fill the test database with test data.""" + unique_platform_names: list[str] = ["PA", "PB", "PC"] + """The unique platform names that will be used to generate the sample of all platform names.""" + # We suppress ruff (S311) here as we are not generating anything cryptographic here! - platform_names = choices(["PA", "PB", "PC"], k=10) # noqa: S311 + platform_names = choices(["PA", "PB", "PC"], k=20) # noqa: S311 """Example platform names. Warning: @@ -123,8 +127,11 @@ class TestDatabase: generated as a result of building the documentation! """ + unique_sensors: list[str] = ["SA", "SB", "SC"] + """The unique sensor names that will be used to generate the sample of all sensor names.""" + # We suppress ruff (S311) here as we are not generating anything cryptographic here! - sensors = choices(["SA", "SB", "SC"], k=10) # noqa: S311 + sensors = choices(["SA", "SB", "SC"], k=20) # noqa: S311 """Example sensor names. Warning: @@ -192,6 +199,23 @@ def write_test_data(cls): collection.delete_many({}) collection.insert_many(cls.documents) + @classmethod + def get_all_documents_from_database(cls) -> list[dict]: + """Retrieves all the documents from the database. + + Returns: + A list of all documents from the database. This matches the content of :obj:`~TestDatabase.documents` with + the addition of `IDs` which are assigned by the MongoDB. + """ + with mongodb_for_test_context() as client: + collection = client[ + test_app_config.database.main_database_name + ][ + test_app_config.database.main_collection_name + ] + documents = list(collection.find({})) + return documents + @classmethod def find_min_max_datetime(cls): """Finds the minimum and the maximum for both the ``start_time`` and the ``end_time``. @@ -212,27 +236,53 @@ def find_min_max_datetime(cls): _max=dict(_id=None, _time="1900-01-01T00:00:00")) ) - with mongodb_for_test_context() as client: - collection = client[ - test_app_config.database.main_database_name - ][ - test_app_config.database.main_collection_name - ] - documents = collection.find({}) + documents = cls.get_all_documents_from_database() - for document in documents: - for k in ["start_time", "end_time"]: - dt = document[k].isoformat() - if dt > result[k]["_max"]["_time"]: - result[k]["_max"]["_time"] = dt - result[k]["_max"]["_id"] = str(document["_id"]) + for document in documents: + for k in ["start_time", "end_time"]: + dt = document[k].isoformat() + if dt > result[k]["_max"]["_time"]: + result[k]["_max"]["_time"] = dt + result[k]["_max"]["_id"] = str(document["_id"]) - if dt < result[k]["_min"]["_time"]: - result[k]["_min"]["_time"] = dt - result[k]["_min"]["_id"] = str(document["_id"]) + if dt < result[k]["_min"]["_time"]: + result[k]["_min"]["_time"] = dt + result[k]["_min"]["_id"] = str(document["_id"]) return result + @classmethod + def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None): + """Matches the given query. + + We first take all the documents and then progressively remove all that do not match the given queries until + we end up with those that match. When a query is ``None``, it does not have any effect on the results. + """ + documents = cls.get_all_documents_from_database() + + buffer = deepcopy(documents) + for document in documents: + should_remove = False + if platform: + should_remove = document["platform_name"] not in platform + + if sensor and not should_remove: + should_remove = document["sensor"] not in sensor + + if time_min and time_max and not should_remove: + should_remove = document["end_time"] < time_min or document["start_time"] > time_max + + if time_min and not time_max and not should_remove: + should_remove = document["end_time"] < time_min + + if time_max and not time_min and not should_remove: + should_remove = document["end_time"] > time_max + + if should_remove and document in buffer: + buffer.remove(document) + + return [str(item["_id"]) for item in buffer] + @classmethod def prepare(cls): """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" diff --git a/trolldb/tests/tests_api/test_api.py b/trolldb/tests/tests_api/test_api.py index 160a0cf..b32be8f 100644 --- a/trolldb/tests/tests_api/test_api.py +++ b/trolldb/tests/tests_api/test_api.py @@ -8,13 +8,17 @@ """ from collections import Counter +from datetime import datetime import pytest from fastapi import status -from trolldb.test_utils.common import http_get +from trolldb.test_utils.common import http_get, test_app_config from trolldb.test_utils.mongodb_database import TestDatabase, mongodb_for_test_context +main_database_name = test_app_config.database.main_database_name +main_collection_name = test_app_config.database.main_collection_name + def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool: """Checks if the test and expected list of collection names match.""" @@ -26,6 +30,44 @@ def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bo return Counter(test_ids) == Counter(expected_ids) +def single_query_is_correct(key: str, value: str | datetime) -> bool: + """Checks if the given single query, denoted by ``key`` matches correctly against the ``value``.""" + return ( + Counter(http_get(f"queries?{key}={value}").json()) == + Counter(TestDatabase.match_query(**{key: value})) + ) + + +def query_results_are_correct(keys: list[str], values_list: list[list[str | datetime]]) -> bool: + """Checks if the retrieved result from querying the database via the API matches the expected result. + + There can be more than one query `key/value` pair. + + Args: + keys: + A list of all query keys, e.g. ``keys=["platform", "sensor"]`` + + values_list: + A list in which each element is a list of values itself. The `nth` element corresponds to the `nth` key in + the ``keys``. + + Returns: + A boolean flag indicating whether the retrieved result matches the expected result. + """ + # Make a single query string for all queries + query_buffer = [] + for label, value_list in zip(keys, values_list, strict=True): + query_buffer += [f"{label}={value}" for value in value_list] + query_string = "&".join(query_buffer) + + return ( + Counter(http_get(f"queries?{query_string}").json()) == + Counter(TestDatabase.match_query( + **{label: value_list for label, value_list in zip(keys, values_list, strict=True)} + )) + ) + + @pytest.mark.usefixtures("_test_server_fixture") def test_root(): """Checks that the server is up and running, i.e. the root routes responds with 200.""" @@ -85,3 +127,57 @@ def test_collections_negative(): def test_datetime(): """Checks that the datetime route works properly.""" assert http_get("datetime").json() == TestDatabase.find_min_max_datetime() + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_all(): + """Tests that the queries route returns all documents when no actual queries are given.""" + assert document_ids_are_correct( + http_get("queries").json(), + [str(doc["_id"]) for doc in TestDatabase.get_all_documents_from_database()] + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +@pytest.mark.parametrize(("key", "values"), [ + ("platform", TestDatabase.unique_platform_names), + ("sensor", TestDatabase.unique_sensors) +]) +def test_queries_platform_or_sensor(key, values): + """Tests the platform and sensor queries, one at a time. + + There is only a single key in the query, but it has multiple corresponding values. + """ + for i in range(len(values)): + assert query_results_are_correct( + [key], + [values[:i]] + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_mix_platform_sensor(): + """Tests a mix of platform and sensor queries.""" + for n_plt, n_sns in zip([1, 1, 2, 3, 3], [1, 3, 2, 1, 3], strict=False): + assert query_results_are_correct( + ["platform", "sensor"], + [TestDatabase.unique_platform_names[:n_plt], TestDatabase.unique_sensors[:n_sns]] + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_time(): + """Checks that a single time query works properly.""" + res = http_get("datetime").json() + time_min = datetime.fromisoformat(res["start_time"]["_min"]["_time"]) + time_max = datetime.fromisoformat(res["end_time"]["_max"]["_time"]) + + assert single_query_is_correct( + "time_min", + time_min + ) + + assert single_query_is_correct( + "time_max", + time_max + ) From 409df77aee9c175bce2424475d7d5ed52e62b548 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Wed, 29 May 2024 10:12:52 +0200 Subject: [PATCH 07/21] Log insertion and deletion --- trolldb/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trolldb/cli.py b/trolldb/cli.py index 294aab9..d5e1a01 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -22,10 +22,12 @@ async def record_messages(config: AppConfig): msg = Message.decode(str(m)) if msg.type in ["file", "dataset"]: await collection.insert_one(msg.data) + logger.info(f"Inserted file with uri: {msg.data["uri"]}") elif msg.type == "del": deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) + logger.info(f"Deleted document with uri: {msg.data["uri"]}") if deletion_result.deleted_count != 1: - logger.error("Recorder found multiple deletions!") # TODO: Log some data related to the msg + logger.error(f"Recorder found multiple deletions for uri: {msg.data["uri"]}!") else: logger.debug(f"Don't know what to do with {msg.type} message.") From 76fe40e1d4e1e6f99bb2063e267daa7d76934bba Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Wed, 29 May 2024 11:02:59 +0200 Subject: [PATCH 08/21] Refactor `cli` and `test_recorder` --- trolldb/cli.py | 48 +++++++++++++++++++++++++++------- trolldb/tests/test_recorder.py | 26 +++++++++++------- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/trolldb/cli.py b/trolldb/cli.py index d5e1a01..10b22e9 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -4,6 +4,7 @@ import asyncio from loguru import logger +from motor.motor_asyncio import AsyncIOMotorCollection from posttroll.message import Message from posttroll.subscriber import create_subscriber_from_dict_config from pydantic import FilePath @@ -12,6 +13,30 @@ from trolldb.database.mongodb import MongoDB, mongodb_context +async def delete_uri_from_collection(collection: AsyncIOMotorCollection, uri: str) -> int: + """Deletes a document from collection and logs the deletion. + + Args: + collection: + The collection object which includes the document to delete. + uri: + The URI used to query the collection. It can be either a URI of a previously recorded file message or + a dataset message. + + Returns: + Number of deleted documents. + """ + del_result_file = await collection.delete_many({"uri": uri}) + if del_result_file.deleted_count == 1: + logger.info(f"Deleted one document (file) with uri: {uri}") + + del_result_dataset = await collection.delete_many({"dataset.uri": uri}) + if del_result_dataset.deleted_count == 1: + logger.info(f"Deleted one document (dataset) with uri: {uri}") + + return del_result_file.deleted_count + del_result_dataset.deleted_count + + async def record_messages(config: AppConfig): """Record the metadata of messages into the database.""" async with mongodb_context(config.database): @@ -20,16 +45,19 @@ async def record_messages(config: AppConfig): ) for m in create_subscriber_from_dict_config(config.subscriber).recv(): msg = Message.decode(str(m)) - if msg.type in ["file", "dataset"]: - await collection.insert_one(msg.data) - logger.info(f"Inserted file with uri: {msg.data["uri"]}") - elif msg.type == "del": - deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) - logger.info(f"Deleted document with uri: {msg.data["uri"]}") - if deletion_result.deleted_count != 1: - logger.error(f"Recorder found multiple deletions for uri: {msg.data["uri"]}!") - else: - logger.debug(f"Don't know what to do with {msg.type} message.") + match msg.type: + case "file": + await collection.insert_one(msg.data) + logger.info(f"Inserted file with uri: {msg.data["uri"]}") + case "dataset": + await collection.insert_one(msg.data) + logger.info(f"Inserted dataset with {len(msg.data["dataset"])} elements.") + case "del": + deletion_count = await delete_uri_from_collection(collection, msg.data["uri"]) + if deletion_count >= 1: + logger.error(f"Recorder found multiple deletions for uri: {msg.data["uri"]}!") + case _: + logger.debug(f"Don't know what to do with {msg.type} message.") async def record_messages_from_config(config_file: FilePath): diff --git a/trolldb/tests/test_recorder.py b/trolldb/tests/test_recorder.py index d54bb6e..553a8f0 100644 --- a/trolldb/tests/test_recorder.py +++ b/trolldb/tests/test_recorder.py @@ -5,7 +5,12 @@ from posttroll.testing import patched_subscriber_recv from pytest_lazy_fixtures import lf -from trolldb.cli import record_messages, record_messages_from_command_line, record_messages_from_config +from trolldb.cli import ( + delete_uri_from_collection, + record_messages, + record_messages_from_command_line, + record_messages_from_config, +) from trolldb.database.mongodb import MongoDB, mongodb_context from trolldb.test_utils.common import AppConfig, create_config_file, make_test_app_config, test_app_config from trolldb.test_utils.mongodb_instance import running_prepared_database_context @@ -57,14 +62,18 @@ def config_file(tmp_path): return create_config_file(tmp_path) -async def message_in_database_and_delete_count_is_one(msg) -> bool: +async def message_in_database_and_delete_count_is_one(msg: Message) -> bool: """Checks if there is exactly one item in the database which matches the data of the message.""" async with mongodb_context(test_app_config.database): collection = await MongoDB.get_collection("test_database", "test_collection") result = await collection.find_one(dict(scan_mode="EW")) result.pop("_id") - deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) - return result == msg.data and deletion_result.deleted_count == 1 + uri = msg.data.get("uri") + if not uri: + uri = msg.data["dataset"][0]["uri"] + deletion_count = await delete_uri_from_collection(collection, uri) + + return result == msg.data and deletion_count == 1 @pytest.mark.parametrize(("function", "args"), [ @@ -101,13 +110,12 @@ async def test_record_deletes_message(tmp_path, file_message, del_message): result = await collection.find_one(dict(scan_mode="EW")) assert result is None + async def test_record_dataset_messages(tmp_path, dataset_message): - """Test recording a dataset message and deleting the file.""" + """Tests recording a dataset message and deleting the file.""" config = AppConfig(**make_test_app_config(tmp_path)) + msg = Message.decode(dataset_message) with running_prepared_database_context(): with patched_subscriber_recv([dataset_message]): await record_messages(config) - async with mongodb_context(config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") - result = await collection.find_one(dict(scan_mode="EW")) - assert result is not None + assert await message_in_database_and_delete_count_is_one(msg) From c00a7a61a7069506bd42d8e44227b9543014c118 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Wed, 29 May 2024 17:40:25 +0200 Subject: [PATCH 09/21] Update API to use new query annotations The new version of FastAPI recommends using `Annotated` for queries. This also resolves the issue with ruff complaining about `Query(None)`. --- trolldb/api/routes/common.py | 10 +--------- trolldb/api/routes/databases.py | 13 ++++++++++--- trolldb/api/routes/queries.py | 11 +++++------ 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/trolldb/api/routes/common.py b/trolldb/api/routes/common.py index b05c86b..86bb4f8 100644 --- a/trolldb/api/routes/common.py +++ b/trolldb/api/routes/common.py @@ -2,19 +2,11 @@ from typing import Annotated, Union -from fastapi import Depends, Query, Response +from fastapi import Depends, Response from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase from trolldb.database.mongodb import MongoDB -exclude_defaults_query = Query( - True, - title="Query string", - description= - "A boolean to exclude default databases from a MongoDB instance. Refer to " - "`trolldb.database.mongodb.MongoDB.default_database_names` for more information." -) - async def check_database(database_name: str | None = None) -> AsyncIOMotorDatabase: """A dependency for route handlers to check for the existence of a database given its name. diff --git a/trolldb/api/routes/databases.py b/trolldb/api/routes/databases.py index e114ef4..046d57a 100644 --- a/trolldb/api/routes/databases.py +++ b/trolldb/api/routes/databases.py @@ -4,10 +4,12 @@ For more information on the API server, see the automatically generated documentation by FastAPI. """ -from fastapi import APIRouter +from typing import Annotated + +from fastapi import APIRouter, Query from pymongo.collection import _DocumentType -from trolldb.api.routes.common import CheckCollectionDependency, CheckDataBaseDependency, exclude_defaults_query +from trolldb.api.routes.common import CheckCollectionDependency, CheckDataBaseDependency from trolldb.config.config import MongoObjectId from trolldb.database.errors import ( Databases, @@ -23,7 +25,12 @@ @router.get("/", response_model=list[str], summary="Gets the list of all database names") -async def database_names(exclude_defaults: bool = exclude_defaults_query) -> list[str]: +async def database_names( + exclude_defaults: Annotated[bool, Query( + title="Query parameter", + description="A boolean to exclude default databases from a MongoDB instance. Refer to " + "`trolldb.database.mongodb.MongoDB.default_database_names` for more information." + )] = True) -> list[str]: """Please consult the auto-generated documentation by FastAPI.""" db_names = await MongoDB.list_database_names() diff --git a/trolldb/api/routes/queries.py b/trolldb/api/routes/queries.py index 7715aec..4542b57 100644 --- a/trolldb/api/routes/queries.py +++ b/trolldb/api/routes/queries.py @@ -5,6 +5,7 @@ """ import datetime +from typing import Annotated from fastapi import APIRouter, Query @@ -22,12 +23,10 @@ summary="Gets the database UUIDs of the documents that match specifications determined by the query string") async def queries( collection: CheckCollectionDependency, - # We suppress ruff for the following four lines with `Query(default=None)`. - # Reason: This is the FastAPI way of defining optional queries and ruff is not happy about it! - platform: list[str] = Query(default=None), # noqa: B008 - sensor: list[str] = Query(default=None), # noqa: B008 - time_min: datetime.datetime = Query(default=None), # noqa: B008 - time_max: datetime.datetime = Query(default=None)) -> list[str]: # noqa: B008 + platform: Annotated[list[str] | None, Query()] = None, + sensor: Annotated[list[str] | None, Query()] = None, + time_min: Annotated[datetime.datetime, Query()] = None, + time_max: Annotated[datetime.datetime, Query()] = None) -> list[str]: """Please consult the auto-generated documentation by FastAPI.""" pipelines = Pipelines() From a4298f5ca513ea3e3fcd6a40cd8a08b642ea9092 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Wed, 29 May 2024 19:01:03 +0200 Subject: [PATCH 10/21] Improve type hints --- trolldb/api/api.py | 9 +++--- trolldb/cli.py | 8 +++--- trolldb/database/errors.py | 22 +++++++------- trolldb/database/mongodb.py | 17 ++++++----- trolldb/database/pipelines.py | 4 +-- trolldb/errors/errors.py | 16 +++++------ trolldb/test_utils/common.py | 2 +- trolldb/test_utils/mongodb_database.py | 40 ++++++++++++++------------ trolldb/test_utils/mongodb_instance.py | 23 ++++++++------- 9 files changed, 73 insertions(+), 68 deletions(-) diff --git a/trolldb/api/api.py b/trolldb/api/api.py index 4ce42e7..40f3379 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -17,7 +17,7 @@ import time from contextlib import contextmanager from multiprocessing import Process -from typing import Union +from typing import Any, Generator, NoReturn, Union import uvicorn from fastapi import FastAPI, status @@ -89,7 +89,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: app.include_router(api_router) @app.exception_handler(ResponseError) - async def auto_exception_handler(_, exc: ResponseError): + async def auto_exception_handler(_, exc: ResponseError) -> PlainTextResponse: """Catches all the exceptions raised as a ResponseError, e.g. accessing non-existing databases/collections.""" status_code, message = exc.get_error_details() info = dict( @@ -99,7 +99,7 @@ async def auto_exception_handler(_, exc: ResponseError): logger.error(f"Response error caught by the API auto exception handler: {info}") return PlainTextResponse(**info) - async def _serve(): + async def _serve() -> NoReturn: """An auxiliary coroutine to be used in the asynchronous execution of the FastAPI application.""" async with mongodb_context(config.database): logger.info("Attempt to start the uvicorn server ...") @@ -116,7 +116,8 @@ async def _serve(): @contextmanager -def api_server_process_context(config: Union[AppConfig, FilePath], startup_time: Timeout = 2): +def api_server_process_context( + config: Union[AppConfig, FilePath], startup_time: Timeout = 2) -> Generator[Process, Any, None]: """A synchronous context manager to run the API server in a separate process (non-blocking). It uses the `multiprocessing `_ package. The main use case diff --git a/trolldb/cli.py b/trolldb/cli.py index 10b22e9..89b4520 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -37,7 +37,7 @@ async def delete_uri_from_collection(collection: AsyncIOMotorCollection, uri: st return del_result_file.deleted_count + del_result_dataset.deleted_count -async def record_messages(config: AppConfig): +async def record_messages(config: AppConfig) -> None: """Record the metadata of messages into the database.""" async with mongodb_context(config.database): collection = await MongoDB.get_collection( @@ -60,13 +60,13 @@ async def record_messages(config: AppConfig): logger.debug(f"Don't know what to do with {msg.type} message.") -async def record_messages_from_config(config_file: FilePath): +async def record_messages_from_config(config_file: FilePath) -> None: """Record messages into the database, getting the configuration from a file.""" config = parse_config_yaml_file(config_file) await record_messages(config) -async def record_messages_from_command_line(args=None): +async def record_messages_from_command_line(args=None) -> None: """Record messages into the database, command-line interface.""" parser = argparse.ArgumentParser() parser.add_argument( @@ -77,6 +77,6 @@ async def record_messages_from_command_line(args=None): await record_messages_from_config(cmd_args.configuration_file) -def run_sync(): +def run_sync() -> None: """Runs the interface synchronously.""" asyncio.run(record_messages_from_command_line()) diff --git a/trolldb/database/errors.py b/trolldb/database/errors.py index 983597e..7a14608 100644 --- a/trolldb/database/errors.py +++ b/trolldb/database/errors.py @@ -6,6 +6,8 @@ are (expected to be) self-explanatory and require no additional documentation. """ +from typing import ClassVar + from fastapi import status from trolldb.errors.errors import ResponseError, ResponsesErrorGroup @@ -13,29 +15,29 @@ class Client(ResponsesErrorGroup): """Database client error responses, e.g. if something goes wrong with initialization or closing the client.""" - CloseNotAllowedError = ResponseError({ + CloseNotAllowedError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "Calling `close()` on a client which has not been initialized is not allowed!" }) - ReinitializeConfigError = ResponseError({ + ReinitializeConfigError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "The client is already initialized with a different database configuration!" }) - AlreadyOpenError = ResponseError({ + AlreadyOpenError: ClassVar[ResponseError] = ResponseError({ status.HTTP_100_CONTINUE: "The client has been already initialized with the same configuration." }) - InconsistencyError = ResponseError({ + InconsistencyError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "Something must have been wrong as we are in an inconsistent state. " "The internal database configuration is not empty and is the same as what we just " "received but the client is `None` or has been already closed!" }) - ConnectionError = ResponseError({ + ConnectionError: ClassVar[ResponseError] = ResponseError({ status.HTTP_400_BAD_REQUEST: "Could not connect to the database with the given URL." }) @@ -43,12 +45,12 @@ class Client(ResponsesErrorGroup): class Collections(ResponsesErrorGroup): """Collections error responses, e.g. if the requested collection cannot be found.""" - NotFoundError = ResponseError({ + NotFoundError: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given collection name inside the specified database." }) - WrongTypeError = ResponseError({ + WrongTypeError: ClassVar[ResponseError] = ResponseError({ status.HTTP_422_UNPROCESSABLE_ENTITY: "Both the database and collection name must be `None` if either one is `None`." }) @@ -56,12 +58,12 @@ class Collections(ResponsesErrorGroup): class Databases(ResponsesErrorGroup): """Databases error responses, e.g. if the requested database cannot be found.""" - NotFoundError = ResponseError({ + NotFoundError: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given database name." }) - WrongTypeError = ResponseError({ + WrongTypeError: ClassVar[ResponseError] = ResponseError({ status.HTTP_422_UNPROCESSABLE_ENTITY: "Database name must be either of type `str` or `None.`" }) @@ -69,7 +71,7 @@ class Databases(ResponsesErrorGroup): class Documents(ResponsesErrorGroup): """Documents error responses, e.g. if the requested document cannot be found.""" - NotFound = ResponseError({ + NotFound: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find any document with the given object id." }) diff --git a/trolldb/database/mongodb.py b/trolldb/database/mongodb.py index 214b8cf..b5647b5 100644 --- a/trolldb/database/mongodb.py +++ b/trolldb/database/mongodb.py @@ -7,7 +7,7 @@ import errno from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Coroutine, Optional, TypeVar, Union +from typing import Any, AsyncGenerator, ClassVar, Coroutine, Optional, TypeVar, Union from loguru import logger from motor.motor_asyncio import ( @@ -23,7 +23,6 @@ from trolldb.config.config import DatabaseConfig from trolldb.database.errors import Client, Collections, Databases -from trolldb.errors.errors import ResponseError T = TypeVar("T") CoroutineLike = Coroutine[Any, Any, T] @@ -103,12 +102,12 @@ class MongoDB: us, we would like to fail early! """ - __client: Optional[AsyncIOMotorClient] = None - __database_config: Optional[DatabaseConfig] = None - __main_collection: AsyncIOMotorCollection = None - __main_database: AsyncIOMotorDatabase = None + __client: ClassVar[Optional[AsyncIOMotorClient]] = None + __database_config: ClassVar[Optional[DatabaseConfig]] = None + __main_collection: ClassVar[Optional[AsyncIOMotorCollection]] = None + __main_database: ClassVar[Optional[AsyncIOMotorDatabase]] = None - default_database_names = ["admin", "config", "local"] + default_database_names: ClassVar[list[str]] = ["admin", "config", "local"] """MongoDB creates these databases by default for self usage.""" @classmethod @@ -228,7 +227,7 @@ def main_database(cls) -> AsyncIOMotorDatabase: async def get_collection( cls, database_name: str, - collection_name: str) -> Union[AsyncIOMotorCollection, ResponseError]: + collection_name: str) -> AsyncIOMotorCollection: """Gets the collection object given its name and the database name in which it resides. Args: @@ -271,7 +270,7 @@ async def get_collection( raise Collections.WrongTypeError @classmethod - async def get_database(cls, database_name: str) -> Union[AsyncIOMotorDatabase, ResponseError]: + async def get_database(cls, database_name: str) -> AsyncIOMotorDatabase: """Gets the database object given its name. Args: diff --git a/trolldb/database/pipelines.py b/trolldb/database/pipelines.py index 3e74640..237637c 100644 --- a/trolldb/database/pipelines.py +++ b/trolldb/database/pipelines.py @@ -30,11 +30,11 @@ class PipelineBooleanDict(dict): pd_or == pd_or_literal """ - def __or__(self, other: Self): + def __or__(self, other: Self) -> Self: """Implements the bitwise or operator, i.e. ``|``.""" return PipelineBooleanDict({"$or": [self, other]}) - def __and__(self, other: Self): + def __and__(self, other: Self) -> Self: """Implements the bitwise and operator, i.e. ``&``.""" return PipelineBooleanDict({"$and": [self, other]}) diff --git a/trolldb/errors/errors.py b/trolldb/errors/errors.py index c083b0c..e0ea80b 100644 --- a/trolldb/errors/errors.py +++ b/trolldb/errors/errors.py @@ -6,7 +6,7 @@ from collections import OrderedDict from sys import exit -from typing import Self +from typing import ClassVar, NoReturn, Self from fastapi import Response from fastapi.responses import PlainTextResponse @@ -64,7 +64,7 @@ class ResponseError(Exception): messages. """ - descriptor_delimiter: str = " |OR| " + descriptor_delimiter: ClassVar[str] = " |OR| " """A delimiter to divide the message part of several error responses which have been combined into a single one. This will be shown in textual format for the response descriptors of the Fast API routes. @@ -80,7 +80,7 @@ class ResponseError(Exception): "Bad Request |OR| Not Found" """ - DefaultResponseClass: Response = PlainTextResponse + DefaultResponseClass: ClassVar[Response] = PlainTextResponse """The default type of the response which will be returned when an error occurs. This must be a valid member (class) of ``fastapi.responses``. @@ -105,7 +105,7 @@ def __init__(self, args_dict: OrderedDict[StatusCode, str | list[str]] | dict) - self.__dict: OrderedDict = OrderedDict(args_dict) self.extra_information: dict | None = None - def __or__(self, other: Self): + def __or__(self, other: Self) -> Self: """Implements the bitwise `or` ``|`` which combines the error objects into a single error response. Args: @@ -141,7 +141,7 @@ def __or__(self, other: Self): def __retrieve_one_from_some( self, - status_code: StatusCode | None = None) -> (StatusCode, str): + status_code: StatusCode | None = None) -> tuple[StatusCode, str]: """Retrieves a tuple ``(, )`` from the internal dictionary :obj:`ResponseError.__dict`. Args: @@ -182,7 +182,7 @@ def __retrieve_one_from_some( def get_error_details( self, extra_information: dict | None = None, - status_code: int | None = None) -> (StatusCode, str): + status_code: int | None = None) -> tuple[StatusCode, str]: """Gets the details of the error response. Args: @@ -202,7 +202,7 @@ def get_error_details( def log_as_warning( self, extra_information: dict | None = None, - status_code: int | None = None): + status_code: int | None = None) -> None: """Same as :func:`~ResponseError.get_error_details` but logs the error as a warning and returns ``None``.""" msg, _ = self.get_error_details(extra_information, status_code) logger.warning(msg) @@ -211,7 +211,7 @@ def sys_exit_log( self, exit_code: int = -1, extra_information: dict | None = None, - status_code: int | None = None) -> None: + status_code: int | None = None) -> NoReturn: """Same as :func:`~ResponseError.get_error_details` but logs the error and calls the ``sys.exit``. The arguments are the same as :func:`~ResponseError.get_error_details` with the addition of ``exit_code`` diff --git a/trolldb/test_utils/common.py b/trolldb/test_utils/common.py index 730b299..2007977 100644 --- a/trolldb/test_utils/common.py +++ b/trolldb/test_utils/common.py @@ -10,7 +10,7 @@ from trolldb.config.config import AppConfig -def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict: +def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict[str, dict]: """Makes the app configuration when used in testing. Args: diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index e041ea2..2e2e135 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -3,7 +3,7 @@ from copy import deepcopy from datetime import datetime, timedelta from random import choices, randint, shuffle -from typing import Iterator +from typing import Any, ClassVar, Generator from pymongo import MongoClient @@ -12,7 +12,8 @@ @contextmanager -def mongodb_for_test_context(database_config: DatabaseConfig = test_app_config.database) -> Iterator[MongoClient]: +def mongodb_for_test_context( + database_config: DatabaseConfig = test_app_config.database) -> Generator[MongoClient, Any, None]: """A context manager for the MongoDB client given test configurations. Note: @@ -39,13 +40,13 @@ def mongodb_for_test_context(database_config: DatabaseConfig = test_app_config.d class Time: """A static class to enclose functionalities for generating random timestamps.""" - min_start_time = datetime(2019, 1, 1, 0, 0, 0) + min_start_time: ClassVar[datetime] = datetime(2019, 1, 1, 0, 0, 0) """The minimum timestamp which is allowed to appear in our data.""" - max_end_time = datetime(2024, 1, 1, 0, 0, 0) + max_end_time: ClassVar[datetime] = datetime(2024, 1, 1, 0, 0, 0) """The maximum timestamp which is allowed to appear in our data.""" - delta_time = int((max_end_time - min_start_time).total_seconds()) + delta_time: ClassVar[int] = int((max_end_time - min_start_time).total_seconds()) """The difference between the maximum and minimum timestamps in seconds.""" @staticmethod @@ -115,11 +116,11 @@ def like_mongodb_document(self) -> dict: class TestDatabase: """A static class which encloses functionalities to prepare and fill the test database with test data.""" - unique_platform_names: list[str] = ["PA", "PB", "PC"] + unique_platform_names: ClassVar[list[str]] = ["PA", "PB", "PC"] """The unique platform names that will be used to generate the sample of all platform names.""" # We suppress ruff (S311) here as we are not generating anything cryptographic here! - platform_names = choices(["PA", "PB", "PC"], k=20) # noqa: S311 + platform_names: ClassVar[list[str]] = choices(["PA", "PB", "PC"], k=20) # noqa: S311 """Example platform names. Warning: @@ -127,11 +128,11 @@ class TestDatabase: generated as a result of building the documentation! """ - unique_sensors: list[str] = ["SA", "SB", "SC"] + unique_sensors: ClassVar[list[str]] = ["SA", "SB", "SC"] """The unique sensor names that will be used to generate the sample of all sensor names.""" # We suppress ruff (S311) here as we are not generating anything cryptographic here! - sensors = choices(["SA", "SB", "SC"], k=20) # noqa: S311 + sensors: ClassVar[list[str]] = choices(["SA", "SB", "SC"], k=20) # noqa: S311 """Example sensor names. Warning: @@ -139,24 +140,24 @@ class TestDatabase: generated as a result of building the documentation! """ - database_names = [test_app_config.database.main_database_name, "another_test_database"] + database_names: ClassVar[list[str]] = [test_app_config.database.main_database_name, "another_test_database"] """List of all database names. The first element is the main database that will be queried by the API and includes the test data. The second database is for testing scenarios when one attempts to access another existing database or collection. """ - collection_names = [test_app_config.database.main_collection_name, "another_test_collection"] + collection_names: ClassVar[list[str]] = [test_app_config.database.main_collection_name, "another_test_collection"] """List of all collection names. The first element is the main collection that will be queried by the API and includes the test data. The second collection is for testing scenarios when one attempts to access another existing collection. """ - all_database_names = ["admin", "config", "local", *database_names] + all_database_names: ClassVar[list[str]] = ["admin", "config", "local", *database_names] """All database names including the default ones which are automatically created by MongoDB.""" - documents: list[dict] = [] + documents: ClassVar[list[dict]] = [] """The list of documents which include test data.""" @classmethod @@ -167,12 +168,13 @@ def generate_documents(cls, random_shuffle: bool = True) -> None: This method is not pure! The side effect is that the :obj:`TestDatabase.documents` is reset to new values. """ cls.documents = [ - Document(p, s).like_mongodb_document() for p, s in zip(cls.platform_names, cls.sensors, strict=False)] + Document(p, s).like_mongodb_document() for p, s in zip(cls.platform_names, cls.sensors, strict=False) + ] if random_shuffle: shuffle(cls.documents) @classmethod - def reset(cls): + def reset(cls) -> None: """Resets all the databases/collections. This is done by deleting all documents in the collections and then inserting a single empty document, i.e. @@ -186,7 +188,7 @@ def reset(cls): collection.insert_one({}) @classmethod - def write_test_data(cls): + def write_test_data(cls) -> None: """Fills databases/collections with test data.""" with mongodb_for_test_context() as client: # The following function call has side effects! @@ -217,7 +219,7 @@ def get_all_documents_from_database(cls) -> list[dict]: return documents @classmethod - def find_min_max_datetime(cls): + def find_min_max_datetime(cls) -> dict[str, dict]: """Finds the minimum and the maximum for both the ``start_time`` and the ``end_time``. We use `brute force` for this purpose. We set the minimum to a large value (year 2100) and the maximum to a @@ -252,7 +254,7 @@ def find_min_max_datetime(cls): return result @classmethod - def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None): + def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None) -> list[str]: """Matches the given query. We first take all the documents and then progressively remove all that do not match the given queries until @@ -284,7 +286,7 @@ def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None): return [str(item["_id"]) for item in buffer] @classmethod - def prepare(cls): + def prepare(cls) -> None: """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" cls.reset() cls.write_test_data() diff --git a/trolldb/test_utils/mongodb_instance.py b/trolldb/test_utils/mongodb_instance.py index 1d7e107..433bf91 100644 --- a/trolldb/test_utils/mongodb_instance.py +++ b/trolldb/test_utils/mongodb_instance.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from os import mkdir, path from shutil import rmtree +from typing import Any, AnyStr, ClassVar, Generator, Optional from loguru import logger @@ -18,7 +19,7 @@ class TestMongoInstance: """A static class to enclose functionalities for running a MongoDB instance.""" - log_dir: str = tempfile.mkdtemp("__pytroll_db_temp_test_log") + log_dir: ClassVar[str] = tempfile.mkdtemp("__pytroll_db_temp_test_log") """Temp directory for logging messages by the MongoDB instance. Warning: @@ -26,7 +27,7 @@ class TestMongoInstance: every time! """ - storage_dir: str = tempfile.mkdtemp("__pytroll_db_temp_test_storage") + storage_dir: ClassVar[str] = tempfile.mkdtemp("__pytroll_db_temp_test_storage") """Temp directory for storing database files by the MongoDB instance. Warning: @@ -34,18 +35,18 @@ class TestMongoInstance: every time! """ - port: int = 28017 + port: ClassVar[int] = 28017 """The port on which the instance will run. Warning: This must be always hard-coded. """ - process: subprocess.Popen | None = None + process: ClassVar[Optional[subprocess.Popen]] = None """The process which is used to run the MongoDB instance.""" @classmethod - def __prepare_dir(cls, directory: str): + def __prepare_dir(cls, directory: str) -> None: """An auxiliary function to prepare a single directory. It creates a directory if it does not exist, or removes it first if it exists and then recreates it. @@ -54,13 +55,13 @@ def __prepare_dir(cls, directory: str): mkdir(directory) @classmethod - def __remove_dir(cls, directory: str): + def __remove_dir(cls, directory: str) -> None: """An auxiliary function to remove a directory and all its content recursively.""" if path.exists(directory) and path.isdir(directory): rmtree(directory) @classmethod - def run_subprocess(cls, args: list[str], wait=True): + def run_subprocess(cls, args: list[str], wait=True) -> tuple[AnyStr, AnyStr] | None: """Runs the subprocess in shell given its arguments.""" # We suppress ruff (S603) here as we are not receiving any args from outside, e.g. port is hard-coded. # Therefore, sanitization of arguments is not required. @@ -85,14 +86,14 @@ def prepare_dirs(cls) -> None: cls.__prepare_dir(d) @classmethod - def run_instance(cls): + def run_instance(cls) -> None: """Runs the MongoDB instance and does not wait for it, i.e. the process runs in the background.""" cls.run_subprocess( ["mongod", "--dbpath", cls.storage_dir, "--logpath", f"{cls.log_dir}/mongod.log", "--port", f"{cls.port}"] , wait=False) @classmethod - def shutdown_instance(cls): + def shutdown_instance(cls) -> None: """Shuts down the MongoDB instance by terminating its process.""" cls.process.terminate() cls.process.wait() @@ -103,7 +104,7 @@ def shutdown_instance(cls): @contextmanager def mongodb_instance_server_process_context( database_config: DatabaseConfig = test_app_config.database, - startup_time: Timeout = 2): + startup_time: Timeout = 2) -> Generator[Any, Any, None]: """A synchronous context manager to run the MongoDB instance in a separate process (non-blocking). It uses the `subprocess `_ package. The main use case is @@ -133,7 +134,7 @@ def mongodb_instance_server_process_context( @contextmanager -def running_prepared_database_context(): +def running_prepared_database_context() -> Generator[Any, Any, None]: """A synchronous context manager to start and prepare a database instance for tests.""" with mongodb_instance_server_process_context(): TestDatabase.prepare() From f0c7e16843793460f85ad0c86cae10482ec221df Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 14:06:35 +0200 Subject: [PATCH 11/21] Improve type checks and validation using Pydantic This led to the removal of some redundant type checks. --- trolldb/api/api.py | 9 ++++++-- trolldb/config/config.py | 17 +++++++++++++-- trolldb/database/mongodb.py | 42 ++++++++++++++++++------------------- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/trolldb/api/api.py b/trolldb/api/api.py index 40f3379..ec93fd9 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -23,7 +23,7 @@ from fastapi import FastAPI, status from fastapi.responses import PlainTextResponse from loguru import logger -from pydantic import FilePath, validate_call +from pydantic import FilePath, ValidationError, validate_call from trolldb.api.routes import api_router from trolldb.config.config import AppConfig, Timeout, parse_config_yaml_file @@ -89,7 +89,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: app.include_router(api_router) @app.exception_handler(ResponseError) - async def auto_exception_handler(_, exc: ResponseError) -> PlainTextResponse: + async def auto_handler_response_errors(_, exc: ResponseError) -> PlainTextResponse: """Catches all the exceptions raised as a ResponseError, e.g. accessing non-existing databases/collections.""" status_code, message = exc.get_error_details() info = dict( @@ -99,6 +99,11 @@ async def auto_exception_handler(_, exc: ResponseError) -> PlainTextResponse: logger.error(f"Response error caught by the API auto exception handler: {info}") return PlainTextResponse(**info) + @app.exception_handler(ValidationError) + async def auto_handler_pydantic_validation_errors(_, exc: ValidationError) -> PlainTextResponse: + """Catches all the exceptions raised as a Pydantic ValidationError.""" + return PlainTextResponse(str(exc), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + async def _serve() -> NoReturn: """An auxiliary coroutine to be used in the asynchronous execution of the FastAPI application.""" async with mongodb_context(config.database): diff --git a/trolldb/config/config.py b/trolldb/config/config.py index e8c5678..ff1bbc9 100644 --- a/trolldb/config/config.py +++ b/trolldb/config/config.py @@ -1,6 +1,11 @@ """The module which handles parsing and validating the config (YAML) file. The validation is performed using `Pydantic `_. + +Note: + Some functions/methods in this module are decorated with the Pydantic + `@validate_call `_ which checks the arguments during the + function calls. """ import errno @@ -10,7 +15,7 @@ from bson import ObjectId from bson.errors import InvalidId from loguru import logger -from pydantic import AnyUrl, BaseModel, Field, FilePath, MongoDsn, ValidationError +from pydantic import AnyUrl, BaseModel, Field, FilePath, MongoDsn, ValidationError, validate_call from pydantic.functional_validators import AfterValidator from typing_extensions import Annotated from yaml import safe_load @@ -19,6 +24,7 @@ """A type hint for the timeout in seconds (non-negative float).""" +@validate_call def id_must_be_valid(id_like_string: str) -> ObjectId: """Checks that the given string can be converted to a valid MongoDB ObjectId. @@ -30,6 +36,9 @@ def id_must_be_valid(id_like_string: str) -> ObjectId: The ObjectId object if successful. Raises: + ValidationError: + If the given argument is not of type ``str``. + ValueError: If the given string cannot be converted to a valid ObjectId. This will ultimately turn into a pydantic validation error. @@ -99,6 +108,7 @@ class AppConfig(BaseModel): subscriber: SubscriberConfig +@validate_call def parse_config_yaml_file(filename: FilePath) -> AppConfig: """Parses and validates the configurations from a YAML file. @@ -111,7 +121,10 @@ def parse_config_yaml_file(filename: FilePath) -> AppConfig: Raises: ParserError: - If the file cannot be properly parsed + If the file cannot be properly parsed. + + ValidationError: + If the ``filename`` is not of type ``FilePath``. ValidationError: If the successfully parsed file fails the validation, i.e. its schema or the content does not conform to diff --git a/trolldb/database/mongodb.py b/trolldb/database/mongodb.py index b5647b5..4221aee 100644 --- a/trolldb/database/mongodb.py +++ b/trolldb/database/mongodb.py @@ -3,6 +3,11 @@ It is based on the following libraries: - `PyMongo `_ - `motor `_. + +Note: + Some functions/methods in this module are decorated with the Pydantic + `@validate_call `_ which checks the arguments during the + function calls. """ import errno @@ -17,7 +22,7 @@ AsyncIOMotorCursor, AsyncIOMotorDatabase, ) -from pydantic import BaseModel +from pydantic import validate_call from pymongo.collection import _DocumentType from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError @@ -35,16 +40,6 @@ """Coroutine type hint for a list of strings.""" -class DatabaseName(BaseModel): - """Pydantic model for a database name.""" - name: str | None - - -class CollectionName(BaseModel): - """Pydantic model for a collection name.""" - name: str | None - - async def get_id(doc: CoroutineDocument) -> str: """Retrieves the ID of a document as a simple flat string. @@ -111,6 +106,7 @@ class MongoDB: """MongoDB creates these databases by default for self usage.""" @classmethod + @validate_call async def initialize(cls, database_config: DatabaseConfig): """Initializes the motor client. Note that this method has to be awaited! @@ -126,6 +122,9 @@ async def initialize(cls, database_config: DatabaseConfig): On success ``None``. Raises: + ValidationError: + If the method is not called with arguments of valid type. + SystemExit(errno.EIO): If connection is not established, i.e. ``ConnectionFailure``. SystemExit(errno.EIO): @@ -136,7 +135,6 @@ async def initialize(cls, database_config: DatabaseConfig): SystemExit(errno.EIO): If the state is not consistent, i.e. the client is closed or ``None`` but the internal database configurations still exist and are different from the new ones which have been just provided. - SystemExit(errno.ENODATA): If either ``database_config.main_database_name`` or ``database_config.main_collection_name`` does not exist. @@ -224,10 +222,11 @@ def main_database(cls) -> AsyncIOMotorDatabase: return cls.__main_database @classmethod + @validate_call async def get_collection( cls, - database_name: str, - collection_name: str) -> AsyncIOMotorCollection: + database_name: str | None, + collection_name: str | None) -> AsyncIOMotorCollection: """Gets the collection object given its name and the database name in which it resides. Args: @@ -242,7 +241,7 @@ async def get_collection( Raises: ValidationError: - If input args are invalid according to the pydantic. + If the method is not called with arguments of valid type. KeyError: If the database name exists, but it does not include any collection with the given name. @@ -254,9 +253,6 @@ async def get_collection( This method relies on :func:`get_database` to check for the existence of the database which can raise exceptions. Check its documentation for more information. """ - database_name = DatabaseName(name=database_name).name - collection_name = CollectionName(name=collection_name).name - match database_name, collection_name: case None, None: return cls.main_collection() @@ -270,7 +266,8 @@ async def get_collection( raise Collections.WrongTypeError @classmethod - async def get_database(cls, database_name: str) -> AsyncIOMotorDatabase: + @validate_call + async def get_database(cls, database_name: str | None) -> AsyncIOMotorDatabase: """Gets the database object given its name. Args: @@ -281,11 +278,12 @@ async def get_database(cls, database_name: str) -> AsyncIOMotorDatabase: The database object. Raises: - KeyError: + ValidationError: + If the method is not called with arguments of valid type. + + KeyError: If the database name does not exist in the list of database names. """ - database_name = DatabaseName(name=database_name).name - match database_name: case None: return cls.main_database() From 825bff76e2051274ec55a102013a6a2aea49ed39 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 14:32:01 +0200 Subject: [PATCH 12/21] Fix a bug and add a log message The bug had to do with multiple deletion in the recorder. The log message was added to the auto exception handler. --- trolldb/api/api.py | 1 + trolldb/cli.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/trolldb/api/api.py b/trolldb/api/api.py index ec93fd9..d50b0ea 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -102,6 +102,7 @@ async def auto_handler_response_errors(_, exc: ResponseError) -> PlainTextRespon @app.exception_handler(ValidationError) async def auto_handler_pydantic_validation_errors(_, exc: ValidationError) -> PlainTextResponse: """Catches all the exceptions raised as a Pydantic ValidationError.""" + logger.error(f"Response error caught by the API auto exception handler: {exc}") return PlainTextResponse(str(exc), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) async def _serve() -> NoReturn: diff --git a/trolldb/cli.py b/trolldb/cli.py index 89b4520..05fc58b 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -54,7 +54,7 @@ async def record_messages(config: AppConfig) -> None: logger.info(f"Inserted dataset with {len(msg.data["dataset"])} elements.") case "del": deletion_count = await delete_uri_from_collection(collection, msg.data["uri"]) - if deletion_count >= 1: + if deletion_count > 1: logger.error(f"Recorder found multiple deletions for uri: {msg.data["uri"]}!") case _: logger.debug(f"Don't know what to do with {msg.type} message.") From 5ab9e27a93e916320bbc8a2347a0b6e00c7bf877 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 14:48:54 +0200 Subject: [PATCH 13/21] Simplify `test_utils.TestDatabase.match_query()` Add two auxiliary methods which simplify the implementation of the main method. --- trolldb/test_utils/mongodb_database.py | 49 +++++++++++++++++--------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index 2e2e135..3856348 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -253,33 +253,50 @@ def find_min_max_datetime(cls) -> dict[str, dict]: return result + @classmethod + def _query_platform_sensor(cls, document, platform=None, sensor=None) -> bool: + """An auxiliary method to the :func:`TestDatabase.match_query`.""" + should_remove = False + + if platform: + should_remove = platform and document["platform_name"] not in platform + + if sensor and not should_remove: + should_remove = document["sensor"] not in sensor + + return should_remove + + @classmethod + def _query_time(cls, document, time_min=None, time_max=None) -> bool: + """An auxiliary method to the :func:`TestDatabase.match_query`.""" + should_remove = False + + if time_min and time_max and not should_remove: + should_remove = document["end_time"] < time_min or document["start_time"] > time_max + + if time_min and not time_max and not should_remove: + should_remove = document["end_time"] < time_min + + if time_max and not time_min and not should_remove: + should_remove = document["end_time"] > time_max + + return should_remove + @classmethod def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None) -> list[str]: """Matches the given query. We first take all the documents and then progressively remove all that do not match the given queries until we end up with those that match. When a query is ``None``, it does not have any effect on the results. + This method will be used in testing the ``/queries`` route of the API. """ documents = cls.get_all_documents_from_database() buffer = deepcopy(documents) for document in documents: - should_remove = False - if platform: - should_remove = document["platform_name"] not in platform - - if sensor and not should_remove: - should_remove = document["sensor"] not in sensor - - if time_min and time_max and not should_remove: - should_remove = document["end_time"] < time_min or document["start_time"] > time_max - - if time_min and not time_max and not should_remove: - should_remove = document["end_time"] < time_min - - if time_max and not time_min and not should_remove: - should_remove = document["end_time"] > time_max - + should_remove = cls._query_platform_sensor(document, platform, sensor) + if not should_remove: + should_remove = cls._query_time(document, time_min, time_max) if should_remove and document in buffer: buffer.remove(document) From fb78e7c6cf00b663b826aa3d9929579dc526acfe Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 18:37:58 +0200 Subject: [PATCH 14/21] Extract a function to make a single query string in `test_api` The function is named `make_query_string` --- trolldb/tests/tests_api/test_api.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trolldb/tests/tests_api/test_api.py b/trolldb/tests/tests_api/test_api.py index b32be8f..0f28282 100644 --- a/trolldb/tests/tests_api/test_api.py +++ b/trolldb/tests/tests_api/test_api.py @@ -38,7 +38,15 @@ def single_query_is_correct(key: str, value: str | datetime) -> bool: ) -def query_results_are_correct(keys: list[str], values_list: list[list[str | datetime]]) -> bool: +def make_query_string(keys: list[str], values_list: list[list[str] | datetime]) -> str: + """Makes a single query string for all the given queries.""" + query_buffer = [] + for key, value_list in zip(keys, values_list, strict=True): + query_buffer += [f"{key}={value}" for value in value_list] + return "&".join(query_buffer) + + +def query_results_are_correct(keys: list[str], values_list: list[list[str] | datetime]) -> bool: """Checks if the retrieved result from querying the database via the API matches the expected result. There can be more than one query `key/value` pair. @@ -54,11 +62,7 @@ def query_results_are_correct(keys: list[str], values_list: list[list[str | date Returns: A boolean flag indicating whether the retrieved result matches the expected result. """ - # Make a single query string for all queries - query_buffer = [] - for label, value_list in zip(keys, values_list, strict=True): - query_buffer += [f"{label}={value}" for value in value_list] - query_string = "&".join(query_buffer) + query_string = make_query_string(keys, values_list) return ( Counter(http_get(f"queries?{query_string}").json()) == @@ -143,7 +147,7 @@ def test_queries_all(): ("platform", TestDatabase.unique_platform_names), ("sensor", TestDatabase.unique_sensors) ]) -def test_queries_platform_or_sensor(key, values): +def test_queries_platform_or_sensor(key: str, values: list[str]): """Tests the platform and sensor queries, one at a time. There is only a single key in the query, but it has multiple corresponding values. From 91aebe2e7231e01d23a138ccb726b6dda1a4b16c Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 18:47:44 +0200 Subject: [PATCH 15/21] Change the signature of `parse_config` The old name of the function was `parse_config_yaml_file`. The function now accepts a number of different types for the input arg `file`. --- trolldb/cli.py | 5 ++-- trolldb/config/config.py | 50 +++++++++++++++++++++------------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/trolldb/cli.py b/trolldb/cli.py index 05fc58b..cc83adf 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -9,7 +9,7 @@ from posttroll.subscriber import create_subscriber_from_dict_config from pydantic import FilePath -from trolldb.config.config import AppConfig, parse_config_yaml_file +from trolldb.config.config import AppConfig, parse_config from trolldb.database.mongodb import MongoDB, mongodb_context @@ -62,8 +62,7 @@ async def record_messages(config: AppConfig) -> None: async def record_messages_from_config(config_file: FilePath) -> None: """Record messages into the database, getting the configuration from a file.""" - config = parse_config_yaml_file(config_file) - await record_messages(config) + await record_messages(parse_config(config_file)) async def record_messages_from_command_line(args=None) -> None: diff --git a/trolldb/config/config.py b/trolldb/config/config.py index ff1bbc9..f86663f 100644 --- a/trolldb/config/config.py +++ b/trolldb/config/config.py @@ -10,17 +10,18 @@ import errno import sys +from os import PathLike from typing import Any, NamedTuple from bson import ObjectId from bson.errors import InvalidId from loguru import logger -from pydantic import AnyUrl, BaseModel, Field, FilePath, MongoDsn, ValidationError, validate_call +from pydantic import AnyUrl, BaseModel, MongoDsn, PositiveFloat, ValidationError, validate_call from pydantic.functional_validators import AfterValidator from typing_extensions import Annotated -from yaml import safe_load +from yaml import parser, safe_load -Timeout = Annotated[float, Field(ge=0)] +Timeout = PositiveFloat """A type hint for the timeout in seconds (non-negative float).""" @@ -108,37 +109,38 @@ class AppConfig(BaseModel): subscriber: SubscriberConfig -@validate_call -def parse_config_yaml_file(filename: FilePath) -> AppConfig: - """Parses and validates the configurations from a YAML file. +@logger.catch(onerror=lambda _: sys.exit(1)) +def parse_config(file: int | str | bytes | PathLike[str] | PathLike[bytes]) -> AppConfig: + """Parses and validates the configurations from a YAML file (descriptor). Args: - filename: - The filename of a valid YAML file which holds the configurations. + file: + A path-like object (``str`` or ``bytes``) or an integer file descriptor. This will be directly passed to the + ``open()`` function. For example, it can be the filename (absolute or relative) of a valid YAML file which + holds the configurations. Returns: An instance of :class:`AppConfig`. - - Raises: - ParserError: - If the file cannot be properly parsed. - - ValidationError: - If the ``filename`` is not of type ``FilePath``. - - ValidationError: - If the successfully parsed file fails the validation, i.e. its schema or the content does not conform to - :class:`AppConfig`. """ - logger.info("Attempt to parse the YAML file ...") - with open(filename, "r") as file: - config = safe_load(file) + try: + logger.info("Attempt to parse the YAML file ...") + with open(file, "r") as f: + config = safe_load(f) + except parser.ParserError as e: + logger.error(f"The file could not be parsed: {e}") + sys.exit(errno.EIO) + except (OSError, FileNotFoundError) as e: + logger.error(f"The file (descriptor) could not be found or opened: {e}") + sys.exit(errno.EIO) + logger.info("Parsing YAML file is successful.") + try: logger.info("Attempt to validate the parsed YAML file ...") config = AppConfig(**config) - logger.info("Validation of the parsed YAML file is successful.") - return config except ValidationError as e: logger.error(e) sys.exit(errno.EIO) + + logger.info("Validation of the parsed YAML file is successful.") + return config From 598b09f4c75c1a6dcf075ad3b64d6e1f12ef15df Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 18:49:06 +0200 Subject: [PATCH 16/21] Add `@logger.catch` to the `run_server` function This logs and catches unhandled exceptions. --- trolldb/api/api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/trolldb/api/api.py b/trolldb/api/api.py index d50b0ea..b3f486d 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -14,6 +14,7 @@ """ import asyncio +import sys import time from contextlib import contextmanager from multiprocessing import Process @@ -26,7 +27,7 @@ from pydantic import FilePath, ValidationError, validate_call from trolldb.api.routes import api_router -from trolldb.config.config import AppConfig, Timeout, parse_config_yaml_file +from trolldb.config.config import AppConfig, Timeout, parse_config from trolldb.database.mongodb import mongodb_context from trolldb.errors.errors import ResponseError @@ -46,6 +47,7 @@ """These will appear in the auto-generated documentation and are passed to the ``FastAPI`` class as keyword args.""" +@logger.catch(onerror=lambda _: sys.exit(1)) @validate_call def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: """Runs the API server with all the routes and connection to the database. @@ -68,10 +70,6 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: take precedence over ``config``. Finally, :obj:`API_INFO`, which are hard-coded information for the API server, will be concatenated and takes precedence over all. - Raises: - ValidationError: - If the function is not called with arguments of valid type. - Example: .. code-block:: python @@ -81,7 +79,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: """ logger.info("Attempt to run the API server ...") if not isinstance(config, AppConfig): - config = parse_config_yaml_file(config) + config = parse_config(config) # Concatenate the keyword arguments for the API server in the order of precedence (lower to higher). app = FastAPI(**(config.api_server._asdict() | kwargs | API_INFO)) @@ -140,7 +138,7 @@ def api_server_process_context( """ logger.info("Attempt to run the API server process in a context manager ...") if not isinstance(config, AppConfig): - config = parse_config_yaml_file(config) + config = parse_config(config) process = Process(target=run_server, args=(config,)) try: From 95f9de631d649aa0b728a6a6036dd0563b63f2bc Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Thu, 30 May 2024 18:57:42 +0200 Subject: [PATCH 17/21] Log the list of files for the `dataset` messages --- trolldb/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trolldb/cli.py b/trolldb/cli.py index cc83adf..dd5f0c2 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -51,7 +51,7 @@ async def record_messages(config: AppConfig) -> None: logger.info(f"Inserted file with uri: {msg.data["uri"]}") case "dataset": await collection.insert_one(msg.data) - logger.info(f"Inserted dataset with {len(msg.data["dataset"])} elements.") + logger.info(f"Inserted dataset with {len(msg.data["dataset"])} elements: {msg.data["dataset"]}") case "del": deletion_count = await delete_uri_from_collection(collection, msg.data["uri"]) if deletion_count > 1: From 89ccff0cba947d6293352e3b28c55baa8523e3ad Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 31 May 2024 10:45:55 +0200 Subject: [PATCH 18/21] Amend `parse_config()` - Remove type hint for `file`. - Amend docstring. - Remove explicit exception handling for parsing and reading the file. --- trolldb/config/config.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/trolldb/config/config.py b/trolldb/config/config.py index f86663f..99ea99e 100644 --- a/trolldb/config/config.py +++ b/trolldb/config/config.py @@ -10,7 +10,6 @@ import errno import sys -from os import PathLike from typing import Any, NamedTuple from bson import ObjectId @@ -19,7 +18,7 @@ from pydantic import AnyUrl, BaseModel, MongoDsn, PositiveFloat, ValidationError, validate_call from pydantic.functional_validators import AfterValidator from typing_extensions import Annotated -from yaml import parser, safe_load +from yaml import safe_load Timeout = PositiveFloat """A type hint for the timeout in seconds (non-negative float).""" @@ -110,29 +109,21 @@ class AppConfig(BaseModel): @logger.catch(onerror=lambda _: sys.exit(1)) -def parse_config(file: int | str | bytes | PathLike[str] | PathLike[bytes]) -> AppConfig: +def parse_config(file) -> AppConfig: """Parses and validates the configurations from a YAML file (descriptor). Args: file: - A path-like object (``str`` or ``bytes``) or an integer file descriptor. This will be directly passed to the - ``open()`` function. For example, it can be the filename (absolute or relative) of a valid YAML file which - holds the configurations. + A `path-like object `_ or an integer file + descriptor. This will be directly passed to the ``open()`` function. For example, it can be the filename + (absolute or relative) of a valid YAML file which holds the configurations. Returns: An instance of :class:`AppConfig`. """ - try: - logger.info("Attempt to parse the YAML file ...") - with open(file, "r") as f: - config = safe_load(f) - except parser.ParserError as e: - logger.error(f"The file could not be parsed: {e}") - sys.exit(errno.EIO) - except (OSError, FileNotFoundError) as e: - logger.error(f"The file (descriptor) could not be found or opened: {e}") - sys.exit(errno.EIO) - + logger.info("Attempt to parse the YAML file ...") + with open(file, "r") as f: + config = safe_load(f) logger.info("Parsing YAML file is successful.") try: From 506f1220eed2cf312ddf071c63da73db7fa59169 Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 31 May 2024 10:50:19 +0200 Subject: [PATCH 19/21] Remove type checks for config in `api.py` As a result, the signature of `run_server()` and `api_server_process_context()` have been changed. They now only accept a `config` object of type `AppConfig`. --- trolldb/api/api.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/trolldb/api/api.py b/trolldb/api/api.py index b3f486d..a960ffe 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -18,16 +18,16 @@ import time from contextlib import contextmanager from multiprocessing import Process -from typing import Any, Generator, NoReturn, Union +from typing import Any, Generator, NoReturn import uvicorn from fastapi import FastAPI, status from fastapi.responses import PlainTextResponse from loguru import logger -from pydantic import FilePath, ValidationError, validate_call +from pydantic import ValidationError from trolldb.api.routes import api_router -from trolldb.config.config import AppConfig, Timeout, parse_config +from trolldb.config.config import AppConfig, Timeout from trolldb.database.mongodb import mongodb_context from trolldb.errors.errors import ResponseError @@ -48,8 +48,7 @@ @logger.catch(onerror=lambda _: sys.exit(1)) -@validate_call -def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: +def run_server(config: AppConfig, **kwargs) -> None: """Runs the API server with all the routes and connection to the database. It first creates a FastAPI application and runs it using `uvicorn `_ which is @@ -58,9 +57,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: Args: config: - The configuration of the application which includes both the server and database configurations. Its type - should be a :class:`FilePath`, which is a valid path to an existing config file which will parsed as a - ``.YAML`` file. + The configuration of the application which includes both the server and database configurations. **kwargs: The keyword arguments are the same as those accepted by the @@ -74,12 +71,12 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: .. code-block:: python from trolldb.api.api import run_server + from trolldb.config.config import parse_config + if __name__ == "__main__": - run_server("config.yaml") + run_server(parse_config("config.yaml")) """ logger.info("Attempt to run the API server ...") - if not isinstance(config, AppConfig): - config = parse_config(config) # Concatenate the keyword arguments for the API server in the order of precedence (lower to higher). app = FastAPI(**(config.api_server._asdict() | kwargs | API_INFO)) @@ -120,8 +117,7 @@ async def _serve() -> NoReturn: @contextmanager -def api_server_process_context( - config: Union[AppConfig, FilePath], startup_time: Timeout = 2) -> Generator[Process, Any, None]: +def api_server_process_context(config: AppConfig, startup_time: Timeout = 2) -> Generator[Process, Any, None]: """A synchronous context manager to run the API server in a separate process (non-blocking). It uses the `multiprocessing `_ package. The main use case @@ -137,9 +133,6 @@ def api_server_process_context( large so that the tests will not time out. """ logger.info("Attempt to run the API server process in a context manager ...") - if not isinstance(config, AppConfig): - config = parse_config(config) - process = Process(target=run_server, args=(config,)) try: process.start() From c42f8cc428e0710e1ca144a84f401d5b2a533bff Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 31 May 2024 11:36:47 +0200 Subject: [PATCH 20/21] Reorder functions in test scripts. The dependencies are now brought after the first main dependent function that uses them. --- trolldb/tests/test_recorder.py | 26 +++---- trolldb/tests/tests_api/test_api.py | 104 ++++++++++++++-------------- 2 files changed, 65 insertions(+), 65 deletions(-) diff --git a/trolldb/tests/test_recorder.py b/trolldb/tests/test_recorder.py index 553a8f0..7bacd3d 100644 --- a/trolldb/tests/test_recorder.py +++ b/trolldb/tests/test_recorder.py @@ -62,6 +62,19 @@ def config_file(tmp_path): return create_config_file(tmp_path) +@pytest.mark.parametrize(("function", "args"), [ + (record_messages_from_config, lf("config_file")), + (record_messages_from_command_line, [lf("config_file")]) +]) +async def test_record_from_cli_and_config(tmp_path, file_message, tmp_data_filename, function, args): + """Tests that message recording adds a message to the database either via configs from a file or the CLI.""" + msg = Message.decode(file_message) + with running_prepared_database_context(): + with patched_subscriber_recv([file_message]): + await function(args) + assert await message_in_database_and_delete_count_is_one(msg) + + async def message_in_database_and_delete_count_is_one(msg: Message) -> bool: """Checks if there is exactly one item in the database which matches the data of the message.""" async with mongodb_context(test_app_config.database): @@ -76,19 +89,6 @@ async def message_in_database_and_delete_count_is_one(msg: Message) -> bool: return result == msg.data and deletion_count == 1 -@pytest.mark.parametrize(("function", "args"), [ - (record_messages_from_config, lf("config_file")), - (record_messages_from_command_line, [lf("config_file")]) -]) -async def test_record_from_cli_and_config(tmp_path, file_message, tmp_data_filename, function, args): - """Tests that message recording adds a message to the database either via configs from a file or the CLI.""" - msg = Message.decode(file_message) - with running_prepared_database_context(): - with patched_subscriber_recv([file_message]): - await function(args) - assert await message_in_database_and_delete_count_is_one(msg) - - async def test_record_messages(config_file, tmp_path, file_message, tmp_data_filename): """Tests that message recording adds a message to the database.""" config = AppConfig(**make_test_app_config(tmp_path)) diff --git a/trolldb/tests/tests_api/test_api.py b/trolldb/tests/tests_api/test_api.py index 0f28282..5e47cf2 100644 --- a/trolldb/tests/tests_api/test_api.py +++ b/trolldb/tests/tests_api/test_api.py @@ -20,58 +20,6 @@ main_collection_name = test_app_config.database.main_collection_name -def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool: - """Checks if the test and expected list of collection names match.""" - return Counter(test_collection_names) == Counter(expected_collection_name) - - -def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bool: - """Checks if the test (retrieved from the API) and expected list of (document) ids match.""" - return Counter(test_ids) == Counter(expected_ids) - - -def single_query_is_correct(key: str, value: str | datetime) -> bool: - """Checks if the given single query, denoted by ``key`` matches correctly against the ``value``.""" - return ( - Counter(http_get(f"queries?{key}={value}").json()) == - Counter(TestDatabase.match_query(**{key: value})) - ) - - -def make_query_string(keys: list[str], values_list: list[list[str] | datetime]) -> str: - """Makes a single query string for all the given queries.""" - query_buffer = [] - for key, value_list in zip(keys, values_list, strict=True): - query_buffer += [f"{key}={value}" for value in value_list] - return "&".join(query_buffer) - - -def query_results_are_correct(keys: list[str], values_list: list[list[str] | datetime]) -> bool: - """Checks if the retrieved result from querying the database via the API matches the expected result. - - There can be more than one query `key/value` pair. - - Args: - keys: - A list of all query keys, e.g. ``keys=["platform", "sensor"]`` - - values_list: - A list in which each element is a list of values itself. The `nth` element corresponds to the `nth` key in - the ``keys``. - - Returns: - A boolean flag indicating whether the retrieved result matches the expected result. - """ - query_string = make_query_string(keys, values_list) - - return ( - Counter(http_get(f"queries?{query_string}").json()) == - Counter(TestDatabase.match_query( - **{label: value_list for label, value_list in zip(keys, values_list, strict=True)} - )) - ) - - @pytest.mark.usefixtures("_test_server_fixture") def test_root(): """Checks that the server is up and running, i.e. the root routes responds with 200.""" @@ -120,6 +68,16 @@ def test_collections(): ) +def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool: + """Checks if the test and expected list of collection names match.""" + return Counter(test_collection_names) == Counter(expected_collection_name) + + +def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bool: + """Checks if the test (retrieved from the API) and expected list of (document) ids match.""" + return Counter(test_ids) == Counter(expected_ids) + + @pytest.mark.usefixtures("_test_server_fixture") def test_collections_negative(): """Checks that the non-existing collections cannot be found.""" @@ -159,6 +117,40 @@ def test_queries_platform_or_sensor(key: str, values: list[str]): ) +def make_query_string(keys: list[str], values_list: list[list[str] | datetime]) -> str: + """Makes a single query string for all the given queries.""" + query_buffer = [] + for key, value_list in zip(keys, values_list, strict=True): + query_buffer += [f"{key}={value}" for value in value_list] + return "&".join(query_buffer) + + +def query_results_are_correct(keys: list[str], values_list: list[list[str] | datetime]) -> bool: + """Checks if the retrieved result from querying the database via the API matches the expected result. + + There can be more than one query `key/value` pair. + + Args: + keys: + A list of all query keys, e.g. ``keys=["platform", "sensor"]`` + + values_list: + A list in which each element is a list of values itself. The `nth` element corresponds to the `nth` key in + the ``keys``. + + Returns: + A boolean flag indicating whether the retrieved result matches the expected result. + """ + query_string = make_query_string(keys, values_list) + + return ( + Counter(http_get(f"queries?{query_string}").json()) == + Counter(TestDatabase.match_query( + **{label: value_list for label, value_list in zip(keys, values_list, strict=True)} + )) + ) + + @pytest.mark.usefixtures("_test_server_fixture") def test_queries_mix_platform_sensor(): """Tests a mix of platform and sensor queries.""" @@ -185,3 +177,11 @@ def test_queries_time(): "time_max", time_max ) + + +def single_query_is_correct(key: str, value: str | datetime) -> bool: + """Checks if the given single query, denoted by ``key`` matches correctly against the ``value``.""" + return ( + Counter(http_get(f"queries?{key}={value}").json()) == + Counter(TestDatabase.match_query(**{key: value})) + ) From 781891dab6538c8f21dacfabee652d93c929478c Mon Sep 17 00:00:00 2001 From: Pouria Khalaj Date: Fri, 31 May 2024 11:47:55 +0200 Subject: [PATCH 21/21] Remove `@validate_call` from `MongoDB.initialize` --- trolldb/database/mongodb.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trolldb/database/mongodb.py b/trolldb/database/mongodb.py index 4221aee..4f8f3bc 100644 --- a/trolldb/database/mongodb.py +++ b/trolldb/database/mongodb.py @@ -106,7 +106,6 @@ class MongoDB: """MongoDB creates these databases by default for self usage.""" @classmethod - @validate_call async def initialize(cls, database_config: DatabaseConfig): """Initializes the motor client. Note that this method has to be awaited! @@ -122,9 +121,6 @@ async def initialize(cls, database_config: DatabaseConfig): On success ``None``. Raises: - ValidationError: - If the method is not called with arguments of valid type. - SystemExit(errno.EIO): If connection is not established, i.e. ``ConnectionFailure``. SystemExit(errno.EIO):