diff --git a/Pipfile b/Pipfile index f9ec9fcae..fd38eed33 100644 --- a/Pipfile +++ b/Pipfile @@ -9,7 +9,7 @@ isort = "==5.13.2" pylint = "==3.1.0" coverage = "~=5.5" # api-server -api-server = {editable = true, path = "./packages/api-server"} +api-server = {editable = true, path = "./packages/api-server", extras = ["postgres"]} httpx = "~=0.26.0" datamodel-code-generator = "==0.25.4" requests = "~=2.25" diff --git a/Pipfile.lock b/Pipfile.lock index 28d7b9e07..3d4a558e9 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "c4d71f36224c943d87d65effedecd697e9e028be42c49bcf26d2e4668d9cd00a" + "sha256": "361c917805a1df51cf540ee7171e6bfabad657eb0afc2bbb3d192b718547d034" }, "pipfile-spec": 6, "requires": { @@ -51,15 +51,18 @@ }, "api-server": { "editable": true, + "extras": [ + "postgres" + ], "path": "./packages/api-server" }, "argcomplete": { "hashes": [ - "sha256:c168c3723482c031df3c207d4ba8fa702717ccb9fc0bfe4117166c1f537b4a54", - "sha256:fd03ff4a5b9e6580569d34b273f741e85cd9e072f3feeeee3eba4891c70eda62" + "sha256:69a79e083a716173e5532e0fa3bef45f793f4e61096cf52b5a42c0211c8b8aa5", + "sha256:c2abcdfe1be8ace47ba777d4fce319eb13bf8ad9dace8d085dcad6eded88057f" ], "markers": "python_version >= '3.8'", - "version": "==3.3.0" + "version": "==3.4.0" }, "astroid": { "hashes": [ @@ -69,6 +72,52 @@ "markers": "python_full_version >= '3.8.0'", "version": "==3.1.0" }, + "asyncpg": { + "hashes": [ + "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9", + "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7", + "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548", + "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23", + "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3", + "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675", + "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe", + "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175", + "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83", + "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385", + "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da", + "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106", + "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870", + "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449", + "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc", + "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178", + "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9", + "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b", + "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169", + "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610", + "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772", + "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2", + "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c", + "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb", + "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac", + "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408", + "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22", + "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb", + "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02", + "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59", + "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8", + "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3", + "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e", + "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4", + "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364", + "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f", + "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775", + "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3", + "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090", + "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810", + "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397" + ], + "version": "==0.29.0" + }, "bidict": { "hashes": [ "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", @@ -400,10 +449,10 @@ }, "email-validator": { "hashes": [ - "sha256:200a70680ba08904be6d1eef729205cc0d687634399a5924d842533efb824b84", - "sha256:97d882d174e2a65732fb43bfce81a3a834cbc1bde8bf419e30ef5ea976370a05" + "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631", + "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7" ], - "version": "==2.1.1" + "version": "==2.2.0" }, "fastapi": { "hashes": [ @@ -782,11 +831,11 @@ }, "python-socketio": { "hashes": [ - "sha256:ae6a1de5c5209ca859dc574dccc8931c4be17ee003e74ce3b8d1306162bb4a37", - "sha256:b9f22a8ff762d7a6e123d16a43ddb1a27d50f07c3c88ea999334f2f89b0ad52b" + "sha256:194af8cdbb7b0768c2e807ba76c7abc288eb5bb85559b7cddee51a6bc7a65737", + "sha256:2a923a831ff70664b7c502df093c423eb6aa93c1ce68b8319e840227a26d8b69" ], "markers": "python_version >= '3.8'", - "version": "==5.11.2" + "version": "==5.11.3" }, "pytz": { "hashes": [ @@ -875,6 +924,7 @@ }, "schedule": { "hashes": [ + "sha256:15fe9c75fe5fd9b9627f3f19cc0ef1420508f9f9a46f45cd0769ef75ede5f0b7", "sha256:5bef4a2a0183abf44046ae0d164cadcac21b1db011bdd8102e4a0c1e91e06a7d" ], "markers": "python_version >= '3.7'", @@ -930,11 +980,11 @@ }, "urllib3": { "hashes": [ - "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d", - "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19" + "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472", + "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168" ], "markers": "python_version >= '3.8'", - "version": "==2.2.1" + "version": "==2.2.2" }, "uvicorn": { "extras": [ diff --git a/packages/api-server/README.md b/packages/api-server/README.md index 405ef6aa3..3c2548a6c 100644 --- a/packages/api-server/README.md +++ b/packages/api-server/README.md @@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected. ### Running unit tests ```bash -npm test +pnpm test +``` + +By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable. + +```bash +RMF_API_SERVER_TEST_DB_URL= pnpm test ``` ### Collecting code coverage ```bash -npm run test:cov +pnpm run test:cov ``` Generate coverage report ```bash -npm run test:report +pnpm run test:report ``` ## Live reload diff --git a/packages/api-server/api_server/dependencies.py b/packages/api-server/api_server/dependencies.py index e21cae0bc..4c429cd73 100644 --- a/packages/api-server/api_server/dependencies.py +++ b/packages/api-server/api_server/dependencies.py @@ -20,7 +20,11 @@ def pagination_query( ) -> Pagination: limit = limit or 100 offset = offset or 0 - return Pagination(limit=limit, offset=offset, order_by=order_by) + return Pagination( + limit=limit, + offset=offset, + order_by=order_by.split(",") if order_by else [], + ) # hacky way to get the sio user diff --git a/packages/api-server/api_server/models/pagination.py b/packages/api-server/api_server/models/pagination.py index b88a78b4c..9832a379d 100644 --- a/packages/api-server/api_server/models/pagination.py +++ b/packages/api-server/api_server/models/pagination.py @@ -1,9 +1,7 @@ -from typing import Optional - from pydantic import BaseModel class Pagination(BaseModel): limit: int offset: int - order_by: Optional[str] + order_by: list[str] diff --git a/packages/api-server/api_server/query.py b/packages/api-server/api_server/query.py index 4a6779ab4..c0ea9db33 100644 --- a/packages/api-server/api_server/query.py +++ b/packages/api-server/api_server/query.py @@ -1,5 +1,3 @@ -import tortoise.functions as tfuncs -from tortoise.expressions import Q from tortoise.queryset import MODEL, QuerySet from api_server.models.pagination import Pagination @@ -8,47 +6,10 @@ def add_pagination( query: QuerySet[MODEL], pagination: Pagination, - field_mappings: dict[str, str] | None = None, - group_by: str | None = None, ) -> QuerySet[MODEL]: - """ - Adds pagination and ordering to a query. If the order field starts with `label=`, it is - assumed to be a label and label sorting will used. In this case, the model must have - a reverse relation named "labels" and the `group_by` param is required. - - :param field_mapping: A dict mapping the order fields to the fields used to build the - query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}` - will order the query result according to `db_field`. - :param group_by: Required when sorting by labels, must be the foreign key column of the label table. - """ - field_mappings = field_mappings or {} - annotations = {} - query = query.limit(pagination.limit).offset(pagination.offset) - if pagination.order_by is not None: - order_fields = [] - order_values = pagination.order_by.split(",") - for v in order_values: - # perform the mapping after stripping the order prefix - order_prefix = "" - order_field = v - if v[0] in ["-", "+"]: - order_prefix = v[0] - order_field = v[1:] - order_field = field_mappings.get(order_field, order_field) - - # add annotations required for sorting by labels - if order_field.startswith("label="): - f = order_field[6:] - annotations[f"label_sort_{f}"] = tfuncs.Max( - "labels__label_value", - _filter=Q(labels__label_name=f), - ) - order_field = f"label_sort_{f}" - - order_fields.append(order_prefix + order_field) - - query = query.annotate(**annotations) - if group_by is not None: - query = query.group_by(group_by) - query = query.order_by(*order_fields) - return query + """Adds pagination and ordering to a query""" + return ( + query.limit(pagination.limit) + .offset(pagination.offset) + .order_by(*pagination.order_by) + ) diff --git a/packages/api-server/api_server/repositories/tasks.py b/packages/api-server/api_server/repositories/tasks.py index 9d52bb276..144bf5ec6 100644 --- a/packages/api-server/api_server/repositories/tasks.py +++ b/packages/api-server/api_server/repositories/tasks.py @@ -2,10 +2,11 @@ from datetime import datetime from typing import Dict, List, Optional, Sequence, Tuple +import tortoise.functions as tfuncs from fastapi import Depends, HTTPException from tortoise.exceptions import FieldError, IntegrityError +from tortoise.expressions import Expression, Q from tortoise.query_utils import Prefetch -from tortoise.queryset import QuerySet from tortoise.transactions import in_transaction from api_server.authenticator import user_dep @@ -18,6 +19,7 @@ TaskEventLog, TaskRequest, TaskState, + TaskStatus, User, ) from api_server.models import tortoise_models as ttm @@ -25,7 +27,6 @@ from api_server.models.rmf_api.task_state import Category, Id, Phase from api_server.models.tortoise_models import TaskRequest as DbTaskRequest from api_server.models.tortoise_models import TaskState as DbTaskState -from api_server.query import add_pagination from api_server.rmf_io import task_events @@ -96,11 +97,85 @@ async def save_task_state(self, task_state: TaskState) -> None: await self.save_task_labels(db_task_state, labels) async def query_task_states( - self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None + self, + task_id: list[str] | None = None, + category: list[str] | None = None, + assigned_to: list[str] | None = None, + start_time_between: tuple[datetime, datetime] | None = None, + finish_time_between: tuple[datetime, datetime] | None = None, + status: list[str] | None = None, + label: Labels | None = None, + pagination: Optional[Pagination] = None, ) -> List[TaskState]: + filters = {} + if task_id is not None: + filters["id___in"] = task_id + if category is not None: + filters["category__in"] = category + if assigned_to is not None: + filters["assigned_to__in"] = assigned_to + if start_time_between is not None: + filters["unix_millis_start_time__gte"] = start_time_between[0] + filters["unix_millis_start_time__lte"] = start_time_between[1] + if finish_time_between is not None: + filters["unix_millis_finish_time__gte"] = finish_time_between[0] + filters["unix_millis_finish_time__lte"] = finish_time_between[1] + if status is not None: + valid_values = [member.value for member in TaskStatus] + filters["status__in"] = [] + for status_string in status: + if status_string not in valid_values: + continue + filters["status__in"].append(TaskStatus(status_string)) + query = DbTaskState.filter(**filters) + + need_group_by = False + label_filters = {} + if label is not None: + label_filters.update( + { + f"label_filter_{k}": tfuncs.Count( + "id_", + _filter=Q(labels__label_name=k, labels__label_value=v), + ) + for k, v in label.root.items() + } + ) + + if len(label_filters) > 0: + filter_gt = {f"{f}__gt": 0 for f in label_filters} + query = query.annotate(**label_filters).filter(**filter_gt) + need_group_by = True + + if pagination: + order_fields: list[str] = [] + annotations: dict[str, Expression] = {} + # add annotations required for sorting by labels + for f in pagination.order_by: + order_prefix = f[0] if f[0] == "-" else "" + order_field = f[1:] if order_prefix == "-" else f + if order_field.startswith("label="): + f = order_field[6:] + annotations[f"label_sort_{f}"] = tfuncs.Max( + "labels__label_value", + _filter=Q(labels__label_name=f), + ) + order_field = f"label_sort_{f}" + + order_fields.append(order_prefix + order_field) + + query = ( + query.annotate(**annotations) + .limit(pagination.limit) + .offset(pagination.offset) + .order_by(*order_fields) + ) + need_group_by = True + + if need_group_by: + query = query.group_by("id_", "labels__state_id") + try: - if pagination: - query = add_pagination(query, pagination, group_by="labels__state_id") # TODO: enforce with authz results = await query.values_list("data") return [TaskState(**r[0]) for r in results] diff --git a/packages/api-server/api_server/routes/tasks/scheduled_tasks.py b/packages/api-server/api_server/routes/tasks/scheduled_tasks.py index 40164d812..0b5b61a3f 100644 --- a/packages/api-server/api_server/routes/tasks/scheduled_tasks.py +++ b/packages/api-server/api_server/routes/tasks/scheduled_tasks.py @@ -134,7 +134,7 @@ async def get_scheduled_tasks( .offset(pagination.offset) ) if pagination.order_by: - q.order_by(*pagination.order_by.split(",")) + q.order_by(*pagination.order_by) results = await q await ttm.ScheduledTask.fetch_for_list(results) return [ScheduledTask.model_validate(x) for x in results] diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index 831b350d3..77059d8e3 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -1,10 +1,8 @@ from datetime import datetime from typing import List, Optional, Tuple, cast -import tortoise.functions as tfuncs from fastapi import Body, Depends, HTTPException, Path, Query from reactivex import operators as rxops -from tortoise.expressions import Q from api_server import models as mdl from api_server.dependencies import ( @@ -15,7 +13,6 @@ start_time_between_query, ) from api_server.fast_io import FastIORouter, SubscriptionRequest -from api_server.models.tortoise_models import TaskState as DbTaskState from api_server.repositories import TaskRepository, task_repo_dep from api_server.response import RawJSONResponse from api_server.rmf_io import task_events, tasks_service @@ -60,51 +57,16 @@ async def query_task_states( ), pagination: mdl.Pagination = Depends(pagination_query), ): - filters = {} - if task_id is not None: - filters["id___in"] = task_id.split(",") - if category is not None: - filters["category__in"] = category.split(",") - if assigned_to is not None: - filters["assigned_to__in"] = assigned_to.split(",") - if start_time_between is not None: - filters["unix_millis_start_time__gte"] = start_time_between[0] - filters["unix_millis_start_time__lte"] = start_time_between[1] - if finish_time_between is not None: - filters["unix_millis_finish_time__gte"] = finish_time_between[0] - filters["unix_millis_finish_time__lte"] = finish_time_between[1] - if status is not None: - valid_values = [member.value for member in mdl.TaskStatus] - filters["status__in"] = [] - for status_string in status.split(","): - if status_string not in valid_values: - continue - filters["status__in"].append(mdl.TaskStatus(status_string)) - query = DbTaskState.filter(**filters) - - label_filters = {} - if label is not None: - labels = mdl.Labels.from_strings(label.split(",")) - label_filters.update( - { - f"label_filter_{k}": tfuncs.Count( - "id_", _filter=Q(labels__label_name=k, labels__label_value=v) - ) - for k, v in labels.root.items() - } - ) - - if len(label_filters) > 0: - filter_gt = {f"{f}__gt": 0 for f in label_filters} - query = ( - query.annotate(**label_filters) - .group_by( - "labels__state_id" - ) # need to group by a related field to make tortoise-orm generate joins - .filter(**filter_gt) - ) - - return await task_repo.query_task_states(query, pagination) + return await task_repo.query_task_states( + task_id=task_id.split(",") if task_id else None, + category=category.split(",") if category else None, + assigned_to=assigned_to.split(",") if assigned_to else None, + start_time_between=start_time_between, + finish_time_between=finish_time_between, + status=status.split(",") if status else None, + label=mdl.Labels.from_strings(label.split(",")) if label else None, + pagination=pagination, + ) @router.get("/{task_id}/state", response_model=mdl.TaskState) diff --git a/packages/api-server/api_server/routes/tasks/test_tasks.py b/packages/api-server/api_server/routes/tasks/test_tasks.py index 6534ccc4f..9a12067a7 100644 --- a/packages/api-server/api_server/routes/tasks/test_tasks.py +++ b/packages/api-server/api_server/routes/tasks/test_tasks.py @@ -36,15 +36,12 @@ def setUpClass(cls): cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids] cls.clsSetupErr: str | None = None - if cls.client.portal is None: - cls.clsSetupErr = "missing client portal, is the client context entered?" - return - + portal = cls.get_portal() repo = TaskRepository(cls.admin_user) for x in cls.task_states: - cls.client.portal.call(repo.save_task_state, x) + portal.call(repo.save_task_state, x) for x in cls.task_logs: - cls.client.portal.call(repo.save_task_log, x) + portal.call(repo.save_task_log, x) def setUp(self): super().setUp() diff --git a/packages/api-server/api_server/routes/test_building_map.py b/packages/api-server/api_server/routes/test_building_map.py index 44765b8b3..1aa1490e8 100644 --- a/packages/api-server/api_server/routes/test_building_map.py +++ b/packages/api-server/api_server/routes/test_building_map.py @@ -1,11 +1,11 @@ -from api_server.rmf_io import rmf_events from api_server.test import AppFixture, make_building_map, try_until class TestBuildingMapRoute(AppFixture): def test_get_building_map(self): building_map = make_building_map() - rmf_events.building_map.on_next(building_map) + portal = self.get_portal() + portal.call(building_map.save) resp = try_until( lambda: self.client.get("/building_map"), lambda x: x.status_code == 200 diff --git a/packages/api-server/api_server/routes/test_dispensers.py b/packages/api-server/api_server/routes/test_dispensers.py index 72b1ee91f..0e135b17b 100644 --- a/packages/api-server/api_server/routes/test_dispensers.py +++ b/packages/api-server/api_server/routes/test_dispensers.py @@ -1,4 +1,3 @@ -import asyncio from typing import List from uuid import uuid4 @@ -12,8 +11,9 @@ def setUpClass(cls): super().setUpClass() cls.dispenser_states = [make_dispenser_state(f"test_{uuid4()}")] + portal = cls.get_portal() for x in cls.dispenser_states: - asyncio.run(x.save()) + portal.call(x.save) def test_get_dispensers(self): resp = self.client.get("/dispensers") diff --git a/packages/api-server/api_server/routes/test_doors.py b/packages/api-server/api_server/routes/test_doors.py index 677f407ee..ec4a29f28 100644 --- a/packages/api-server/api_server/routes/test_doors.py +++ b/packages/api-server/api_server/routes/test_doors.py @@ -1,4 +1,3 @@ -import asyncio from uuid import uuid4 from rmf_door_msgs.msg import DoorMode as RmfDoorMode @@ -12,12 +11,13 @@ class TestDoorsRoute(AppFixture): def setUpClass(cls): super().setUpClass() cls.building_map = make_building_map() - asyncio.run(cls.building_map.save()) + portal = cls.get_portal() + portal.call(cls.building_map.save) cls.door_states = [make_door_state(f"test_{uuid4()}")] for x in cls.door_states: - asyncio.run(x.save()) + portal.call(x.save) def test_get_doors(self): resp = self.client.get("/doors") diff --git a/packages/api-server/api_server/routes/test_ingestors.py b/packages/api-server/api_server/routes/test_ingestors.py index 20916f697..810180508 100644 --- a/packages/api-server/api_server/routes/test_ingestors.py +++ b/packages/api-server/api_server/routes/test_ingestors.py @@ -1,4 +1,3 @@ -import asyncio from typing import List from uuid import uuid4 @@ -12,8 +11,9 @@ def setUpClass(cls): super().setUpClass() cls.ingestor_states = [make_ingestor_state(f"test_{uuid4()}")] + portal = cls.get_portal() for x in cls.ingestor_states: - asyncio.run(x.save()) + portal.call(x.save) def test_get_ingestors(self): resp = self.client.get("/ingestors") diff --git a/packages/api-server/api_server/routes/test_lifts.py b/packages/api-server/api_server/routes/test_lifts.py index 0c8f2c898..516f8533c 100644 --- a/packages/api-server/api_server/routes/test_lifts.py +++ b/packages/api-server/api_server/routes/test_lifts.py @@ -1,4 +1,3 @@ -import asyncio from uuid import uuid4 from rmf_lift_msgs.msg import LiftRequest as RmfLiftRequest @@ -12,11 +11,12 @@ class TestLiftsRoute(AppFixture): def setUpClass(cls): super().setUpClass() cls.building_map = make_building_map() - asyncio.run(cls.building_map.save()) + portal = cls.get_portal() + portal.call(cls.building_map.save) cls.lift_states = [make_lift_state(f"test_{uuid4()}")] for x in cls.lift_states: - asyncio.run(x.save()) + portal.call(x.save) def test_get_lifts(self): resp = self.client.get("/lifts") diff --git a/packages/api-server/api_server/test/__init__.py b/packages/api-server/api_server/test/__init__.py index f7725fb21..cfeecb29b 100644 --- a/packages/api-server/api_server/test/__init__.py +++ b/packages/api-server/api_server/test/__init__.py @@ -5,6 +5,5 @@ from .test_client import TestClient from .test_data import * from .test_fixtures import * -from .test_utils import * test_user = User(username="test_user", is_admin=True) diff --git a/packages/api-server/api_server/test/test_fixtures.py b/packages/api-server/api_server/test/test_fixtures.py index 29d599650..16e96caeb 100644 --- a/packages/api-server/api_server/test/test_fixtures.py +++ b/packages/api-server/api_server/test/test_fixtures.py @@ -11,8 +11,10 @@ from uuid import uuid4 import pydantic +from anyio.abc import BlockingPortal +from tortoise import Tortoise -from api_server.app import app +from api_server.app import app, app_config from api_server.models import User from api_server.routes.admin import PostUsers @@ -85,12 +87,32 @@ async def async_try_until( class AppFixture(unittest.TestCase): @classmethod def setUpClass(cls): + async def clean_db(): + # connect to the db to drop it + await Tortoise.init(db_url=app_config.db_url, modules={"models": []}) + await Tortoise._drop_databases() # pylint: disable=protected-access + # connect to it again to recreate it + await Tortoise.init( + db_url=app_config.db_url, modules={"models": []}, _create_db=True + ) + await Tortoise.close_connections() + + asyncio.run(clean_db()) + cls.admin_user = User(username="admin", is_admin=True) cls.client = TestClient() cls.client.headers["Content-Type"] = "application/json" cls.client.__enter__() cls.addClassCleanup(cls.client.__exit__) + @classmethod + def get_portal(cls) -> BlockingPortal: + if not cls.client.portal: + raise AssertionError( + "missing client portal, is the client context entered?" + ) + return cls.client.portal + @contextlib.contextmanager def subscribe_sio(self, room: str, *, user="admin"): """ diff --git a/packages/api-server/api_server/test/test_utils.py b/packages/api-server/api_server/test/test_utils.py deleted file mode 100644 index 71ae280f1..000000000 --- a/packages/api-server/api_server/test/test_utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Optional, Sequence - -from tortoise import Tortoise - - -async def init_db(models: Optional[Sequence[str]] = None): - models = models or ["api_server.models.tortoise_models"] - await Tortoise.init( - db_url="sqlite://:memory:", - modules={"models": models}, - ) - await Tortoise.generate_schemas() diff --git a/packages/api-server/scripts/sqlite_test_config.py b/packages/api-server/scripts/sqlite_test_config.py deleted file mode 100644 index 70e85c659..000000000 --- a/packages/api-server/scripts/sqlite_test_config.py +++ /dev/null @@ -1,3 +0,0 @@ -from base_test_config import config - -config.update({"db_url": "sqlite://:memory:"}) diff --git a/packages/api-server/scripts/test.py b/packages/api-server/scripts/test.py index 8c60d946b..89df3df30 100644 --- a/packages/api-server/scripts/test.py +++ b/packages/api-server/scripts/test.py @@ -1,9 +1,7 @@ import os import sys -os.environ[ - "RMF_API_SERVER_CONFIG" -] = f"{os.path.dirname(__file__)}/sqlite_test_config.py" +os.environ["RMF_API_SERVER_CONFIG"] = f"{os.path.dirname(__file__)}/test_config.py" import unittest diff --git a/packages/api-server/scripts/base_test_config.py b/packages/api-server/scripts/test_config.py similarity index 65% rename from packages/api-server/scripts/base_test_config.py rename to packages/api-server/scripts/test_config.py index 35fce9a44..6afd866ae 100644 --- a/packages/api-server/scripts/base_test_config.py +++ b/packages/api-server/scripts/test_config.py @@ -4,7 +4,7 @@ here = os.path.dirname(__file__) -test_port = os.environ.get("RMF_SERVER_TEST_PORT", "8000") +test_port = os.environ.get("RMF_API_SERVER_TEST_PORT", "8000") config.update( { "host": "127.0.0.1", @@ -12,5 +12,6 @@ "log_level": "CRITICAL", "jwt_public_key": f"{here}/test.pub", "iss": "test", + "db_url": os.environ.get("RMF_API_SERVER_TEST_DB_URL", "sqlite://:memory:"), } )