Skip to content

Commit

Permalink
Add Support for counting grouped and filtered rows
Browse files Browse the repository at this point in the history
  • Loading branch information
CannonLock committed Mar 13, 2024
1 parent 9424a47 commit ad06ea0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
7 changes: 6 additions & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ async def get_sub_sources(
)

# Add metadata to the response
response.headers["X-Total-Count"] = str(await db.get_sources_sub_table_count(engine=get_engine(), table_id=table_id))
response.headers["X-Total-Count"] = str(
await db.get_sources_sub_table_count(
engine=get_engine(),
query_params=filter_query_params,
table_id=table_id)
)

return result.to_dict()

Expand Down
30 changes: 24 additions & 6 deletions api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def get_polygon_table_name(engine: AsyncEngine, table_id: int) -> str:
raise NoSuchTableError(e)


async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int) -> int:
async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int, query_params: list = None) -> int:
async with engine.begin() as conn:
# Grabbing a table from the database as it is
metadata = MetaData(schema="sources")
Expand All @@ -181,7 +181,29 @@ async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int) -> int
lambda sync_conn: Table(polygon_table, metadata, autoload_with=sync_conn)
)

stmt = select(func.count()).select_from(table)
# Extract filters from the query parameters
query_parser = QueryParser(columns=table.columns, query_params=query_params)

stmt = None
if query_parser.get_group_by_column() is not None:

sub_stmt = (
select(query_parser.get_group_by_column())
.where(query_parser.where_expressions())
.group_by(query_parser.get_group_by_column())
)

stmt = select(func.count("*")).select_from(sub_stmt)
else:
stmt = (
select(func.count())
.select_from(table)
.where(query_parser.where_expressions())
)

x = str(stmt.compile(compile_kwargs={
"literal_binds": True
}))

result = await conn.execute(stmt)

Expand All @@ -197,8 +219,6 @@ async def select_sources_sub_table(
) -> SQLResponse:
async with engine.begin() as conn:

query_params = query_params if query_params is not None else []

# Grabbing a table from the database as it is
metadata = MetaData(schema="sources")
polygon_table = await get_polygon_table_name(engine, table_id)
Expand Down Expand Up @@ -248,8 +268,6 @@ async def patch_sources_sub_table(
engine: AsyncEngine, table_id: int, update_values: dict, query_params: list = None
) -> CursorResult:

query_params = query_params if query_params is not None else []

async with engine.begin() as conn:
# Grabbing a table from the database as it is
metadata = MetaData(schema="sources")
Expand Down
1 change: 0 additions & 1 deletion api/models/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class PolygonModel(CommonModel):
comments: Optional[str] = None
t_interval: Optional[Union[int | str]] = None
b_interval: Optional[Union[int | str]] = None
geom: Optional[Polygon] = None
confidence: Optional[float] = None
t_age: Optional[float] = None
b_age: Optional[float] = None
Expand Down
9 changes: 8 additions & 1 deletion api/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import urllib.parse
from dataclasses import dataclass
from functools import lru_cache

import starlette.requests
import logging
Expand Down Expand Up @@ -61,7 +62,12 @@ class QueryParser:

VALID_OPERATORS = ["not", "eq", "lt", "le", "gt", "ge", "ne", "like", "in", "is"]

def __init__(self, columns: list[Column], query_params: list[dict]):
def __init__(self, columns: list[Column], query_params: list[dict] | None):

# If no query params, then set to empty list
if query_params is None:
query_params = []

self.columns = {c.name: c for c in columns}
self.query_params = query_params
self.decomposed_query_params = self._decompose_query_params()
Expand All @@ -85,6 +91,7 @@ def where_expressions(self):
else:
return and_(*where_expressions)

@lru_cache
def get_group_by_column(self):
"""Returns the group by expressions for the query"""

Expand Down

0 comments on commit ad06ea0

Please sign in to comment.