diff --git a/packages/api-server/api_server/app.py b/packages/api-server/api_server/app.py index e6721a1bc..0928a0acf 100644 --- a/packages/api-server/api_server/app.py +++ b/packages/api-server/api_server/app.py @@ -134,6 +134,7 @@ async def on_sio_connect( app.include_router( routes.fleets_router, prefix="/fleets", dependencies=[Depends(user_dep)] ) +app.include_router(routes.rios_router, prefix="/rios", dependencies=[Depends(user_dep)]) app.include_router( routes.admin_router, prefix="/admin", dependencies=[Depends(user_dep)] ) diff --git a/packages/api-server/api_server/models/__init__.py b/packages/api-server/api_server/models/__init__.py index d3d9eb544..72016de82 100644 --- a/packages/api-server/api_server/models/__init__.py +++ b/packages/api-server/api_server/models/__init__.py @@ -10,6 +10,7 @@ from .labels import * from .lifts import * from .pagination import * +from .rio import * from .rmf_api.activity_discovery_request import ActivityDiscoveryRequest from .rmf_api.activity_discovery_response import ActivityDiscovery from .rmf_api.cancel_task_request import CancelTaskRequest diff --git a/packages/api-server/api_server/models/rio.py b/packages/api-server/api_server/models/rio.py new file mode 100644 index 000000000..4916ef05e --- /dev/null +++ b/packages/api-server/api_server/models/rio.py @@ -0,0 +1,9 @@ +from typing import Any + +from pydantic import BaseModel + + +class Rio(BaseModel): + id: str + type: str + data: dict[str, Any] diff --git a/packages/api-server/api_server/models/tortoise_models/__init__.py b/packages/api-server/api_server/models/tortoise_models/__init__.py index 384dd978d..1b9821064 100644 --- a/packages/api-server/api_server/models/tortoise_models/__init__.py +++ b/packages/api-server/api_server/models/tortoise_models/__init__.py @@ -17,6 +17,7 @@ from .ingestor_state import IngestorState from .lift_state import LiftState from .log import LogMixin +from .rio import * from .scheduled_task import * from .tasks import ( TaskEventLog, diff --git a/packages/api-server/api_server/models/tortoise_models/rio.py b/packages/api-server/api_server/models/tortoise_models/rio.py new file mode 100644 index 000000000..a05ba02e8 --- /dev/null +++ b/packages/api-server/api_server/models/tortoise_models/rio.py @@ -0,0 +1,8 @@ +from tortoise.fields import CharField, JSONField +from tortoise.models import Model + + +class Rio(Model): + id = CharField(max_length=255, pk=True) + type = CharField(max_length=255, index=True) + data = JSONField() diff --git a/packages/api-server/api_server/rmf_io/__init__.py b/packages/api-server/api_server/rmf_io/__init__.py index 0d3e3a2e8..770026be3 100644 --- a/packages/api-server/api_server/rmf_io/__init__.py +++ b/packages/api-server/api_server/rmf_io/__init__.py @@ -4,6 +4,7 @@ TaskEvents, alert_events, fleet_events, + rio_events, rmf_events, task_events, ) diff --git a/packages/api-server/api_server/rmf_io/events.py b/packages/api-server/api_server/rmf_io/events.py index d92117501..44b3d394a 100644 --- a/packages/api-server/api_server/rmf_io/events.py +++ b/packages/api-server/api_server/rmf_io/events.py @@ -51,3 +51,11 @@ def __init__(self): alert_events = AlertEvents() + + +class RioEvents: + def __init__(self): + self.rios = Subject() + + +rio_events = RioEvents() diff --git a/packages/api-server/api_server/routes/__init__.py b/packages/api-server/api_server/routes/__init__.py index ac3fbd34c..98eb57367 100644 --- a/packages/api-server/api_server/routes/__init__.py +++ b/packages/api-server/api_server/routes/__init__.py @@ -10,4 +10,5 @@ from .internal import router as internal_router from .lifts import router as lifts_router from .main import router as main_router +from .rios import router as rios_router from .tasks import * diff --git a/packages/api-server/api_server/routes/rios.py b/packages/api-server/api_server/routes/rios.py new file mode 100644 index 000000000..79971ca1f --- /dev/null +++ b/packages/api-server/api_server/routes/rios.py @@ -0,0 +1,44 @@ +from typing import Annotated + +from fastapi import Query, Response + +from api_server.fast_io import FastIORouter, SubscriptionRequest +from api_server.models import Rio +from api_server.models.tortoise_models import Rio as DbRio +from api_server.rmf_io import rio_events + +router = FastIORouter(tags=["RIOs"]) + + +@router.get("", response_model=list[Rio]) +async def query_rios( + id_: Annotated[ + str | None, Query(alias="id", description="comma separated list of ids") + ] = None, + type_: Annotated[ + str | None, Query(alias="type", description="comma separated list of types") + ] = None, +): + filters = {} + if id_: + filters["id__in"] = id_.split(",") + if type_: + filters["type__in"] = type_.split(",") + + rios = await DbRio.filter(**filters) + return rios + + +@router.sub("", response_model=Rio) +async def sub_rio(_req: SubscriptionRequest): + return rio_events.rios + + +@router.put("", response_model=None) +async def put_rio(rio: Rio, resp: Response): + rio_dict = rio.dict() + del rio_dict["id"] + _, created = await DbRio.update_or_create(rio_dict, id=rio.id) + if created: + resp.status_code = 201 + rio_events.rios.on_next(rio) diff --git a/packages/api-server/api_server/routes/test_rios.py b/packages/api-server/api_server/routes/test_rios.py new file mode 100644 index 000000000..a13809e46 --- /dev/null +++ b/packages/api-server/api_server/routes/test_rios.py @@ -0,0 +1,68 @@ +# import pydantic + +# from api_server.models import Rio +# from api_server.models.tortoise_models import Rio as DbRio +# from api_server.rmf_io import rio_events +# from api_server.test import AppFixture + + +# @AppFixture.reset_app_before_test +# class TestRiosRoute(AppFixture): +# def test_get_rios(self): +# self.portal.call( +# DbRio(id="test_rio", type="test_type", data={"battery": 1}).save +# ) +# self.portal.call( +# DbRio(id="test_rio2", type="test_type", data={"battery": 0.5}).save +# ) +# self.portal.call( +# DbRio(id="test_rio3", type="test_type3", data={"battery": 0}).save +# ) + +# test_cases = [ +# ("id=test_rio,test_rio2", 2), +# ("id=test_rio,test_rio4", 1), +# ("type=test_type,test_type3", 3), +# ("type=test_type,test_rio", 2), +# ("id=test_rio,test_rio3&type=test_type3", 1), +# ] + +# for tc in test_cases: +# resp = self.client.get(f"/rios?{tc[0]}") +# self.assertEqual(200, resp.status_code, tc) +# rios = pydantic.TypeAdapter(list[Rio]).validate_json(resp.content) +# self.assertEqual(tc[1], len(rios)) + +# def test_sub_rios(self): +# with self.subscribe_sio("/rios") as sub: +# rio_events.rios.on_next( +# Rio(id="test_rio", type="test_type", data={"battery": 1}) +# ) +# rio = Rio(**next(sub)) +# self.assertEqual("test_rio", rio.id) + +# def test_put_rios(self): +# resp = self.client.put( +# "/rios", +# content=Rio( +# id="test_rio", type="test_type", data={"battery": 1} +# ).model_dump_json(), +# ) +# self.assertEqual(201, resp.status_code) + +# rios = self.portal.call(DbRio.all) +# self.assertEqual(1, len(rios)) + +# resp = self.client.put( +# "/rios", +# content=Rio( +# id="test_rio", type="test_type", data={"battery": 0.5} +# ).model_dump_json(), +# ) +# # should return 200 if an existing resource is updated +# self.assertEqual(200, resp.status_code) +# rios = self.portal.call(DbRio.all) +# self.assertEqual(1, len(rios)) +# if not isinstance(rios[0].data, dict): +# self.fail("data should be a dict") +# self.assertEqual(0.5, rios[0].data["battery"]) diff --git a/packages/api-server/api_server/test/test_fixtures.py b/packages/api-server/api_server/test/test_fixtures.py index c63c9695d..0b1201d2c 100644 --- a/packages/api-server/api_server/test/test_fixtures.py +++ b/packages/api-server/api_server/test/test_fixtures.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import enum import inspect import os import os.path @@ -83,8 +84,31 @@ async def async_try_until( class AppFixture(unittest.TestCase): + class InitMode(enum.Enum): + SETUP_CLASS = enum.auto() + SETUP_TEST = enum.auto() + + _init_mode = InitMode.SETUP_CLASS + + @staticmethod + def reset_app_before_test(testcase: type["AppFixture"]): + """ + By default, the app is setup once and remains for the entire test case, + use this to change it so that it resets the app and database before every test. + + Example usage: + ```python3 + @AppFixture.reset_app_before_test + class MyTest(AppFixture): + ... + ``` + """ + # pylint: disable=protected-access + testcase._init_mode = AppFixture.InitMode.SETUP_TEST + return testcase + @classmethod - def setUpClass(cls): + def setUpApp(cls): async def clean_db(): # connect to the db to drop it await Tortoise.init(db_url=app_config.db_url, modules={"models": []}) @@ -101,7 +125,20 @@ async def clean_db(): cls.client = TestClient() cls.client.headers["Content-Type"] = "application/json" cls.client.__enter__() - cls.addClassCleanup(cls.client.__exit__) + + @classmethod + def setUpClass(cls): + if cls._init_mode == AppFixture.InitMode.SETUP_CLASS: + cls.setUpApp() + cls.addClassCleanup(cls.client.__exit__) + + def setUp(self): + if self._init_mode == AppFixture.InitMode.SETUP_TEST: + self.setUpApp() + self.addCleanup(self.client.__exit__) + + self.test_time = 0 + self.portal = self.get_portal() @classmethod def get_portal(cls) -> BlockingPortal: @@ -146,7 +183,7 @@ async def wait_for_msgs(): async def handle_resp(emit_room, msg, *_args, **_kwargs): if emit_room == "subscribe" and not msg["success"]: - raise Exception("Failed to subscribe") + raise Exception("Failed to subscribe", msg) if emit_room == room: async with condition: if isinstance(msg, pydantic.BaseModel): @@ -167,9 +204,6 @@ async def handle_resp(emit_room, msg, *_args, **_kwargs): if connected: portal.call(on_disconnect, "test") - def setUp(self): - self.test_time = 0 - def create_user(self, admin: bool = False): username = f"user_{uuid4().hex}" resp = self.client.post(