Skip to content

Commit

Permalink
Merge pull request #1178 from datajoint/dj-top-1084-continued
Browse files Browse the repository at this point in the history
dj.Top continued (#1084)
  • Loading branch information
dimitri-yatsenko authored Sep 12, 2024
2 parents 0a49595 + 10f2b9f commit ed9a520
Show file tree
Hide file tree
Showing 18 changed files with 698 additions and 145 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"editor.formatOnPaste": false,
"editor.formatOnSave": true,
"editor.formatOnSave": false,
"editor.rulers": [
94
],
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Release notes

### 0.14.3 -- TBD
- Added - `dj.Top` restriction ([#1024](https://github.com/datajoint/datajoint-python/issues/1024)) PR [#1084](https://github.com/datajoint/datajoint-python/pull/1084)
- Fixed - Added encapsulating double quotes to comply with [DOT language](https://graphviz.org/doc/info/lang.html) - PR [#1177](https://github.com/datajoint/datajoint-python/pull/1177)
- Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091)

Expand Down
3 changes: 2 additions & 1 deletion datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"Part",
"Not",
"AndList",
"Top",
"U",
"Diagram",
"Di",
Expand All @@ -61,7 +62,7 @@
from .schemas import VirtualModule, list_schemas
from .table import Table, FreeTable
from .user_tables import Manual, Lookup, Imported, Computed, Part
from .expression import Not, AndList, U
from .expression import Not, AndList, U, Top
from .diagram import Diagram
from .admin import set_password, kill
from .blob import MatCell, MatStruct
Expand Down
31 changes: 31 additions & 0 deletions datajoint/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pandas
import json
from .errors import DataJointError
from typing import Union, List
from dataclasses import dataclass

JSON_PATTERN = re.compile(
r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$"
Expand Down Expand Up @@ -61,6 +63,35 @@ def append(self, restriction):
super().append(restriction)


@dataclass
class Top:
"""
A restriction to the top entities of a query.
In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET
"""

limit: Union[int, None] = 1
order_by: Union[str, List[str]] = "KEY"
offset: int = 0

def __post_init__(self):
self.order_by = self.order_by or ["KEY"]
self.offset = self.offset or 0

if self.limit is not None and not isinstance(self.limit, int):
raise TypeError("Top limit must be an integer")
if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all(
isinstance(r, str) for r in self.order_by
):
raise TypeError("Top order_by attributes must all be strings")
if not isinstance(self.offset, int):
raise TypeError("The offset argument must be an integer")
if self.offset and self.limit is None:
self.limit = 999999999999 # arbitrary large number to allow query
if isinstance(self.order_by, str):
self.order_by = [self.order_by]


class Not:
"""invert restriction"""

Expand Down
8 changes: 5 additions & 3 deletions datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,11 @@ def format_attribute(attr):
return f"`{attr}`"
return f"({attr})"

match = re.match(
r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I
).groupdict()
match = re.match(r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I)
if match is None:
raise DataJointError(f'Table definition syntax error in line "{line}"')
match = match.groupdict()

attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"])
index_sql.append(
"{unique}index ({attrs})".format(
Expand Down
115 changes: 82 additions & 33 deletions datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .preview import preview, repr_html
from .condition import (
AndList,
Top,
Not,
make_condition,
assert_join_compatibility,
Expand Down Expand Up @@ -52,6 +53,7 @@ class QueryExpression:
_connection = None
_heading = None
_support = None
_top = None

# If the query will be using distinct
_distinct = False
Expand Down Expand Up @@ -121,17 +123,33 @@ def where_clause(self):
else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction)
)

def sorting_clauses(self):
if not self._top:
return ""
clause = ", ".join(
_wrap_attributes(
_flatten_attribute_list(self.primary_key, self._top.order_by)
)
)
if clause:
clause = f" ORDER BY {clause}"
if self._top.limit is not None:
clause += f" LIMIT {self._top.limit}{f' OFFSET {self._top.offset}' if self._top.offset else ''}"

return clause

def make_sql(self, fields=None):
"""
Make the SQL SELECT statement.
:param fields: used to explicitly set the select attributes
"""
return "SELECT {distinct}{fields} FROM {from_}{where}".format(
return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format(
distinct="DISTINCT " if self._distinct else "",
fields=self.heading.as_sql(fields or self.heading.names),
from_=self.from_clause(),
where=self.where_clause(),
sorting=self.sorting_clauses(),
)

# --------- query operators -----------
Expand Down Expand Up @@ -189,6 +207,14 @@ def restrict(self, restriction):
string, or an AndList.
"""
attributes = set()
if isinstance(restriction, Top):
result = (
self.make_subquery()
if self._top and not self._top.__eq__(restriction)
else copy.copy(self)
) # make subquery to avoid overwriting existing Top
result._top = restriction
return result
new_condition = make_condition(self, restriction, attributes)
if new_condition is True:
return self # restriction has no effect, return the same object
Expand All @@ -202,8 +228,10 @@ def restrict(self, restriction):
pass # all ok
# If the new condition uses any new attributes, a subquery is required.
# However, Aggregation's HAVING statement works fine with aliased attributes.
need_subquery = isinstance(self, Union) or (
not isinstance(self, Aggregation) and self.heading.new_attributes
need_subquery = (
isinstance(self, Union)
or (not isinstance(self, Aggregation) and self.heading.new_attributes)
or self._top
)
if need_subquery:
result = self.make_subquery()
Expand Down Expand Up @@ -539,19 +567,20 @@ def tail(self, limit=25, **fetch_kwargs):

def __len__(self):
""":return: number of elements in the result set e.g. ``len(q1)``."""
return self.connection.query(
result = self.make_subquery() if self._top else copy.copy(self)
return result.connection.query(
"SELECT {select_} FROM {from_}{where}".format(
select_=(
"count(*)"
if any(self._left)
if any(result._left)
else "count(DISTINCT {fields})".format(
fields=self.heading.as_sql(
self.primary_key, include_aliases=False
fields=result.heading.as_sql(
result.primary_key, include_aliases=False
)
)
),
from_=self.from_clause(),
where=self.where_clause(),
from_=result.from_clause(),
where=result.where_clause(),
)
).fetchone()[0]

Expand Down Expand Up @@ -619,18 +648,12 @@ def __next__(self):
# -- move on to next entry.
return next(self)

def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
def cursor(self, as_dict=False):
"""
See expression.fetch() for input description.
:return: query cursor
"""
if offset and limit is None:
raise DataJointError("limit is required when offset is set")
sql = self.make_sql()
if order_by is not None:
sql += " ORDER BY " + ", ".join(order_by)
if limit is not None:
sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "")
logger.debug(sql)
return self.connection.query(sql, as_dict=as_dict)

Expand Down Expand Up @@ -701,23 +724,26 @@ def make_sql(self, fields=None):
fields = self.heading.as_sql(fields or self.heading.names)
assert self._grouping_attributes or not self.restriction
distinct = set(self.heading.names) == set(self.primary_key)
return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
distinct="DISTINCT " if distinct else "",
fields=fields,
from_=self.from_clause(),
where=self.where_clause(),
group_by=(
""
if not self.primary_key
else (
" GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
+ (
""
if not self.restriction
else " HAVING (%s)" % ")AND(".join(self.restriction)
return (
"SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format(
distinct="DISTINCT " if distinct else "",
fields=fields,
from_=self.from_clause(),
where=self.where_clause(),
group_by=(
""
if not self.primary_key
else (
" GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
+ (
""
if not self.restriction
else " HAVING (%s)" % ")AND(".join(self.restriction)
)
)
)
),
),
sorting=self.sorting_clauses(),
)
)

def __len__(self):
Expand Down Expand Up @@ -776,7 +802,7 @@ def make_sql(self):
):
# no secondary attributes: use UNION DISTINCT
fields = arg1.primary_key
return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format(
return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format(
sql1=(
arg1.make_sql()
if isinstance(arg1, Union)
Expand All @@ -788,6 +814,7 @@ def make_sql(self):
else arg2.make_sql(fields)
),
alias=next(self.__count),
sorting=self.sorting_clauses(),
)
# with secondary attributes, use union of left join with antijoin
fields = self.heading.names
Expand Down Expand Up @@ -939,3 +966,25 @@ def aggr(self, group, **named_attributes):
)

aggregate = aggr # alias for aggr


def _flatten_attribute_list(primary_key, attrs):
"""
:param primary_key: list of attributes in primary key
:param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
:return: generator of attributes where "KEY" is replaced with its component attributes
"""
for a in attrs:
if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
if primary_key:
yield from primary_key
elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
if primary_key:
yield from (q + " DESC" for q in primary_key)
else:
yield a


def _wrap_attributes(attr):
for entry in attr: # wrap attribute names in backquotes
yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE)
46 changes: 10 additions & 36 deletions datajoint/fetch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from functools import partial
from pathlib import Path
import logging
import pandas
import itertools
import re
import json
import numpy as np
import uuid
import numbers

from datajoint.condition import Top
from . import blob, hash
from .errors import DataJointError
from .settings import config
from .utils import safe_write

logger = logging.getLogger(__name__.split(".")[0])


class key:
"""
Expand Down Expand Up @@ -119,21 +117,6 @@ def _get(connection, attr, data, squeeze, download_path):
)


def _flatten_attribute_list(primary_key, attrs):
"""
:param primary_key: list of attributes in primary key
:param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC"
:return: generator of attributes where "KEY" is replaces with its component attributes
"""
for a in attrs:
if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a):
yield from primary_key
elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a):
yield from (q + " DESC" for q in primary_key)
else:
yield a


class Fetch:
"""
A fetch object that handles retrieving elements from the table expression.
Expand Down Expand Up @@ -174,13 +157,13 @@ def __call__(
:param download_path: for fetches that download data, e.g. attachments
:return: the contents of the table in the form of a structured numpy.array or a dict list
"""
if order_by is not None:
# if 'order_by' passed in a string, make into list
if isinstance(order_by, str):
order_by = [order_by]
# expand "KEY" or "KEY DESC"
order_by = list(
_flatten_attribute_list(self._expression.primary_key, order_by)
if offset or order_by or limit:
self._expression = self._expression.restrict(
Top(
limit,
order_by,
offset,
)
)

attrs_as_dict = as_dict and attrs
Expand Down Expand Up @@ -212,13 +195,6 @@ def __call__(
'use "array" or "frame"'.format(format)
)

if limit is None and offset is not None:
logger.warning(
"Offset set, but no limit. Setting limit to a large number. "
"Consider setting a limit explicitly."
)
limit = 8000000000 # just a very large number to effect no limit

get = partial(
_get,
self._expression.connection,
Expand Down Expand Up @@ -257,9 +233,7 @@ def __call__(
]
ret = return_values[0] if len(attrs) == 1 else return_values
else: # fetch all attributes as a numpy.record_array or pandas.DataFrame
cur = self._expression.cursor(
as_dict=as_dict, limit=limit, offset=offset, order_by=order_by
)
cur = self._expression.cursor(as_dict=as_dict)
heading = self._expression.heading
if as_dict:
ret = [
Expand Down
Loading

0 comments on commit ed9a520

Please sign in to comment.