Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into koonpeng/jazzy
Browse files Browse the repository at this point in the history
Signed-off-by: Teo Koon Peng <[email protected]>
  • Loading branch information
koonpeng committed Jun 27, 2024
2 parents 5b7af0a + bb88e1c commit a7feba4
Show file tree
Hide file tree
Showing 21 changed files with 217 additions and 159 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ isort = "==5.13.2"
pylint = "==3.1.0"
coverage = "~=5.5"
# api-server
api-server = {editable = true, path = "./packages/api-server"}
api-server = {editable = true, path = "./packages/api-server", extras = ["postgres"]}
httpx = "~=0.26.0"
datamodel-code-generator = "==0.25.4"
requests = "~=2.25"
Expand Down
76 changes: 63 additions & 13 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions packages/api-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected.
### Running unit tests

```bash
npm test
pnpm test
```

By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable.

```bash
RMF_API_SERVER_TEST_DB_URL=<db_url> pnpm test
```

### Collecting code coverage

```bash
npm run test:cov
pnpm run test:cov
```

Generate coverage report
```bash
npm run test:report
pnpm run test:report
```

## Live reload
Expand Down
6 changes: 5 additions & 1 deletion packages/api-server/api_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def pagination_query(
) -> Pagination:
limit = limit or 100
offset = offset or 0
return Pagination(limit=limit, offset=offset, order_by=order_by)
return Pagination(
limit=limit,
offset=offset,
order_by=order_by.split(",") if order_by else [],
)


# hacky way to get the sio user
Expand Down
4 changes: 1 addition & 3 deletions packages/api-server/api_server/models/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from pydantic import BaseModel


class Pagination(BaseModel):
limit: int
offset: int
order_by: Optional[str]
order_by: list[str]
51 changes: 6 additions & 45 deletions packages/api-server/api_server/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import tortoise.functions as tfuncs
from tortoise.expressions import Q
from tortoise.queryset import MODEL, QuerySet

from api_server.models.pagination import Pagination
Expand All @@ -8,47 +6,10 @@
def add_pagination(
query: QuerySet[MODEL],
pagination: Pagination,
field_mappings: dict[str, str] | None = None,
group_by: str | None = None,
) -> QuerySet[MODEL]:
"""
Adds pagination and ordering to a query. If the order field starts with `label=`, it is
assumed to be a label and label sorting will used. In this case, the model must have
a reverse relation named "labels" and the `group_by` param is required.
:param field_mapping: A dict mapping the order fields to the fields used to build the
query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}`
will order the query result according to `db_field`.
:param group_by: Required when sorting by labels, must be the foreign key column of the label table.
"""
field_mappings = field_mappings or {}
annotations = {}
query = query.limit(pagination.limit).offset(pagination.offset)
if pagination.order_by is not None:
order_fields = []
order_values = pagination.order_by.split(",")
for v in order_values:
# perform the mapping after stripping the order prefix
order_prefix = ""
order_field = v
if v[0] in ["-", "+"]:
order_prefix = v[0]
order_field = v[1:]
order_field = field_mappings.get(order_field, order_field)

# add annotations required for sorting by labels
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = query.annotate(**annotations)
if group_by is not None:
query = query.group_by(group_by)
query = query.order_by(*order_fields)
return query
"""Adds pagination and ordering to a query"""
return (
query.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*pagination.order_by)
)
85 changes: 80 additions & 5 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple

import tortoise.functions as tfuncs
from fastapi import Depends, HTTPException
from tortoise.exceptions import FieldError, IntegrityError
from tortoise.expressions import Expression, Q
from tortoise.query_utils import Prefetch
from tortoise.queryset import QuerySet
from tortoise.transactions import in_transaction

from api_server.authenticator import user_dep
Expand All @@ -18,14 +19,14 @@
TaskEventLog,
TaskRequest,
TaskState,
TaskStatus,
User,
)
from api_server.models import tortoise_models as ttm
from api_server.models.rmf_api.log_entry import Tier
from api_server.models.rmf_api.task_state import Category, Id, Phase
from api_server.models.tortoise_models import TaskRequest as DbTaskRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.query import add_pagination
from api_server.rmf_io import task_events


Expand Down Expand Up @@ -96,11 +97,85 @@ async def save_task_state(self, task_state: TaskState) -> None:
await self.save_task_labels(db_task_state, labels)

async def query_task_states(
self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None
self,
task_id: list[str] | None = None,
category: list[str] | None = None,
assigned_to: list[str] | None = None,
start_time_between: tuple[datetime, datetime] | None = None,
finish_time_between: tuple[datetime, datetime] | None = None,
status: list[str] | None = None,
label: Labels | None = None,
pagination: Optional[Pagination] = None,
) -> List[TaskState]:
filters = {}
if task_id is not None:
filters["id___in"] = task_id
if category is not None:
filters["category__in"] = category
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in TaskStatus]
filters["status__in"] = []
for status_string in status:
if status_string not in valid_values:
continue
filters["status__in"].append(TaskStatus(status_string))
query = DbTaskState.filter(**filters)

need_group_by = False
label_filters = {}
if label is not None:
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_",
_filter=Q(labels__label_name=k, labels__label_value=v),
)
for k, v in label.root.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = query.annotate(**label_filters).filter(**filter_gt)
need_group_by = True

if pagination:
order_fields: list[str] = []
annotations: dict[str, Expression] = {}
# add annotations required for sorting by labels
for f in pagination.order_by:
order_prefix = f[0] if f[0] == "-" else ""
order_field = f[1:] if order_prefix == "-" else f
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = (
query.annotate(**annotations)
.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*order_fields)
)
need_group_by = True

if need_group_by:
query = query.group_by("id_", "labels__state_id")

try:
if pagination:
query = add_pagination(query, pagination, group_by="labels__state_id")
# TODO: enforce with authz
results = await query.values_list("data")
return [TaskState(**r[0]) for r in results]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def get_scheduled_tasks(
.offset(pagination.offset)
)
if pagination.order_by:
q.order_by(*pagination.order_by.split(","))
q.order_by(*pagination.order_by)
results = await q
await ttm.ScheduledTask.fetch_for_list(results)
return [ScheduledTask.model_validate(x) for x in results]
Expand Down
Loading

0 comments on commit a7feba4

Please sign in to comment.