diff --git a/api/app.py b/api/app.py index 78a5737..5f0c95f 100644 --- a/api/app.py +++ b/api/app.py @@ -124,7 +124,12 @@ async def get_sub_sources( ) # Add metadata to the response - response.headers["X-Total-Count"] = str(await db.get_sources_sub_table_count(engine=get_engine(), table_id=table_id)) + response.headers["X-Total-Count"] = str( + await db.get_sources_sub_table_count( + engine=get_engine(), + query_params=filter_query_params, + table_id=table_id) + ) return result.to_dict() diff --git a/api/database.py b/api/database.py index 5576a7e..db37e3c 100644 --- a/api/database.py +++ b/api/database.py @@ -172,7 +172,7 @@ async def get_polygon_table_name(engine: AsyncEngine, table_id: int) -> str: raise NoSuchTableError(e) -async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int) -> int: +async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int, query_params: list = None) -> int: async with engine.begin() as conn: # Grabbing a table from the database as it is metadata = MetaData(schema="sources") @@ -181,7 +181,29 @@ async def get_sources_sub_table_count(engine: AsyncEngine, table_id: int) -> int lambda sync_conn: Table(polygon_table, metadata, autoload_with=sync_conn) ) - stmt = select(func.count()).select_from(table) + # Extract filters from the query parameters + query_parser = QueryParser(columns=table.columns, query_params=query_params) + + stmt = None + if query_parser.get_group_by_column() is not None: + + sub_stmt = ( + select(query_parser.get_group_by_column()) + .where(query_parser.where_expressions()) + .group_by(query_parser.get_group_by_column()) + ) + + stmt = select(func.count("*")).select_from(sub_stmt) + else: + stmt = ( + select(func.count()) + .select_from(table) + .where(query_parser.where_expressions()) + ) + + x = str(stmt.compile(compile_kwargs={ + "literal_binds": True + })) result = await conn.execute(stmt) @@ -197,8 +219,6 @@ async def select_sources_sub_table( ) -> SQLResponse: async with engine.begin() as conn: - query_params = query_params if query_params is not None else [] - # Grabbing a table from the database as it is metadata = MetaData(schema="sources") polygon_table = await get_polygon_table_name(engine, table_id) @@ -248,8 +268,6 @@ async def patch_sources_sub_table( engine: AsyncEngine, table_id: int, update_values: dict, query_params: list = None ) -> CursorResult: - query_params = query_params if query_params is not None else [] - async with engine.begin() as conn: # Grabbing a table from the database as it is metadata = MetaData(schema="sources") diff --git a/api/models/geometries.py b/api/models/geometries.py index 39557fd..1c73423 100644 --- a/api/models/geometries.py +++ b/api/models/geometries.py @@ -21,7 +21,6 @@ class PolygonModel(CommonModel): comments: Optional[str] = None t_interval: Optional[Union[int | str]] = None b_interval: Optional[Union[int | str]] = None - geom: Optional[Polygon] = None confidence: Optional[float] = None t_age: Optional[float] = None b_age: Optional[float] = None diff --git a/api/query_parser.py b/api/query_parser.py index df56986..72814ea 100644 --- a/api/query_parser.py +++ b/api/query_parser.py @@ -6,6 +6,7 @@ import urllib.parse from dataclasses import dataclass +from functools import lru_cache import starlette.requests import logging @@ -61,7 +62,12 @@ class QueryParser: VALID_OPERATORS = ["not", "eq", "lt", "le", "gt", "ge", "ne", "like", "in", "is"] - def __init__(self, columns: list[Column], query_params: list[dict]): + def __init__(self, columns: list[Column], query_params: list[dict] | None): + + # If no query params, then set to empty list + if query_params is None: + query_params = [] + self.columns = {c.name: c for c in columns} self.query_params = query_params self.decomposed_query_params = self._decompose_query_params() @@ -85,6 +91,7 @@ def where_expressions(self): else: return and_(*where_expressions) + @lru_cache def get_group_by_column(self): """Returns the group by expressions for the query"""