From ece77d6a745d13b0ba616b5958e85e7729675d6c Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 10 May 2023 19:26:03 +0000 Subject: [PATCH 01/71] initial implementation attempt --- datajoint/__init__.py | 3 +- datajoint/condition.py | 18 +++++++++++ datajoint/expression.py | 70 ++++++++++++++++++++++++++--------------- 3 files changed, 65 insertions(+), 26 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index b73ade94a..a1b2befd8 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -37,6 +37,7 @@ "Part", "Not", "AndList", + "Top", "U", "Diagram", "Di", @@ -61,7 +62,7 @@ from .schemas import VirtualModule, list_schemas from .table import Table, FreeTable from .user_tables import Manual, Lookup, Imported, Computed, Part -from .expression import Not, AndList, U +from .expression import Not, AndList, U, Top from .diagram import Diagram from .admin import set_password, kill from .blob import MatCell, MatStruct diff --git a/datajoint/condition.py b/datajoint/condition.py index 80786c84c..680663bce 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -61,6 +61,13 @@ def append(self, restriction): super().append(restriction) +class Top: + def __init__(self, order_by, limit, offset=0): + self.order_by = order_by + self.limit = limit + self.offset = offset + + class Not: """invert restriction""" @@ -183,6 +190,17 @@ def combine_conditions(negate, conditions): return not negate # and empty AndList is True return combine_conditions(negate, conditions=items) + # restrict by Top + if isinstance(condition, Top): + query_expression.top_restriction = dict( + limit=condition.limit, + offset=condition.offset, + order_by=[condition.order_by] + if isinstance(condition.order_by, str) + else condition.order_by, + ) + return True + # restriction by dj.U evaluates to True if isinstance(condition, U): return not negate diff --git a/datajoint/expression.py b/datajoint/expression.py index 25dd2fe40..99a0ec5e4 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -9,6 +9,7 @@ from .preview import preview, repr_html from .condition import ( AndList, + Top, Not, make_condition, assert_join_compatibility, @@ -119,17 +120,34 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) ) - def make_sql(self, fields=None): + def sorting_clauses(self, limit=None, offset=None, order_by=None, no_offset=False): + if hasattr(self, "top_restriction") and self.top_restriction: + limit = self.top_restriction["limit"] + offset = self.top_restriction["offset"] + order_by = self.top_restriction["order_by"] + if offset and limit is None: + raise DataJointError("limit is required when offset is set") + clause = "" + if order_by is not None: + clause += " ORDER BY " + ", ".join(order_by) + if limit is not None: + clause += " LIMIT %d" % limit + ( + " OFFSET %d" % offset if offset and not no_offset else "" + ) + return clause + + def make_sql(self, fields=None, sorting_params={}): """ Make the SQL SELECT statement. :param fields: used to explicitly set the select attributes """ - return "SELECT {distinct}{fields} FROM {from_}{where}".format( + return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), from_=self.from_clause(), where=self.where_clause(), + sorting=self.sorting_clauses(**sorting_params), ) # --------- query operators ----------- @@ -624,11 +642,9 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ if offset and limit is None: raise DataJointError("limit is required when offset is set") - sql = self.make_sql() - if order_by is not None: - sql += " ORDER BY " + ", ".join(order_by) - if limit is not None: - sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") + sql = self.make_sql( + sorting_params=dict(offset=offset, limit=limit, order_by=order_by) + ) logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) @@ -695,25 +711,28 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict) ) - def make_sql(self, fields=None): + def make_sql(self, fields=None, sorting_params={}): fields = self.heading.as_sql(fields or self.heading.names) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) - return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format( - distinct="DISTINCT " if distinct else "", - fields=fields, - from_=self.from_clause(), - where=self.where_clause(), - group_by="" - if not self.primary_key - else ( - " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) - + ( - "" - if not self.restriction - else " HAVING (%s)" % ")AND(".join(self.restriction) - ) - ), + return ( + "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by="" + if not self.primary_key + else ( + " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) + + ( + "" + if not self.restriction + else " HAVING (%s)" % ")AND(".join(self.restriction) + ) + ), + sorting=self.sorting_clauses(**sorting_params), + ) ) def __len__(self): @@ -764,7 +783,7 @@ def create(cls, arg1, arg2): result._support = [arg1, arg2] return result - def make_sql(self): + def make_sql(self, sorting_params={}): arg1, arg2 = self._support if ( not arg1.heading.secondary_attributes @@ -772,7 +791,7 @@ def make_sql(self): ): # no secondary attributes: use UNION DISTINCT fields = arg1.primary_key - return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format( + return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`{sorting}".format( sql1=arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields), @@ -780,6 +799,7 @@ def make_sql(self): if isinstance(arg2, Union) else arg2.make_sql(fields), alias=next(self.__count), + sorting=self.sorting_clauses(**sorting_params), ) # with secondary attributes, use union of left join with antijoin fields = self.heading.names From e32ed8f7f4ce242b316eec70cb38cbbbe1f46da1 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 11 May 2023 18:19:45 +0000 Subject: [PATCH 02/71] subquery --- datajoint/expression.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 99a0ec5e4..31aa5ce0a 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -120,7 +120,7 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) ) - def sorting_clauses(self, limit=None, offset=None, order_by=None, no_offset=False): + def sorting_clauses(self, limit=None, offset=None, order_by=None): if hasattr(self, "top_restriction") and self.top_restriction: limit = self.top_restriction["limit"] offset = self.top_restriction["offset"] @@ -131,9 +131,7 @@ def sorting_clauses(self, limit=None, offset=None, order_by=None, no_offset=Fals if order_by is not None: clause += " ORDER BY " + ", ".join(order_by) if limit is not None: - clause += " LIMIT %d" % limit + ( - " OFFSET %d" % offset if offset and not no_offset else "" - ) + clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause def make_sql(self, fields=None, sorting_params={}): @@ -142,12 +140,21 @@ def make_sql(self, fields=None, sorting_params={}): :param fields: used to explicitly set the select attributes """ - return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( + subquery = None + if hasattr(self, "top_restriction") and self.top_restriction: + subquery = ( + "(SELECT {distinct}{fields} FROM {from_}{sorting}) AS subquery".format( + distinct="DISTINCT " if self._distinct else "", + fields=self.heading.as_sql(fields or self.heading.names), + from_=self.from_clause(), + sorting=self.sorting_clauses(), + ) + ) + return "SELECT {distinct}{fields} FROM {from_}{where}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), - from_=self.from_clause(), + from_=subquery or self.from_clause(), where=self.where_clause(), - sorting=self.sorting_clauses(**sorting_params), ) # --------- query operators ----------- @@ -555,6 +562,15 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" + subquery = None + if hasattr(self, "top_restriction") and self.top_restriction: + subquery = ( + "(SELECT DISTINCT {fields} FROM {from_}{sorting}) AS subquery".format( + fields=self.heading.as_sql(self.heading.names), + from_=self.from_clause(), + sorting=self.sorting_clauses(), + ) + ) return self.connection.query( "SELECT {select_} FROM {from_}{where}".format( select_=( @@ -566,7 +582,7 @@ def __len__(self): ) ) ), - from_=self.from_clause(), + from_=subquery or self.from_clause(), where=self.where_clause(), ) ).fetchone()[0] From fb4dfc0ae66d9b9448010163e852f57ee7e42293 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 11 May 2023 19:14:49 +0000 Subject: [PATCH 03/71] left and right hand restrictions --- datajoint/condition.py | 7 +++++ datajoint/expression.py | 66 ++++++++++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 680663bce..83602c8d1 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -192,6 +192,13 @@ def combine_conditions(negate, conditions): # restrict by Top if isinstance(condition, Top): + if ( + hasattr(query_expression, "top_restriction") + and query_expression.top_restriction + ): + raise DataJointError( + "A QueryExpression may only contain a single dj.Top restriction" + ) query_expression.top_restriction = dict( limit=condition.limit, offset=condition.offset, diff --git a/datajoint/expression.py b/datajoint/expression.py index 31aa5ce0a..b410ff1a4 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -45,7 +45,9 @@ class QueryExpression: """ _restriction = None + _restriction_right = None _restriction_attributes = None + _restriction_right_attributes = None _left = [] # list of booleans True for left joins, False for inner joins _original_heading = None # heading before projections @@ -86,6 +88,13 @@ def restriction(self): self._restriction = AndList() return self._restriction + @property + def restriction_right(self): + """a AndList object of restrictions applied to a dj.Top to produce the result""" + if self._restriction_right is None: + self._restriction_right = AndList() + return self._restriction_right + @property def restriction_attributes(self): """the set of attribute names invoked in the WHERE clause""" @@ -93,6 +102,13 @@ def restriction_attributes(self): self._restriction_attributes = set() return self._restriction_attributes + @property + def restriction_right_attributes(self): + """the set of attribute names invoked in the WHERE clause""" + if self._restriction_right_attributes is None: + self._restriction_right_attributes = set() + return self._restriction_right_attributes + @property def primary_key(self): return self.heading.primary_key @@ -113,9 +129,13 @@ def from_clause(self): ) return clause - def where_clause(self): + def where_clause(self, right=False): return ( "" + if right and not self.restriction_right + else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction_right) + if right + else "" if not self.restriction else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) ) @@ -142,19 +162,20 @@ def make_sql(self, fields=None, sorting_params={}): """ subquery = None if hasattr(self, "top_restriction") and self.top_restriction: - subquery = ( - "(SELECT {distinct}{fields} FROM {from_}{sorting}) AS subquery".format( - distinct="DISTINCT " if self._distinct else "", - fields=self.heading.as_sql(fields or self.heading.names), - from_=self.from_clause(), - sorting=self.sorting_clauses(), - ) + subquery = "(SELECT {distinct}{fields} FROM {from_}{where}{sorting}) AS subquery".format( + distinct="DISTINCT " if self._distinct else "", + fields=self.heading.as_sql(fields or self.heading.names), + from_=self.from_clause(), + sorting=self.sorting_clauses(), + where=self.where_clause(), ) return "SELECT {distinct}{fields} FROM {from_}{where}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), from_=subquery or self.from_clause(), - where=self.where_clause(), + where=self.where_clause() + if not subquery + else self.where_clause(right=True), ) # --------- query operators ----------- @@ -235,8 +256,16 @@ def restrict(self, restriction): result._restriction = AndList( self.restriction ) # copy to preserve the original - result.restriction.append(new_condition) - result.restriction_attributes.update(attributes) + result._restriction_right = AndList( + self.restriction_right + ) # copy to preserve the original + # Distinguish between inner and outer restrictions for queries involving dj.Top + if hasattr(self, "top_restriction") and self.top_restriction: + result.restriction_right.append(new_condition) + result.restriction_right_attributes.update(attributes) + else: + result.restriction.append(new_condition) + result.restriction_attributes.update(attributes) return result def restrict_in_place(self, restriction): @@ -564,12 +593,11 @@ def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" subquery = None if hasattr(self, "top_restriction") and self.top_restriction: - subquery = ( - "(SELECT DISTINCT {fields} FROM {from_}{sorting}) AS subquery".format( - fields=self.heading.as_sql(self.heading.names), - from_=self.from_clause(), - sorting=self.sorting_clauses(), - ) + subquery = "(SELECT DISTINCT {fields} FROM {from_}{where}{sorting}) AS subquery".format( + fields=self.heading.as_sql(self.heading.names), + from_=self.from_clause(), + sorting=self.sorting_clauses(), + where=self.where_clause(), ) return self.connection.query( "SELECT {select_} FROM {from_}{where}".format( @@ -583,7 +611,9 @@ def __len__(self): ) ), from_=subquery or self.from_clause(), - where=self.where_clause(), + where=self.where_clause() + if not subquery + else self.where_clause(right=True), ) ).fetchone()[0] From b601439e5307e415cd445cdb17c67629943e5fb0 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 11 May 2023 20:02:59 +0000 Subject: [PATCH 04/71] remove distinct, limit=None --- datajoint/expression.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index b410ff1a4..866d842e3 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -593,11 +593,13 @@ def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" subquery = None if hasattr(self, "top_restriction") and self.top_restriction: - subquery = "(SELECT DISTINCT {fields} FROM {from_}{where}{sorting}) AS subquery".format( - fields=self.heading.as_sql(self.heading.names), - from_=self.from_clause(), - sorting=self.sorting_clauses(), - where=self.where_clause(), + subquery = ( + "(SELECT {fields} FROM {from_}{where}{sorting}) AS subquery".format( + fields=self.heading.as_sql(self.heading.names), + from_=self.from_clause(), + sorting=self.sorting_clauses(), + where=self.where_clause(), + ) ) return self.connection.query( "SELECT {select_} FROM {from_}{where}".format( From 134095860652dcb48c79d9edb522510a2bbe6162 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 11 May 2023 20:10:40 +0000 Subject: [PATCH 05/71] optional limit --- datajoint/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 83602c8d1..fe6058eee 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -62,7 +62,7 @@ def append(self, restriction): class Top: - def __init__(self, order_by, limit, offset=0): + def __init__(self, order_by, limit=None, offset=0): self.order_by = order_by self.limit = limit self.offset = offset From b67f1fef94b05c0013a894758c6ef5bb2a869f39 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 12 May 2023 20:09:34 +0000 Subject: [PATCH 06/71] recursive top subqueries --- datajoint/condition.py | 20 +++---- datajoint/expression.py | 113 +++++++++++++++++++--------------------- 2 files changed, 61 insertions(+), 72 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index fe6058eee..c65172f57 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -192,19 +192,15 @@ def combine_conditions(negate, conditions): # restrict by Top if isinstance(condition, Top): - if ( - hasattr(query_expression, "top_restriction") - and query_expression.top_restriction - ): - raise DataJointError( - "A QueryExpression may only contain a single dj.Top restriction" + query_expression.top_restriction.append( + dict( + limit=condition.limit, + offset=condition.offset, + order_by=[condition.order_by] + if isinstance(condition.order_by, str) + else condition.order_by, + restriction_index=len(query_expression.restriction), ) - query_expression.top_restriction = dict( - limit=condition.limit, - offset=condition.offset, - order_by=[condition.order_by] - if isinstance(condition.order_by, str) - else condition.order_by, ) return True diff --git a/datajoint/expression.py b/datajoint/expression.py index 866d842e3..6bde3c36c 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -45,11 +45,10 @@ class QueryExpression: """ _restriction = None - _restriction_right = None _restriction_attributes = None - _restriction_right_attributes = None _left = [] # list of booleans True for left joins, False for inner joins _original_heading = None # heading before projections + _top_restriction = None # subclasses or instantiators must provide values _connection = None @@ -88,13 +87,6 @@ def restriction(self): self._restriction = AndList() return self._restriction - @property - def restriction_right(self): - """a AndList object of restrictions applied to a dj.Top to produce the result""" - if self._restriction_right is None: - self._restriction_right = AndList() - return self._restriction_right - @property def restriction_attributes(self): """the set of attribute names invoked in the WHERE clause""" @@ -103,11 +95,11 @@ def restriction_attributes(self): return self._restriction_attributes @property - def restriction_right_attributes(self): - """the set of attribute names invoked in the WHERE clause""" - if self._restriction_right_attributes is None: - self._restriction_right_attributes = set() - return self._restriction_right_attributes + def top_restriction(self): + """the list of top restrictions to be subqeuried""" + if self._top_restriction is None: + self._top_restriction = AndList() + return self._top_restriction @property def primary_key(self): @@ -129,22 +121,16 @@ def from_clause(self): ) return clause - def where_clause(self, right=False): + def where_clause(self, restriction_list=None): return ( - "" - if right and not self.restriction_right - else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction_right) - if right - else "" - if not self.restriction + " WHERE (%s)" % ")AND(".join(str(s) for s in restriction_list) + if restriction_list else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) + if self.restriction + else "" ) def sorting_clauses(self, limit=None, offset=None, order_by=None): - if hasattr(self, "top_restriction") and self.top_restriction: - limit = self.top_restriction["limit"] - offset = self.top_restriction["offset"] - order_by = self.top_restriction["order_by"] if offset and limit is None: raise DataJointError("limit is required when offset is set") clause = "" @@ -154,28 +140,50 @@ def sorting_clauses(self, limit=None, offset=None, order_by=None): clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause + def make_top_subquery(self, tops, fields=None, i=0): + if not tops: + return self.from_clause() + top = tops.pop() + if tops: + start = tops[-1]["restriction_index"] + else: + start = 0 + return "(SELECT {distinct}{fields} FROM {from_}{where}{sorting}) AS top_subquery_{i}".format( + distinct="DISTINCT " if self._distinct else "", + fields=self.heading.as_sql(fields or self.heading.names), + from_=self.make_top_subquery(tops, fields, i + 1), + where=self.where_clause(self.restriction[start : top["restriction_index"]]) + if top["restriction_index"] + else "", + sorting=self.sorting_clauses( + limit=top["limit"], offset=top["offset"], order_by=top["order_by"] + ), + i=i, + ) + def make_sql(self, fields=None, sorting_params={}): """ Make the SQL SELECT statement. :param fields: used to explicitly set the select attributes """ - subquery = None - if hasattr(self, "top_restriction") and self.top_restriction: - subquery = "(SELECT {distinct}{fields} FROM {from_}{where}{sorting}) AS subquery".format( - distinct="DISTINCT " if self._distinct else "", - fields=self.heading.as_sql(fields or self.heading.names), - from_=self.from_clause(), - sorting=self.sorting_clauses(), - where=self.where_clause(), + top_subquery = None + if self.top_restriction: + top_subquery = self.make_top_subquery( + copy.copy(self.top_restriction), fields ) - return "SELECT {distinct}{fields} FROM {from_}{where}".format( + return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), - from_=subquery or self.from_clause(), + from_=top_subquery or self.from_clause(), where=self.where_clause() - if not subquery - else self.where_clause(right=True), + if not top_subquery + else self.where_clause( + self.restriction[self.top_restriction[-1]["restriction_index"] : :] + ) + if self.top_restriction[-1]["restriction_index"] < len(self.restriction) + else "", + sorting=self.sorting_clauses(**sorting_params), ) # --------- query operators ----------- @@ -256,16 +264,8 @@ def restrict(self, restriction): result._restriction = AndList( self.restriction ) # copy to preserve the original - result._restriction_right = AndList( - self.restriction_right - ) # copy to preserve the original - # Distinguish between inner and outer restrictions for queries involving dj.Top - if hasattr(self, "top_restriction") and self.top_restriction: - result.restriction_right.append(new_condition) - result.restriction_right_attributes.update(attributes) - else: - result.restriction.append(new_condition) - result.restriction_attributes.update(attributes) + result.restriction.append(new_condition) + result.restriction_attributes.update(attributes) return result def restrict_in_place(self, restriction): @@ -591,15 +591,10 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" - subquery = None - if hasattr(self, "top_restriction") and self.top_restriction: - subquery = ( - "(SELECT {fields} FROM {from_}{where}{sorting}) AS subquery".format( - fields=self.heading.as_sql(self.heading.names), - from_=self.from_clause(), - sorting=self.sorting_clauses(), - where=self.where_clause(), - ) + top_subquery = None + if self.top_restriction: + top_subquery = self.make_top_subquery( + copy.copy(self.top_restriction), ) return self.connection.query( "SELECT {select_} FROM {from_}{where}".format( @@ -612,10 +607,8 @@ def __len__(self): ) ) ), - from_=subquery or self.from_clause(), - where=self.where_clause() - if not subquery - else self.where_clause(right=True), + from_=top_subquery or self.from_clause(), + where=self.where_clause(), ) ).fetchone()[0] From 90dedd970a18dd2fb05787948144c1c844b2d966 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 15 May 2023 19:53:07 +0000 Subject: [PATCH 07/71] imeplement with `make_subquery` --- datajoint/condition.py | 12 +----- datajoint/expression.py | 89 ++++++++++++++++------------------------- 2 files changed, 35 insertions(+), 66 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index c65172f57..7d635f5af 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -192,17 +192,7 @@ def combine_conditions(negate, conditions): # restrict by Top if isinstance(condition, Top): - query_expression.top_restriction.append( - dict( - limit=condition.limit, - offset=condition.offset, - order_by=[condition.order_by] - if isinstance(condition.order_by, str) - else condition.order_by, - restriction_index=len(query_expression.restriction), - ) - ) - return True + return query_expression.make_subquery(top_restriction=condition) # restriction by dj.U evaluates to True if isinstance(condition, U): diff --git a/datajoint/expression.py b/datajoint/expression.py index 6bde3c36c..aedcf259b 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -48,12 +48,12 @@ class QueryExpression: _restriction_attributes = None _left = [] # list of booleans True for left joins, False for inner joins _original_heading = None # heading before projections - _top_restriction = None # subclasses or instantiators must provide values _connection = None _heading = None _support = None + _top = None # If the query will be using distinct _distinct = False @@ -75,6 +75,11 @@ def heading(self): """a dj.Heading object, reflects the effects of the projection operator .proj""" return self._heading + @property + def top(self): + """a dj.top object, reflects the effects of order by, limit, and offset""" + return self._top + @property def original_heading(self): """a dj.Heading object reflecting the attributes before projection""" @@ -94,13 +99,6 @@ def restriction_attributes(self): self._restriction_attributes = set() return self._restriction_attributes - @property - def top_restriction(self): - """the list of top restrictions to be subqeuried""" - if self._top_restriction is None: - self._top_restriction = AndList() - return self._top_restriction - @property def primary_key(self): return self.heading.primary_key @@ -109,7 +107,17 @@ def primary_key(self): def from_clause(self): support = ( - "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) + "(" + + src.make_sql( + sorting_params=dict( + order_by=self.top["order_by"], + limit=self.top["limit"], + offset=self.top["offset"], + ) + if self.top + else {} + ) + + ") as `$%x`" % next(self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support @@ -121,13 +129,11 @@ def from_clause(self): ) return clause - def where_clause(self, restriction_list=None): + def where_clause(self): return ( - " WHERE (%s)" % ")AND(".join(str(s) for s in restriction_list) - if restriction_list + "" + if not self.restriction else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) - if self.restriction - else "" ) def sorting_clauses(self, limit=None, offset=None, order_by=None): @@ -140,59 +146,35 @@ def sorting_clauses(self, limit=None, offset=None, order_by=None): clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause - def make_top_subquery(self, tops, fields=None, i=0): - if not tops: - return self.from_clause() - top = tops.pop() - if tops: - start = tops[-1]["restriction_index"] - else: - start = 0 - return "(SELECT {distinct}{fields} FROM {from_}{where}{sorting}) AS top_subquery_{i}".format( - distinct="DISTINCT " if self._distinct else "", - fields=self.heading.as_sql(fields or self.heading.names), - from_=self.make_top_subquery(tops, fields, i + 1), - where=self.where_clause(self.restriction[start : top["restriction_index"]]) - if top["restriction_index"] - else "", - sorting=self.sorting_clauses( - limit=top["limit"], offset=top["offset"], order_by=top["order_by"] - ), - i=i, - ) - def make_sql(self, fields=None, sorting_params={}): """ Make the SQL SELECT statement. :param fields: used to explicitly set the select attributes """ - top_subquery = None - if self.top_restriction: - top_subquery = self.make_top_subquery( - copy.copy(self.top_restriction), fields - ) return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), - from_=top_subquery or self.from_clause(), - where=self.where_clause() - if not top_subquery - else self.where_clause( - self.restriction[self.top_restriction[-1]["restriction_index"] : :] - ) - if self.top_restriction[-1]["restriction_index"] < len(self.restriction) - else "", + from_=self.from_clause(), + where=self.where_clause(), sorting=self.sorting_clauses(**sorting_params), ) # --------- query operators ----------- - def make_subquery(self): + def make_subquery(self, top_restriction={}): """create a new SELECT statement where self is the FROM clause""" result = QueryExpression() result._connection = self.connection result._support = [self] result._heading = self.heading.make_subquery_heading() + if top_restriction: + result._top = dict( + limit=top_restriction.limit, + offset=top_restriction.offset, + order_by=[top_restriction.order_by] + if isinstance(top_restriction.order_by, str) + else top_restriction.order_by, + ) return result def restrict(self, restriction): @@ -242,6 +224,8 @@ def restrict(self, restriction): """ attributes = set() new_condition = make_condition(self, restriction, attributes) + if isinstance(new_condition, QueryExpression): + return new_condition if new_condition is True: return self # restriction has no effect, return the same object # check that all attributes in condition are present in the query @@ -591,11 +575,6 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" - top_subquery = None - if self.top_restriction: - top_subquery = self.make_top_subquery( - copy.copy(self.top_restriction), - ) return self.connection.query( "SELECT {select_} FROM {from_}{where}".format( select_=( @@ -607,7 +586,7 @@ def __len__(self): ) ) ), - from_=top_subquery or self.from_clause(), + from_=self.from_clause(), where=self.where_clause(), ) ).fetchone()[0] From 2df9a4605b6b740bd8bb528fae1151a27ad01a44 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 18:48:53 +0000 Subject: [PATCH 08/71] apply top from self instead of params --- datajoint/expression.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index aedcf259b..b979b178e 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -107,17 +107,7 @@ def primary_key(self): def from_clause(self): support = ( - "(" - + src.make_sql( - sorting_params=dict( - order_by=self.top["order_by"], - limit=self.top["limit"], - offset=self.top["offset"], - ) - if self.top - else {} - ) - + ") as `$%x`" % next(self._subquery_alias_count) + "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support @@ -137,6 +127,10 @@ def where_clause(self): ) def sorting_clauses(self, limit=None, offset=None, order_by=None): + if self.top and not (limit or offset or order_by): + limit = self.top["limit"] + offset = self.top["offset"] + order_by = self.top["order_by"] if offset and limit is None: raise DataJointError("limit is required when offset is set") clause = "" @@ -168,7 +162,7 @@ def make_subquery(self, top_restriction={}): result._support = [self] result._heading = self.heading.make_subquery_heading() if top_restriction: - result._top = dict( + self._top = dict( limit=top_restriction.limit, offset=top_restriction.offset, order_by=[top_restriction.order_by] From 07003b181325a90aa8c2ea41a25626af6225d8fb Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 18:54:27 +0000 Subject: [PATCH 09/71] remove sorting params --- datajoint/expression.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index b979b178e..1f57c77be 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -126,11 +126,13 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) ) - def sorting_clauses(self, limit=None, offset=None, order_by=None): - if self.top and not (limit or offset or order_by): + def sorting_clauses(self): + if self.top: limit = self.top["limit"] offset = self.top["offset"] order_by = self.top["order_by"] + else: + return "" if offset and limit is None: raise DataJointError("limit is required when offset is set") clause = "" @@ -140,7 +142,7 @@ def sorting_clauses(self, limit=None, offset=None, order_by=None): clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause - def make_sql(self, fields=None, sorting_params={}): + def make_sql(self, fields=None): """ Make the SQL SELECT statement. @@ -151,7 +153,7 @@ def make_sql(self, fields=None, sorting_params={}): fields=self.heading.as_sql(fields or self.heading.names), from_=self.from_clause(), where=self.where_clause(), - sorting=self.sorting_clauses(**sorting_params), + sorting=self.sorting_clauses(), ) # --------- query operators ----------- @@ -652,13 +654,16 @@ def __next__(self): def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ See expression.fetch() for input description. - :return: query cursor + :return: query cursor` """ if offset and limit is None: raise DataJointError("limit is required when offset is set") - sql = self.make_sql( - sorting_params=dict(offset=offset, limit=limit, order_by=order_by) + self._top = dict( + offset=offset, + limit=limit, + order_by=order_by, ) + sql = self.make_sql() logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) @@ -725,7 +730,7 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict) ) - def make_sql(self, fields=None, sorting_params={}): + def make_sql(self, fields=None): fields = self.heading.as_sql(fields or self.heading.names) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) @@ -745,7 +750,7 @@ def make_sql(self, fields=None, sorting_params={}): else " HAVING (%s)" % ")AND(".join(self.restriction) ) ), - sorting=self.sorting_clauses(**sorting_params), + sorting=self.sorting_clauses(), ) ) @@ -797,7 +802,7 @@ def create(cls, arg1, arg2): result._support = [arg1, arg2] return result - def make_sql(self, sorting_params={}): + def make_sql(self): arg1, arg2 = self._support if ( not arg1.heading.secondary_attributes @@ -813,7 +818,7 @@ def make_sql(self, sorting_params={}): if isinstance(arg2, Union) else arg2.make_sql(fields), alias=next(self.__count), - sorting=self.sorting_clauses(**sorting_params), + sorting=self.sorting_clauses(), ) # with secondary attributes, use union of left join with antijoin fields = self.heading.names From 3e518df52e73fa13691ad92f9560668556bb4704 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 20:30:44 +0000 Subject: [PATCH 10/71] simplify make_subquery --- datajoint/condition.py | 9 ++++++++- datajoint/expression.py | 10 +--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 7d635f5af..2da4baee1 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -192,7 +192,14 @@ def combine_conditions(negate, conditions): # restrict by Top if isinstance(condition, Top): - return query_expression.make_subquery(top_restriction=condition) + query_expression._top = dict( + limit=condition.limit, + offset=condition.offset, + order_by=[condition.order_by] + if isinstance(condition.order_by, str) + else condition.order_by, + ) + return query_expression.make_subquery() # restriction by dj.U evaluates to True if isinstance(condition, U): diff --git a/datajoint/expression.py b/datajoint/expression.py index 1f57c77be..3f9ba5980 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -157,20 +157,12 @@ def make_sql(self, fields=None): ) # --------- query operators ----------- - def make_subquery(self, top_restriction={}): + def make_subquery(self): """create a new SELECT statement where self is the FROM clause""" result = QueryExpression() result._connection = self.connection result._support = [self] result._heading = self.heading.make_subquery_heading() - if top_restriction: - self._top = dict( - limit=top_restriction.limit, - offset=top_restriction.offset, - order_by=[top_restriction.order_by] - if isinstance(top_restriction.order_by, str) - else top_restriction.order_by, - ) return result def restrict(self, restriction): From 32491a6357996849d2cf3a0260dea354ad8a251f Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 20:43:23 +0000 Subject: [PATCH 11/71] optional order_by --- datajoint/condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 2da4baee1..b0f7cc99b 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -62,7 +62,7 @@ def append(self, restriction): class Top: - def __init__(self, order_by, limit=None, offset=0): + def __init__(self, order_by=None, limit=None, offset=0): self.order_by = order_by self.limit = limit self.offset = offset From 01be0ab08b2505f4666647a10b7ab8188d406cf9 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 20:47:44 +0000 Subject: [PATCH 12/71] handle top in QE.restrict --- datajoint/condition.py | 11 ----------- datajoint/expression.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index b0f7cc99b..2b0464acb 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -190,17 +190,6 @@ def combine_conditions(negate, conditions): return not negate # and empty AndList is True return combine_conditions(negate, conditions=items) - # restrict by Top - if isinstance(condition, Top): - query_expression._top = dict( - limit=condition.limit, - offset=condition.offset, - order_by=[condition.order_by] - if isinstance(condition.order_by, str) - else condition.order_by, - ) - return query_expression.make_subquery() - # restriction by dj.U evaluates to True if isinstance(condition, U): return not negate diff --git a/datajoint/expression.py b/datajoint/expression.py index 3f9ba5980..0288fb892 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -211,9 +211,16 @@ def restrict(self, restriction): string, or an AndList. """ attributes = set() + if isinstance(restriction, Top): + self._top = dict( + limit=restriction.limit, + offset=restriction.offset, + order_by=[restriction.order_by] + if isinstance(restriction.order_by, str) + else restriction.order_by, + ) + return self.make_subquery() new_condition = make_condition(self, restriction, attributes) - if isinstance(new_condition, QueryExpression): - return new_condition if new_condition is True: return self # restriction has no effect, return the same object # check that all attributes in condition are present in the query From e86ca94ab93772117089d64fc03377b499cd2b23 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 16 May 2023 21:12:27 +0000 Subject: [PATCH 13/71] dataclass decorator --- datajoint/condition.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 2b0464acb..7feb34dd2 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -10,6 +10,8 @@ import pandas import json from .errors import DataJointError +from typing import Union, List +from dataclasses import dataclass JSON_PATTERN = re.compile( r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$" @@ -61,11 +63,11 @@ def append(self, restriction): super().append(restriction) +@dataclass class Top: - def __init__(self, order_by=None, limit=None, offset=0): - self.order_by = order_by - self.limit = limit - self.offset = offset + order_by: Union[str, List[str]] = None + limit: int = None + offset: int = 0 class Not: From 2a6ec2b8b80e8e808dfaf09c8be91d3abdfffa93 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 17 May 2023 15:28:44 +0000 Subject: [PATCH 14/71] oops --- datajoint/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 0288fb892..3b9d7cd08 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -653,7 +653,7 @@ def __next__(self): def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ See expression.fetch() for input description. - :return: query cursor` + :return: query cursor """ if offset and limit is None: raise DataJointError("limit is required when offset is set") From cc5720f0cee3842eeeb4809fb03ed5aae67bbdfe Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 17 May 2023 19:17:08 +0000 Subject: [PATCH 15/71] new top defaults and order by "KEY" --- datajoint/condition.py | 8 +++- datajoint/expression.py | 10 +++-- tests/test_relational_operand.py | 67 ++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 7feb34dd2..b89d9ee12 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -65,8 +65,12 @@ def append(self, restriction): @dataclass class Top: - order_by: Union[str, List[str]] = None - limit: int = None + """ + doc string + """ + + limit: int = 10 + order_by: Union[str, List[str]] = "KEY" offset: int = 0 diff --git a/datajoint/expression.py b/datajoint/expression.py index 3b9d7cd08..2a8a9d0e8 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -78,6 +78,12 @@ def heading(self): @property def top(self): """a dj.top object, reflects the effects of order by, limit, and offset""" + if self._top and self._top["order_by"]: + if isinstance(self._top["order_by"], str): + self._top["order_by"] = [self._top["order_by"]] + if "KEY" in self._top["order_by"]: + i = self._top["order_by"].index("KEY") + self._top["order_by"][i : i + 1] = self.primary_key return self._top @property @@ -215,9 +221,7 @@ def restrict(self, restriction): self._top = dict( limit=restriction.limit, offset=restriction.offset, - order_by=[restriction.order_by] - if isinstance(restriction.order_by, str) - else restriction.order_by, + order_by=restriction.order_by, ) return self.make_subquery() new_condition = make_condition(self, restriction, attributes) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 0611ab267..ab2e07c45 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -14,6 +14,7 @@ ) import datajoint as dj +from datajoint.errors import DataJointError from .schema_simple import ( A, B, @@ -487,6 +488,72 @@ def test_restrictions_by_lists(): ) assert_true(len(w - y) == 0, "incorrect restriction without common attributes") + @staticmethod + def test_restrictions_by_top(): + a = L() & dj.Top() + b = L() & dj.Top(order_by=["cond_in_l", "KEY"]) + x = L() & dj.Top(5, "id_l desc", 4) & "cond_in_l=1" + y = L() & "cond_in_l=1" & dj.Top(5, "id_l desc", 4) + z = ( + L() + & dj.Top(None, order_by="id_l desc") + & "cond_in_l=1" + & dj.Top(5, "id_l desc") + & ("id_l=20", "id_l=16", "id_l=17") + & dj.Top(2, "id_l asc", 1) + ) + assert len(a) == 10 + assert len(b) == 10 + assert len(x) == 1 + assert len(y) == 5 + assert len(z) == 2 + assert a.fetch(as_dict=True) == [ + {"id_l": 0, "cond_in_l": 1}, + {"id_l": 1, "cond_in_l": 1}, + {"id_l": 2, "cond_in_l": 1}, + {"id_l": 3, "cond_in_l": 0}, + {"id_l": 4, "cond_in_l": 0}, + {"id_l": 5, "cond_in_l": 1}, + {"id_l": 6, "cond_in_l": 0}, + {"id_l": 7, "cond_in_l": 0}, + {"id_l": 8, "cond_in_l": 0}, + {"id_l": 9, "cond_in_l": 0}, + ] + assert b.fetch(as_dict=True) == [ + {"id_l": 3, "cond_in_l": 0}, + {"id_l": 4, "cond_in_l": 0}, + {"id_l": 6, "cond_in_l": 0}, + {"id_l": 7, "cond_in_l": 0}, + {"id_l": 8, "cond_in_l": 0}, + {"id_l": 9, "cond_in_l": 0}, + {"id_l": 12, "cond_in_l": 0}, + {"id_l": 13, "cond_in_l": 0}, + {"id_l": 14, "cond_in_l": 0}, + {"id_l": 18, "cond_in_l": 0}, + ] + assert x.fetch(as_dict=True) == [{"id_l": 25, "cond_in_l": 1}] + assert y.fetch(as_dict=True) == [ + {"id_l": 16, "cond_in_l": 1}, + {"id_l": 15, "cond_in_l": 1}, + {"id_l": 11, "cond_in_l": 1}, + {"id_l": 10, "cond_in_l": 1}, + {"id_l": 5, "cond_in_l": 1}, + ] + assert z.fetch(as_dict=True) == [ + {"id_l": 17, "cond_in_l": 1}, + {"id_l": 20, "cond_in_l": 1}, + ] + + @staticmethod + @raises(DataJointError) + def test_top_in_or_list_fails(): + L() & ("cond_in_l=1", dj.Top()) + + @staticmethod + @raises(DataJointError) + def test_top_in_and_list_fails(): + L() & dj.AndList(["cond_in_l=1", dj.Top()]) + @staticmethod def test_datetime(): """Test date retrieval""" From 95124eb20f573d9f7cea210d6f58248b4395c0fe Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 17 May 2023 19:21:24 +0000 Subject: [PATCH 16/71] docstring --- datajoint/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 2a8a9d0e8..de0c6a99e 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -77,7 +77,7 @@ def heading(self): @property def top(self): - """a dj.top object, reflects the effects of order by, limit, and offset""" + """a top object to form the ORDER BY, LIMIT, and OFFSET clauses""" if self._top and self._top["order_by"]: if isinstance(self._top["order_by"], str): self._top["order_by"] = [self._top["order_by"]] From 0fca2cc6439103b2e6b3c5406cc3c0e458b0416f Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 17 May 2023 19:33:00 +0000 Subject: [PATCH 17/71] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56d01fcc8..597505103 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## Release notes ### Upcoming +- Added - `dj.Top` restriction ([#1024](https://github.com/datajoint/datajoint-python/issues/1024)) PR [#1084](https://github.com/datajoint/datajoint-python/pull/1084) - Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991) - Fixed - `.ipynb` output in tutorials is not visible in dark mode ([#1078](https://github.com/datajoint/datajoint-python/issues/1078)) PR [#1080](https://github.com/datajoint/datajoint-python/pull/1080) - Changed - Readme to update links and include example pipeline image From a816b6cb26a72c55c0ae395ee7d6242c1b607d25 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 18 May 2023 16:37:32 +0000 Subject: [PATCH 18/71] docstring --- datajoint/condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index b89d9ee12..208257d1b 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -66,7 +66,8 @@ def append(self, restriction): @dataclass class Top: """ - doc string + A "restriction" to set the sorting clauses of a query. Since it is not a true + restriction, it has no effect on the WHERE clause. """ limit: int = 10 From 4f3ef26e91dd9715304f6fa492e34b35d49b5b91 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 18 May 2023 17:15:39 +0000 Subject: [PATCH 19/71] use Top instead of dict --- datajoint/expression.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index de0c6a99e..909ca755f 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -78,12 +78,12 @@ def heading(self): @property def top(self): """a top object to form the ORDER BY, LIMIT, and OFFSET clauses""" - if self._top and self._top["order_by"]: - if isinstance(self._top["order_by"], str): - self._top["order_by"] = [self._top["order_by"]] - if "KEY" in self._top["order_by"]: - i = self._top["order_by"].index("KEY") - self._top["order_by"][i : i + 1] = self.primary_key + if self._top and self._top.order_by: + if isinstance(self._top.order_by, str): + self._top.order_by = [self._top.order_by] + if "KEY" in self._top.order_by: + i = self._top.order_by.index("KEY") + self._top.order_by[i : i + 1] = self.primary_key return self._top @property @@ -134,9 +134,9 @@ def where_clause(self): def sorting_clauses(self): if self.top: - limit = self.top["limit"] - offset = self.top["offset"] - order_by = self.top["order_by"] + limit = self.top.limit + offset = self.top.offset + order_by = self.top.order_by else: return "" if offset and limit is None: @@ -218,10 +218,10 @@ def restrict(self, restriction): """ attributes = set() if isinstance(restriction, Top): - self._top = dict( - limit=restriction.limit, - offset=restriction.offset, - order_by=restriction.order_by, + self._top = Top( + restriction.limit, + restriction.order_by, + restriction.offset, ) return self.make_subquery() new_condition = make_condition(self, restriction, attributes) @@ -661,10 +661,10 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ if offset and limit is None: raise DataJointError("limit is required when offset is set") - self._top = dict( - offset=offset, - limit=limit, - order_by=order_by, + self._top = Top( + limit, + order_by, + offset, ) sql = self.make_sql() logger.debug(sql) From 6547af8c58aa61792b1cd108358e2ced6a745818 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 18 May 2023 17:22:46 +0000 Subject: [PATCH 20/71] simpler --- datajoint/expression.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 909ca755f..1f1e905d6 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -218,11 +218,7 @@ def restrict(self, restriction): """ attributes = set() if isinstance(restriction, Top): - self._top = Top( - restriction.limit, - restriction.order_by, - restriction.offset, - ) + self._top = restriction return self.make_subquery() new_condition = make_condition(self, restriction, attributes) if new_condition is True: From 722e061be0285943a8677b591bdad36ff5a5bf11 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 19 May 2023 20:49:32 +0000 Subject: [PATCH 21/71] optimize subqeury usage --- datajoint/condition.py | 6 +-- datajoint/expression.py | 90 +++++++++++++++++++------------- tests/test_relational_operand.py | 22 +------- 3 files changed, 60 insertions(+), 58 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 208257d1b..65ef88c9b 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -66,11 +66,11 @@ def append(self, restriction): @dataclass class Top: """ - A "restriction" to set the sorting clauses of a query. Since it is not a true - restriction, it has no effect on the WHERE clause. + A restriction to the top entities of a query. + In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET """ - limit: int = 10 + limit: Union[int, None] = 1 order_by: Union[str, List[str]] = "KEY" offset: int = 0 diff --git a/datajoint/expression.py b/datajoint/expression.py index 1f1e905d6..5253dc583 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -75,17 +75,6 @@ def heading(self): """a dj.Heading object, reflects the effects of the projection operator .proj""" return self._heading - @property - def top(self): - """a top object to form the ORDER BY, LIMIT, and OFFSET clauses""" - if self._top and self._top.order_by: - if isinstance(self._top.order_by, str): - self._top.order_by = [self._top.order_by] - if "KEY" in self._top.order_by: - i = self._top.order_by.index("KEY") - self._top.order_by[i : i + 1] = self.primary_key - return self._top - @property def original_heading(self): """a dj.Heading object reflecting the attributes before projection""" @@ -133,17 +122,26 @@ def where_clause(self): ) def sorting_clauses(self): - if self.top: - limit = self.top.limit - offset = self.top.offset - order_by = self.top.order_by - else: + if not self._top: return "" + limit = self._top.limit + order_by = self._top.order_by or ["KEY"] + offset = self._top.offset or 0 + + if order_by and not ( + isinstance(order_by, str) or all(isinstance(r, str) for r in order_by) + ): + raise DataJointError("All order_by attributes must be strings") if offset and limit is None: raise DataJointError("limit is required when offset is set") - clause = "" - if order_by is not None: - clause += " ORDER BY " + ", ".join(order_by) + + # if 'order_by' passed in a string, make into list + if isinstance(order_by, str): + order_by = [order_by] + # expand "KEY" or "KEY DESC" + order_by = list(_flatten_attribute_list(self.primary_key, order_by)) + + clause = " ORDER BY " + ", ".join(order_by) if limit is not None: clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause @@ -219,7 +217,7 @@ def restrict(self, restriction): attributes = set() if isinstance(restriction, Top): self._top = restriction - return self.make_subquery() + return self new_condition = make_condition(self, restriction, attributes) if new_condition is True: return self # restriction has no effect, return the same object @@ -233,8 +231,10 @@ def restrict(self, restriction): pass # all ok # If the new condition uses any new attributes, a subquery is required. # However, Aggregation's HAVING statement works fine with aliased attributes. - need_subquery = isinstance(self, Union) or ( - not isinstance(self, Aggregation) and self.heading.new_attributes + need_subquery = ( + isinstance(self, Union) + or (not isinstance(self, Aggregation) and self.heading.new_attributes) + or self._top ) if need_subquery: result = self.make_subquery() @@ -570,19 +570,20 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" - return self.connection.query( + result = self.make_subquery() if self._top else copy.copy(self) + return result.connection.query( "SELECT {select_} FROM {from_}{where}".format( select_=( "count(*)" - if any(self._left) + if any(result._left) else "count(DISTINCT {fields})".format( - fields=self.heading.as_sql( - self.primary_key, include_aliases=False + fields=result.heading.as_sql( + result.primary_key, include_aliases=False ) ) ), - from_=self.from_clause(), - where=self.where_clause(), + from_=result.from_clause(), + where=result.where_clause(), ) ).fetchone()[0] @@ -657,14 +658,18 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): """ if offset and limit is None: raise DataJointError("limit is required when offset is set") - self._top = Top( - limit, - order_by, - offset, - ) - sql = self.make_sql() + if offset or order_by or limit: + result = self.make_subquery() if self._top else copy.copy(self) + result._top = Top( + limit, + order_by, + offset, + ) + else: + result = copy.copy(self) + sql = result.make_sql() logger.debug(sql) - return self.connection.query(sql, as_dict=as_dict) + return result.connection.query(sql, as_dict=as_dict) def __repr__(self): """ @@ -969,3 +974,18 @@ def aggr(self, group, **named_attributes): ) aggregate = aggr # alias for aggr + + +def _flatten_attribute_list(primary_key, attrs): + """ + :param primary_key: list of attributes in primary key + :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" + :return: generator of attributes where "KEY" is replaces with its component attributes + """ + for a in attrs: + if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): + yield from primary_key + elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): + yield from (q + " DESC" for q in primary_key) + else: + yield a diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index ab2e07c45..01f67947a 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -502,34 +502,16 @@ def test_restrictions_by_top(): & ("id_l=20", "id_l=16", "id_l=17") & dj.Top(2, "id_l asc", 1) ) - assert len(a) == 10 - assert len(b) == 10 + assert len(a) == 1 + assert len(b) == 1 assert len(x) == 1 assert len(y) == 5 assert len(z) == 2 assert a.fetch(as_dict=True) == [ {"id_l": 0, "cond_in_l": 1}, - {"id_l": 1, "cond_in_l": 1}, - {"id_l": 2, "cond_in_l": 1}, - {"id_l": 3, "cond_in_l": 0}, - {"id_l": 4, "cond_in_l": 0}, - {"id_l": 5, "cond_in_l": 1}, - {"id_l": 6, "cond_in_l": 0}, - {"id_l": 7, "cond_in_l": 0}, - {"id_l": 8, "cond_in_l": 0}, - {"id_l": 9, "cond_in_l": 0}, ] assert b.fetch(as_dict=True) == [ {"id_l": 3, "cond_in_l": 0}, - {"id_l": 4, "cond_in_l": 0}, - {"id_l": 6, "cond_in_l": 0}, - {"id_l": 7, "cond_in_l": 0}, - {"id_l": 8, "cond_in_l": 0}, - {"id_l": 9, "cond_in_l": 0}, - {"id_l": 12, "cond_in_l": 0}, - {"id_l": 13, "cond_in_l": 0}, - {"id_l": 14, "cond_in_l": 0}, - {"id_l": 18, "cond_in_l": 0}, ] assert x.fetch(as_dict=True) == [{"id_l": 25, "cond_in_l": 1}] assert y.fetch(as_dict=True) == [ From f21173cfb0034aa881f728e4b4a48d15e798d525 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 15:38:58 +0000 Subject: [PATCH 22/71] unnecessary copy --- datajoint/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 5253dc583..cb8dfe810 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -666,7 +666,7 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): offset, ) else: - result = copy.copy(self) + result = self sql = result.make_sql() logger.debug(sql) return result.connection.query(sql, as_dict=as_dict) From cade78c5719d6167f65dc35400a7eda16fd9992b Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 16:04:24 +0000 Subject: [PATCH 23/71] remove list --- datajoint/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index cb8dfe810..9bb328ba8 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -139,7 +139,7 @@ def sorting_clauses(self): if isinstance(order_by, str): order_by = [order_by] # expand "KEY" or "KEY DESC" - order_by = list(_flatten_attribute_list(self.primary_key, order_by)) + order_by = _flatten_attribute_list(self.primary_key, order_by) clause = " ORDER BY " + ", ".join(order_by) if limit is not None: From 297077ce015755e8ed0d5277c236137cf8370a13 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 16:08:02 +0000 Subject: [PATCH 24/71] simplify --- datajoint/expression.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 9bb328ba8..d323c2666 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -134,14 +134,13 @@ def sorting_clauses(self): raise DataJointError("All order_by attributes must be strings") if offset and limit is None: raise DataJointError("limit is required when offset is set") - # if 'order_by' passed in a string, make into list if isinstance(order_by, str): order_by = [order_by] - # expand "KEY" or "KEY DESC" - order_by = _flatten_attribute_list(self.primary_key, order_by) - clause = " ORDER BY " + ", ".join(order_by) + clause = " ORDER BY " + ", ".join( + _flatten_attribute_list(self.primary_key, order_by) + ) if limit is not None: clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") return clause From 74f97620ed0ffed3e68857c26c43b3a00f721377 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 16:32:48 +0000 Subject: [PATCH 25/71] handle dj.U.aggr with no PK --- datajoint/expression.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index d323c2666..89cf9d5f2 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -138,8 +138,13 @@ def sorting_clauses(self): if isinstance(order_by, str): order_by = [order_by] - clause = " ORDER BY " + ", ".join( - _flatten_attribute_list(self.primary_key, order_by) + clause = ( + ( + " ORDER BY " + + ", ".join(_flatten_attribute_list(self.primary_key, order_by)) + ) + if self.primary_key + else "" ) if limit is not None: clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") From a5c4c24f1ba6774c3003c71ec185565fe542b201 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 17:00:32 +0000 Subject: [PATCH 26/71] type check in `post_init` --- datajoint/condition.py | 14 ++++++++++++++ datajoint/expression.py | 4 ---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 65ef88c9b..f4ce90f01 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -74,6 +74,20 @@ class Top: order_by: Union[str, List[str]] = "KEY" offset: int = 0 + def __post_init__(self): + if self.limit is not None and not (isinstance(self.limit, int)): + raise DataJointError("Limit must be an integer") + if not ( + isinstance(self.order_by, str) + or ( + hasattr(self.order_by, "__iter__") + and all(isinstance(r, str) for r in self.order_by) + ) + ): + raise DataJointError("All order_by attributes must be strings") + if not (isinstance(self.offset, int)): + raise DataJointError("Offset must be an integer") + class Not: """invert restriction""" diff --git a/datajoint/expression.py b/datajoint/expression.py index 89cf9d5f2..a9c212185 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -128,10 +128,6 @@ def sorting_clauses(self): order_by = self._top.order_by or ["KEY"] offset = self._top.offset or 0 - if order_by and not ( - isinstance(order_by, str) or all(isinstance(r, str) for r in order_by) - ): - raise DataJointError("All order_by attributes must be strings") if offset and limit is None: raise DataJointError("limit is required when offset is set") # if 'order_by' passed in a string, make into list From bab173221846acd52b3697c3f6f4c4300127908c Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 19:23:54 +0000 Subject: [PATCH 27/71] move None conversion to post_init --- datajoint/condition.py | 3 +++ datajoint/expression.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index f4ce90f01..3920d8f13 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -75,6 +75,9 @@ class Top: offset: int = 0 def __post_init__(self): + self.order_by = self.order_by or ["KEY"] + self.offset = self.offset or 0 + if self.limit is not None and not (isinstance(self.limit, int)): raise DataJointError("Limit must be an integer") if not ( diff --git a/datajoint/expression.py b/datajoint/expression.py index a9c212185..daaf9473d 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -125,8 +125,8 @@ def sorting_clauses(self): if not self._top: return "" limit = self._top.limit - order_by = self._top.order_by or ["KEY"] - offset = self._top.offset or 0 + order_by = self._top.order_by + offset = self._top.offset if offset and limit is None: raise DataJointError("limit is required when offset is set") From 16da9607e79553c39f504bb275229f8c9d502e49 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Mon, 22 May 2023 19:24:42 +0000 Subject: [PATCH 28/71] error msg --- datajoint/condition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 3920d8f13..b000246cc 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -79,7 +79,7 @@ def __post_init__(self): self.offset = self.offset or 0 if self.limit is not None and not (isinstance(self.limit, int)): - raise DataJointError("Limit must be an integer") + raise DataJointError("Top limit must be an integer") if not ( isinstance(self.order_by, str) or ( @@ -87,9 +87,9 @@ def __post_init__(self): and all(isinstance(r, str) for r in self.order_by) ) ): - raise DataJointError("All order_by attributes must be strings") + raise DataJointError("Top order_by attributes must all be strings") if not (isinstance(self.offset, int)): - raise DataJointError("Offset must be an integer") + raise DataJointError("Top offset must be an integer") class Not: From dee963e76a2a21016ea6d16def04fe9a3d470774 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 15:44:09 +0000 Subject: [PATCH 29/71] move error to post_init --- datajoint/condition.py | 2 ++ datajoint/expression.py | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index b000246cc..ce3808ae9 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -90,6 +90,8 @@ def __post_init__(self): raise DataJointError("Top order_by attributes must all be strings") if not (isinstance(self.offset, int)): raise DataJointError("Top offset must be an integer") + if self.offset and self.limit is None: + raise DataJointError("Top limit is required when offset is set") class Not: diff --git a/datajoint/expression.py b/datajoint/expression.py index daaf9473d..66639c6d5 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -128,8 +128,6 @@ def sorting_clauses(self): order_by = self._top.order_by offset = self._top.offset - if offset and limit is None: - raise DataJointError("limit is required when offset is set") # if 'order_by' passed in a string, make into list if isinstance(order_by, str): order_by = [order_by] @@ -656,8 +654,6 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): See expression.fetch() for input description. :return: query cursor """ - if offset and limit is None: - raise DataJointError("limit is required when offset is set") if offset or order_by or limit: result = self.make_subquery() if self._top else copy.copy(self) result._top = Top( From c4d07268c07f43847c29768f9d74e0f9b9656f97 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 18:12:21 +0000 Subject: [PATCH 30/71] handle sorting in fetch.py --- datajoint/expression.py | 24 ++++++++++-------------- datajoint/fetch.py | 35 ++++++++++------------------------- 2 files changed, 20 insertions(+), 39 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 66639c6d5..98525133e 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -214,8 +214,13 @@ def restrict(self, restriction): """ attributes = set() if isinstance(restriction, Top): - self._top = restriction - return self + result = ( + self.make_subquery() + if self._top and not self._top.__eq__(restriction) + else copy.copy(self) + ) # make subquery to avoid overwriting existing Top + result._top = restriction + return result new_condition = make_condition(self, restriction, attributes) if new_condition is True: return self # restriction has no effect, return the same object @@ -649,23 +654,14 @@ def __next__(self): # -- move on to next entry. return next(self) - def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): + def cursor(self, as_dict=False): """ See expression.fetch() for input description. :return: query cursor """ - if offset or order_by or limit: - result = self.make_subquery() if self._top else copy.copy(self) - result._top = Top( - limit, - order_by, - offset, - ) - else: - result = self - sql = result.make_sql() + sql = self.make_sql() logger.debug(sql) - return result.connection.query(sql, as_dict=as_dict) + return self.connection.query(sql, as_dict=as_dict) def __repr__(self): """ diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 750939e5e..71c2c0207 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -8,6 +8,8 @@ import numpy as np import uuid import numbers + +from datajoint.condition import Top from . import blob, hash from .errors import DataJointError from .settings import config @@ -119,21 +121,6 @@ def _get(connection, attr, data, squeeze, download_path): ) -def _flatten_attribute_list(primary_key, attrs): - """ - :param primary_key: list of attributes in primary key - :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaces with its component attributes - """ - for a in attrs: - if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): - yield from primary_key - elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): - yield from (q + " DESC" for q in primary_key) - else: - yield a - - class Fetch: """ A fetch object that handles retrieving elements from the table expression. @@ -174,13 +161,13 @@ def __call__( :param download_path: for fetches that download data, e.g. attachments :return: the contents of the table in the form of a structured numpy.array or a dict list """ - if order_by is not None: - # if 'order_by' passed in a string, make into list - if isinstance(order_by, str): - order_by = [order_by] - # expand "KEY" or "KEY DESC" - order_by = list( - _flatten_attribute_list(self._expression.primary_key, order_by) + if offset or order_by or limit: + self._expression = self._expression.restrict( + Top( + limit, + order_by, + offset, + ) ) attrs_as_dict = as_dict and attrs @@ -255,9 +242,7 @@ def __call__( ] ret = return_values[0] if len(attrs) == 1 else return_values else: # fetch all attributes as a numpy.record_array or pandas.DataFrame - cur = self._expression.cursor( - as_dict=as_dict, limit=limit, offset=offset, order_by=order_by - ) + cur = self._expression.cursor(as_dict=as_dict) heading = self._expression.heading if as_dict: ret = [ From 5520d0a7e4ceb3f06450f733e6e177b200ef850f Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 19:24:58 +0000 Subject: [PATCH 31/71] limit to some large number --- datajoint/condition.py | 9 ++++++++- datajoint/fetch.py | 7 ------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index ce3808ae9..8e9e90ad4 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -2,6 +2,7 @@ import inspect import collections +import logging import re import uuid import datetime @@ -13,6 +14,8 @@ from typing import Union, List from dataclasses import dataclass +logger = logging.getLogger(__name__.split(".")[0]) + JSON_PATTERN = re.compile( r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$" ) @@ -91,7 +94,11 @@ def __post_init__(self): if not (isinstance(self.offset, int)): raise DataJointError("Top offset must be an integer") if self.offset and self.limit is None: - raise DataJointError("Top limit is required when offset is set") + logger.warning( + "Offset set, but no limit. Setting limit to a large number. " + "Consider setting a limit explicitly." + ) + self.limit = 18446744073709551615 # Some large number class Not: diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 71c2c0207..6852426f6 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -199,13 +199,6 @@ def __call__( 'use "array" or "frame"'.format(format) ) - if limit is None and offset is not None: - logger.warning( - "Offset set, but no limit. Setting limit to a large number. " - "Consider setting a limit explicitly." - ) - limit = 8000000000 # just a very large number to effect no limit - get = partial( _get, self._expression.connection, From a520c40abbf26881a6213f426a00f71e19374f13 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 19:29:13 +0000 Subject: [PATCH 32/71] remove re import --- datajoint/fetch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 6852426f6..785e0a5ab 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -3,7 +3,6 @@ import logging import pandas import itertools -import re import json import numpy as np import uuid From b6fedc3935335742e7581f30528e104d5afdaed2 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 19:37:35 +0000 Subject: [PATCH 33/71] more tests --- tests/test_relational_operand.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 01f67947a..f15b23991 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -536,6 +536,21 @@ def test_top_in_or_list_fails(): def test_top_in_and_list_fails(): L() & dj.AndList(["cond_in_l=1", dj.Top()]) + @staticmethod + @raises(DataJointError) + def test_incorrect_limit_type(): + L() & dj.Top(limit="1") + + @staticmethod + @raises(DataJointError) + def test_incorrect_order_type(): + L() & dj.Top(order_by=1) + + @staticmethod + @raises(DataJointError) + def test_incorrect_offset_type(): + L() & dj.Top(offset="1") + @staticmethod def test_datetime(): """Test date retrieval""" From e737b8e8c30aa42b5921f8651188f711a4e2554b Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 23 May 2023 19:45:19 +0000 Subject: [PATCH 34/71] remove unused logger --- datajoint/condition.py | 2 +- datajoint/fetch.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 8e9e90ad4..377052ff1 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -98,7 +98,7 @@ def __post_init__(self): "Offset set, but no limit. Setting limit to a large number. " "Consider setting a limit explicitly." ) - self.limit = 18446744073709551615 # Some large number + self.limit = 999999999999 # arbitrary large number to allow query class Not: diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 785e0a5ab..49d0b14c0 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,6 +1,5 @@ from functools import partial from pathlib import Path -import logging import pandas import itertools import json @@ -14,8 +13,6 @@ from .settings import config from .utils import safe_write -logger = logging.getLogger(__name__.split(".")[0]) - class key: """ From f83046714a8bd231ed192f16e8572e34f33f752c Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 15:14:32 +0000 Subject: [PATCH 35/71] remove offset warning and warning test --- datajoint/condition.py | 4 ---- tests/test_fetch.py | 20 -------------------- 2 files changed, 24 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 377052ff1..55d037526 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -94,10 +94,6 @@ def __post_init__(self): if not (isinstance(self.offset, int)): raise DataJointError("Top offset must be an integer") if self.offset and self.limit is None: - logger.warning( - "Offset set, but no limit. Setting limit to a large number. " - "Consider setting a limit explicitly." - ) self.limit = 999999999999 # arbitrary large number to allow query diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 684cd4846..af0156c6a 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -213,26 +213,6 @@ def test_offset(self): np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" ) - def test_limit_warning(self): - """Tests whether warning is raised if offset is used without limit.""" - log_capture = io.StringIO() - stream_handler = logging.StreamHandler(log_capture) - log_format = logging.Formatter( - "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" - ) - stream_handler.setFormatter(log_format) - stream_handler.set_name("test_limit_warning") - logger.addHandler(stream_handler) - self.lang.fetch(offset=1) - - log_contents = log_capture.getvalue() - log_capture.close() - - for handler in logger.handlers: # Clean up handler - if handler.name == "test_limit_warning": - logger.removeHandler(handler) - assert "[WARNING]: Offset set, but no limit." in log_contents - def test_len(self): """Tests __len__""" assert_equal( From 5fc96a32d1203d71af06afe51c5296064d9fa331 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 16:31:04 +0000 Subject: [PATCH 36/71] better error test --- tests/test_relational_operand.py | 45 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index f15b23991..b3ac74855 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,4 +1,5 @@ import random +import re import string import pandas import datetime @@ -11,6 +12,7 @@ raises, assert_set_equal, assert_list_equal, + assert_raises, ) import datajoint as dj @@ -527,29 +529,26 @@ def test_restrictions_by_top(): ] @staticmethod - @raises(DataJointError) - def test_top_in_or_list_fails(): - L() & ("cond_in_l=1", dj.Top()) - - @staticmethod - @raises(DataJointError) - def test_top_in_and_list_fails(): - L() & dj.AndList(["cond_in_l=1", dj.Top()]) - - @staticmethod - @raises(DataJointError) - def test_incorrect_limit_type(): - L() & dj.Top(limit="1") - - @staticmethod - @raises(DataJointError) - def test_incorrect_order_type(): - L() & dj.Top(order_by=1) - - @staticmethod - @raises(DataJointError) - def test_incorrect_offset_type(): - L() & dj.Top(offset="1") + def test_top_errors(): + with assert_raises(DataJointError) as err1: + L() & ("cond_in_l=1", dj.Top()) + with assert_raises(DataJointError) as err2: + L() & dj.AndList(["cond_in_l=1", dj.Top()]) + with assert_raises(DataJointError) as err3: + L() & dj.Top(limit="1") + with assert_raises(DataJointError) as err4: + L() & dj.Top(order_by=1) + with assert_raises(DataJointError) as err5: + L() & dj.Top(offset="1") + assert "Invalid restriction type Top(limit=1, order_by='KEY', offset=0)" == str( + err1.exception + ) + assert "Invalid restriction type Top(limit=1, order_by='KEY', offset=0)" == str( + err2.exception + ) + assert "Top limit must be an integer" == str(err3.exception) + assert "Top order_by attributes must all be strings" == str(err4.exception) + assert "Top offset must be an integer" == str(err5.exception) @staticmethod def test_datetime(): From 0393f95631efc18711bb72051175aa828db2cd9d Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 16:31:21 +0000 Subject: [PATCH 37/71] unused import --- tests/test_relational_operand.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index b3ac74855..858e7d9e6 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -1,5 +1,4 @@ import random -import re import string import pandas import datetime From 4257459e5189015ae6a64663c95a19cf4dd2d8b1 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 16:52:04 +0000 Subject: [PATCH 38/71] datajointerror -> typeerror --- datajoint/condition.py | 10 +++++----- tests/test_relational_operand.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 55d037526..7f498866a 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -81,8 +81,8 @@ def __post_init__(self): self.order_by = self.order_by or ["KEY"] self.offset = self.offset or 0 - if self.limit is not None and not (isinstance(self.limit, int)): - raise DataJointError("Top limit must be an integer") + if self.limit is not None and not isinstance(self.limit, int): + raise TypeError("Top limit must be an integer") if not ( isinstance(self.order_by, str) or ( @@ -90,9 +90,9 @@ def __post_init__(self): and all(isinstance(r, str) for r in self.order_by) ) ): - raise DataJointError("Top order_by attributes must all be strings") - if not (isinstance(self.offset, int)): - raise DataJointError("Top offset must be an integer") + raise TypeError("Top order_by attributes must all be strings") + if not isinstance(self.offset, int): + raise TypeError("Top offset must be an integer") if self.offset and self.limit is None: self.limit = 999999999999 # arbitrary large number to allow query diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 858e7d9e6..a9983c752 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -533,11 +533,11 @@ def test_top_errors(): L() & ("cond_in_l=1", dj.Top()) with assert_raises(DataJointError) as err2: L() & dj.AndList(["cond_in_l=1", dj.Top()]) - with assert_raises(DataJointError) as err3: + with assert_raises(TypeError) as err3: L() & dj.Top(limit="1") - with assert_raises(DataJointError) as err4: + with assert_raises(TypeError) as err4: L() & dj.Top(order_by=1) - with assert_raises(DataJointError) as err5: + with assert_raises(TypeError) as err5: L() & dj.Top(offset="1") assert "Invalid restriction type Top(limit=1, order_by='KEY', offset=0)" == str( err1.exception From ef61f42d6beca36f2acecd414c6f3ebfbd6ff1ca Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 16:58:03 +0000 Subject: [PATCH 39/71] simplify order_by typecheck --- datajoint/condition.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 7f498866a..105913569 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -83,12 +83,8 @@ def __post_init__(self): if self.limit is not None and not isinstance(self.limit, int): raise TypeError("Top limit must be an integer") - if not ( - isinstance(self.order_by, str) - or ( - hasattr(self.order_by, "__iter__") - and all(isinstance(r, str) for r in self.order_by) - ) + if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all( + isinstance(r, str) for r in self.order_by ): raise TypeError("Top order_by attributes must all be strings") if not isinstance(self.offset, int): From 554f577c5f58897ece749f2c48321a3c9338c90d Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 24 May 2023 17:08:13 +0000 Subject: [PATCH 40/71] offset err msg --- datajoint/condition.py | 2 +- tests/test_relational_operand.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 105913569..322ae986f 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -88,7 +88,7 @@ def __post_init__(self): ): raise TypeError("Top order_by attributes must all be strings") if not isinstance(self.offset, int): - raise TypeError("Top offset must be an integer") + raise TypeError("The offset argument must be an integer") if self.offset and self.limit is None: self.limit = 999999999999 # arbitrary large number to allow query diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index a9983c752..2c0185275 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -547,7 +547,7 @@ def test_top_errors(): ) assert "Top limit must be an integer" == str(err3.exception) assert "Top order_by attributes must all be strings" == str(err4.exception) - assert "Top offset must be an integer" == str(err5.exception) + assert "The offset argument must be an integer" == str(err5.exception) @staticmethod def test_datetime(): From 161c7d0cb71f8d73dcee78b0ba1fbbc01ef569d9 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 25 May 2023 16:46:43 +0000 Subject: [PATCH 41/71] remove unused logger --- datajoint/condition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 322ae986f..7b4407d9b 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -2,7 +2,6 @@ import inspect import collections -import logging import re import uuid import datetime @@ -14,8 +13,6 @@ from typing import Union, List from dataclasses import dataclass -logger = logging.getLogger(__name__.split(".")[0]) - JSON_PATTERN = re.compile( r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$" ) From a48688a9307c8d33106b07e88941ac9abf9e472d Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 25 May 2023 17:24:42 +0000 Subject: [PATCH 42/71] move order_by list conversion to post_init --- datajoint/condition.py | 4 ++++ datajoint/expression.py | 4 ---- tests/test_relational_operand.py | 10 ++++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 7b4407d9b..0f2d394d3 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -88,6 +88,10 @@ def __post_init__(self): raise TypeError("The offset argument must be an integer") if self.offset and self.limit is None: self.limit = 999999999999 # arbitrary large number to allow query + if isinstance(self.order_by, str): + self.order_by = [ + self.order_by + ] # if 'order_by' passed in a string, make into list class Not: diff --git a/datajoint/expression.py b/datajoint/expression.py index 98525133e..ef3bb9d71 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -128,10 +128,6 @@ def sorting_clauses(self): order_by = self._top.order_by offset = self._top.offset - # if 'order_by' passed in a string, make into list - if isinstance(order_by, str): - order_by = [order_by] - clause = ( ( " ORDER BY " diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 2c0185275..1b48a9370 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -539,11 +539,13 @@ def test_top_errors(): L() & dj.Top(order_by=1) with assert_raises(TypeError) as err5: L() & dj.Top(offset="1") - assert "Invalid restriction type Top(limit=1, order_by='KEY', offset=0)" == str( - err1.exception + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err1.exception) ) - assert "Invalid restriction type Top(limit=1, order_by='KEY', offset=0)" == str( - err2.exception + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err2.exception) ) assert "Top limit must be an integer" == str(err3.exception) assert "Top order_by attributes must all be strings" == str(err4.exception) From 63ed6b835cf11b0e4c3ce9873c1847842f9afa93 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 25 May 2023 19:27:24 +0000 Subject: [PATCH 43/71] handle edge case --- datajoint/expression.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index ef3bb9d71..074d3d180 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -129,12 +129,12 @@ def sorting_clauses(self): offset = self._top.offset clause = ( - ( + "" + if not self.primary_key and order_by == ["KEY"] + else ( " ORDER BY " + ", ".join(_flatten_attribute_list(self.primary_key, order_by)) ) - if self.primary_key - else "" ) if limit is not None: clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") @@ -968,11 +968,14 @@ def _flatten_attribute_list(primary_key, attrs): """ :param primary_key: list of attributes in primary key :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaces with its component attributes + :return: generator of attributes where "KEY" is replaced with its component attributes """ for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): - yield from primary_key + if primary_key: + yield from primary_key + else: + continue elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): yield from (q + " DESC" for q in primary_key) else: From 778c6f9323fbf2cacc5327da1fa37be84b2805a1 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 25 May 2023 19:32:32 +0000 Subject: [PATCH 44/71] redundant comment --- datajoint/condition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index 0f2d394d3..de6372c6a 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -89,9 +89,7 @@ def __post_init__(self): if self.offset and self.limit is None: self.limit = 999999999999 # arbitrary large number to allow query if isinstance(self.order_by, str): - self.order_by = [ - self.order_by - ] # if 'order_by' passed in a string, make into list + self.order_by = [self.order_by] class Not: From 7ea05cf198a412e86740a1daf6f3b8991bb54ef1 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 25 May 2023 22:21:41 +0000 Subject: [PATCH 45/71] also handle "KEY desc" --- datajoint/expression.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 074d3d180..ef7c0582a 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -974,9 +974,8 @@ def _flatten_attribute_list(primary_key, attrs): if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): if primary_key: yield from primary_key - else: - continue elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): - yield from (q + " DESC" for q in primary_key) + if primary_key: + yield from (q + " DESC" for q in primary_key) else: yield a From 0002cb37671f6f3130ff83f969cfda1ef4f5985f Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 30 May 2023 16:26:10 +0000 Subject: [PATCH 46/71] fstrings --- datajoint/expression.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index ef7c0582a..2480d729c 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -124,10 +124,6 @@ def where_clause(self): def sorting_clauses(self): if not self._top: return "" - limit = self._top.limit - order_by = self._top.order_by - offset = self._top.offset - clause = ( "" if not self.primary_key and order_by == ["KEY"] @@ -135,9 +131,13 @@ def sorting_clauses(self): " ORDER BY " + ", ".join(_flatten_attribute_list(self.primary_key, order_by)) ) + else f" ORDER BY \ + {', '.join(_flatten_attribute_list(self.primary_key, self._top.order_by))}" ) - if limit is not None: - clause += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") + if self._top.limit is not None: + clause += f" LIMIT {self._top.limit} \ + {f' OFFSET {self._top.offset}' if self._top.offset else ''}" + return clause def make_sql(self, fields=None): @@ -739,7 +739,7 @@ def make_sql(self, fields=None): + ( "" if not self.restriction - else " HAVING (%s)" % ")AND(".join(self.restriction) + else f" HAVING ({')AND('.join(self.restriction)})" ) ), sorting=self.sorting_clauses(), From 6268ebfc8d6fa89e3742cd9861b1df5dc249e3b7 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 30 May 2023 16:26:38 +0000 Subject: [PATCH 47/71] regex matching for empty pk case --- datajoint/expression.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 2480d729c..79107c2f5 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -126,10 +126,11 @@ def sorting_clauses(self): return "" clause = ( "" - if not self.primary_key and order_by == ["KEY"] - else ( - " ORDER BY " - + ", ".join(_flatten_attribute_list(self.primary_key, order_by)) + if not self.primary_key + and all( + re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a) + or re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a) + for a in self._top.order_by ) else f" ORDER BY \ {', '.join(_flatten_attribute_list(self.primary_key, self._top.order_by))}" From 37351dbce8159b16944c72a97ee689ff09e7b5c1 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 30 May 2023 17:03:44 +0000 Subject: [PATCH 48/71] use flatten_atribute_list --- datajoint/expression.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 79107c2f5..c05317a15 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -126,12 +126,7 @@ def sorting_clauses(self): return "" clause = ( "" - if not self.primary_key - and all( - re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a) - or re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a) - for a in self._top.order_by - ) + if not any(_flatten_attribute_list(self.primary_key, self._top.order_by)) else f" ORDER BY \ {', '.join(_flatten_attribute_list(self.primary_key, self._top.order_by))}" ) From c7020f77177174f14d80e0fdaef6be93a2f155a8 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 30 May 2023 19:42:41 +0000 Subject: [PATCH 49/71] simplify flatten calls --- datajoint/expression.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index c05317a15..43a979c42 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -124,12 +124,11 @@ def where_clause(self): def sorting_clauses(self): if not self._top: return "" - clause = ( - "" - if not any(_flatten_attribute_list(self.primary_key, self._top.order_by)) - else f" ORDER BY \ - {', '.join(_flatten_attribute_list(self.primary_key, self._top.order_by))}" + clause = ", ".join( + _flatten_attribute_list(self.primary_key, self._top.order_by) ) + if clause: + clause = f" ORDER BY {clause}" if self._top.limit is not None: clause += f" LIMIT {self._top.limit} \ {f' OFFSET {self._top.offset}' if self._top.offset else ''}" From 59279bade58dfcd9a743281b72656423e14873d8 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 31 May 2023 17:04:32 +0000 Subject: [PATCH 50/71] formatting --- datajoint/expression.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 43a979c42..fe2cc9260 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -130,8 +130,7 @@ def sorting_clauses(self): if clause: clause = f" ORDER BY {clause}" if self._top.limit is not None: - clause += f" LIMIT {self._top.limit} \ - {f' OFFSET {self._top.offset}' if self._top.offset else ''}" + clause += f" LIMIT {self._top.limit}{f' OFFSET {self._top.offset}' if self._top.offset else ''}" return clause From 69c2e97587cdbb87934378847398356447bee96c Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 31 May 2023 20:22:40 +0000 Subject: [PATCH 51/71] escape keywords --- datajoint/expression.py | 49 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/datajoint/expression.py b/datajoint/expression.py index fe2cc9260..919ad5232 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -964,6 +964,53 @@ def _flatten_attribute_list(primary_key, attrs): :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" :return: generator of attributes where "KEY" is replaced with its component attributes """ + sql_keywords = [ + "SELECT", + "FROM", + "WHERE", + "AND", + "OR", + "INSERT", + "INTO", + "VALUES", + "UPDATE", + "SET", + "DELETE", + "CREATE", + "TABLE", + "ALTER", + "DROP", + "INDEX", + "JOIN", + "LEFT", + "RIGHT", + "INNER", + "OUTER", + "UNION", + "GROUP", + "BY", + "HAVING", + "ORDER", + "ASC", + "DESC", + "LIMIT", + "OFFSET", + "DISTINCT", + "CASE", + "WHEN", + "THEN", + "ELSE", + "END", + "NULL", + "NOT", + "IN", + "LIKE", + "KEY", + ] + keyword_pattern = ( + r"\b(" + "|".join(re.escape(keyword) for keyword in sql_keywords) + r")\b" + ) + for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): if primary_key: @@ -971,5 +1018,7 @@ def _flatten_attribute_list(primary_key, attrs): elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): if primary_key: yield from (q + " DESC" for q in primary_key) + elif re.match(keyword_pattern, a, re.I): + yield f"`{a}`" else: yield a From 091e4440f0fc875c1a97677f9bffc3712a4e8342 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Wed, 31 May 2023 22:17:58 +0000 Subject: [PATCH 52/71] always escape --- datajoint/expression.py | 59 +++++++---------------------------------- 1 file changed, 9 insertions(+), 50 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index 919ad5232..b9e5c753d 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -125,7 +125,9 @@ def sorting_clauses(self): if not self._top: return "" clause = ", ".join( - _flatten_attribute_list(self.primary_key, self._top.order_by) + _wrap_attributes( + _flatten_attribute_list(self.primary_key, self._top.order_by) + ) ) if clause: clause = f" ORDER BY {clause}" @@ -964,53 +966,6 @@ def _flatten_attribute_list(primary_key, attrs): :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" :return: generator of attributes where "KEY" is replaced with its component attributes """ - sql_keywords = [ - "SELECT", - "FROM", - "WHERE", - "AND", - "OR", - "INSERT", - "INTO", - "VALUES", - "UPDATE", - "SET", - "DELETE", - "CREATE", - "TABLE", - "ALTER", - "DROP", - "INDEX", - "JOIN", - "LEFT", - "RIGHT", - "INNER", - "OUTER", - "UNION", - "GROUP", - "BY", - "HAVING", - "ORDER", - "ASC", - "DESC", - "LIMIT", - "OFFSET", - "DISTINCT", - "CASE", - "WHEN", - "THEN", - "ELSE", - "END", - "NULL", - "NOT", - "IN", - "LIKE", - "KEY", - ] - keyword_pattern = ( - r"\b(" + "|".join(re.escape(keyword) for keyword in sql_keywords) + r")\b" - ) - for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): if primary_key: @@ -1018,7 +973,11 @@ def _flatten_attribute_list(primary_key, attrs): elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): if primary_key: yield from (q + " DESC" for q in primary_key) - elif re.match(keyword_pattern, a, re.I): - yield f"`{a}`" else: yield a + + +def _wrap_attributes(attr): + for entry in attr: + wrapped_entry = re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, re.IGNORECASE) + yield wrapped_entry # wrap attribute names in backquotes From 97d6e55e62714776ae6a2e53c24550e0e4c81496 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 1 Jun 2023 15:06:27 +0000 Subject: [PATCH 53/71] fix --- datajoint/expression.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datajoint/expression.py b/datajoint/expression.py index b9e5c753d..cce40e2e6 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -978,6 +978,5 @@ def _flatten_attribute_list(primary_key, attrs): def _wrap_attributes(attr): - for entry in attr: - wrapped_entry = re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, re.IGNORECASE) - yield wrapped_entry # wrap attribute names in backquotes + for entry in attr: # wrap attribute names in backquotes + yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE) From c9a8001c15468d08b62ad7e11478232c3368f2be Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 6 Jun 2023 16:21:58 +0000 Subject: [PATCH 54/71] keywork pk test cases --- datajoint/declare.py | 10 +++++++--- tests_old/schema_simple.py | 18 +++++++++++++++++ tests_old/test_relational_operand.py | 29 ++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 683e34759..c26e46a50 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -443,9 +443,13 @@ def format_attribute(attr): return f"`{attr}`" return f"({attr})" - match = re.match( - r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I - ).groupdict() + try: + match = re.match( + r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I + ).groupdict() + except AttributeError: + raise DataJointError(f'Table definition syntax error in line "{line}"') + attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"]) index_sql.append( "{unique}index ({attrs})".format( diff --git a/tests_old/schema_simple.py b/tests_old/schema_simple.py index 78f64d036..3f0c29b8d 100644 --- a/tests_old/schema_simple.py +++ b/tests_old/schema_simple.py @@ -14,6 +14,24 @@ schema = dj.Schema(PREFIX + "_relational", locals(), connection=dj.conn(**CONN_INFO)) +@schema +class SelectPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id: int + select : int + """ + contents = list(dict(id=i, select=i * j) for i in range(3) for j in range(4, 0, -1)) + + +@schema +class KeyPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id : int + key : int + """ + contents = list(dict(id=i, key=i + j) for i in range(3) for j in range(4, 0, -1)) + + @schema class IJ(dj.Lookup): definition = """ # tests restrictions diff --git a/tests_old/test_relational_operand.py b/tests_old/test_relational_operand.py index 1b48a9370..3ba6291da 100644 --- a/tests_old/test_relational_operand.py +++ b/tests_old/test_relational_operand.py @@ -25,6 +25,8 @@ L, DataA, DataB, + SelectPK, + KeyPK, TTestUpdate, IJ, JI, @@ -527,6 +529,33 @@ def test_restrictions_by_top(): {"id_l": 20, "cond_in_l": 1}, ] + @staticmethod + def test_top_restriction_with_keywords(): + select = SelectPK() & dj.Top(limit=9, order_by=["select desc"]) + key = KeyPK() & dj.Top(limit=9, order_by="key desc") + assert select.fetch(as_dict=True) == [ + {"id": 2, "select": 8}, + {"id": 2, "select": 6}, + {"id": 1, "select": 4}, + {"id": 2, "select": 4}, + {"id": 1, "select": 3}, + {"id": 1, "select": 2}, + {"id": 2, "select": 2}, + {"id": 1, "select": 1}, + {"id": 0, "select": 0}, + ] + assert key.fetch(as_dict=True) == [ + {"id": 2, "key": 6}, + {"id": 2, "key": 5}, + {"id": 1, "key": 5}, + {"id": 0, "key": 4}, + {"id": 1, "key": 4}, + {"id": 2, "key": 4}, + {"id": 0, "key": 3}, + {"id": 1, "key": 3}, + {"id": 2, "key": 3}, + ] + @staticmethod def test_top_errors(): with assert_raises(DataJointError) as err1: From 223ccdd649573d4adb3e4eb7b7e5f54ed5034375 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 6 Jun 2023 16:30:28 +0000 Subject: [PATCH 55/71] fix schema test --- tests_old/test_schema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests_old/test_schema.py b/tests_old/test_schema.py index 8ec24fc49..f7a18198e 100644 --- a/tests_old/test_schema.py +++ b/tests_old/test_schema.py @@ -155,6 +155,8 @@ def test_list_tables(): "#website", "profile", "profile__website", + "#select_p_k", + "#key_p_k", ] ) == set(schema_simple.list_tables()) From b19910a7b6592af4f5649e7b5c4438ce3354be40 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Tue, 6 Jun 2023 16:34:50 +0000 Subject: [PATCH 56/71] regex mismatch --- datajoint/declare.py | 8 +++----- tests_old/test_declare.py | 9 +++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index c26e46a50..c99c541f0 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -443,12 +443,10 @@ def format_attribute(attr): return f"`{attr}`" return f"({attr})" - try: - match = re.match( - r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I - ).groupdict() - except AttributeError: + match = re.match(r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I) + if match is None: raise DataJointError(f'Table definition syntax error in line "{line}"') + match = match.groupdict() attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"]) index_sql.append( diff --git a/tests_old/test_declare.py b/tests_old/test_declare.py index 67f532449..bb23be276 100644 --- a/tests_old/test_declare.py +++ b/tests_old/test_declare.py @@ -341,3 +341,12 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): definition = """ -> (master) """ + + @staticmethod + @raises(dj.DataJointError) + def test_regex_mismatch(): + @schema + class IndexAttribute(dj.Manual): + definition = """ + index: int + """ From ca4736ae82063192636f4e5ab590a8ac8d56b0b0 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 22 Jun 2023 20:23:11 +0000 Subject: [PATCH 57/71] documentation --- datajoint/condition.py | 2 +- docs/src/query/operators.md | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index de6372c6a..e12ec1443 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -71,7 +71,7 @@ class Top: """ limit: Union[int, None] = 1 - order_by: Union[str, List[str]] = "KEY" + order_by: Union[str, List[str]] = ["KEY"] offset: int = 0 def __post_init__(self): diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index 9c9258442..6022a0d34 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -17,8 +17,9 @@ DataJoint implements a complete algebra of operators on tables: | [aggr](#aggr) | A.aggr(B, ...) | Same as projection with computations based on matching information in B | | [union](#union) | A + B | All unique entities from both A and B | | [universal set](#universal-set)\*| dj.U() | All unique entities from both A and B | +| [top](#top)\*| dj.Top() | The top rows of A -\*While not technically a query operator, it is useful to discuss Universal Set in the +\*While not technically query operators, it is useful to discuss Universal Set and Top in the same context. ??? note "Notes on relational algebra" @@ -218,6 +219,20 @@ The examples below will use the table definitions in [table tiers](../reproduce/ +## Top + +Similar to the univeral set operator, the top operator uses `dj.top` notation. It is used to +restrict a query by the given `limit`, `order_by`, and `offset` parameters: + +```python +Session & dj.top(limit=10, order_by='session_date') +``` + +The result of this expression returns the first 10 rows of `Session` and sorts them +by their `session_date` in ascending order. If the `order_by` argument was instead: `session_date DESC`, then it would be sorted in descending order. + +The default values for `dj.top` parameters are `limit=1`, `order_by=["KEY"]`, and `offset=0`. + ## Restriction `&` and `-` operators permit restriction. From b355bc61c158ec4c587704b52378c5526762676b Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 22 Jun 2023 20:24:11 +0000 Subject: [PATCH 58/71] bump nginx --- LNX-docker-compose.yml | 2 +- local-docker-compose.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LNX-docker-compose.yml b/LNX-docker-compose.yml index 970552860..9c0a95b78 100644 --- a/LNX-docker-compose.yml +++ b/LNX-docker-compose.yml @@ -44,7 +44,7 @@ services: interval: 15s fakeservices.datajoint.io: <<: *net - image: datajoint/nginx:v0.2.5 + image: datajoint/nginx:v0.2.6 environment: - ADD_db_TYPE=DATABASE - ADD_db_ENDPOINT=db:3306 diff --git a/local-docker-compose.yml b/local-docker-compose.yml index 8b43289d3..62b52ad66 100644 --- a/local-docker-compose.yml +++ b/local-docker-compose.yml @@ -46,7 +46,7 @@ services: interval: 15s fakeservices.datajoint.io: <<: *net - image: datajoint/nginx:v0.2.5 + image: datajoint/nginx:v0.2.6 environment: - ADD_db_TYPE=DATABASE - ADD_db_ENDPOINT=db:3306 From 5bb8da40d6bd360b1ecffadac2fedb5c729421ff Mon Sep 17 00:00:00 2001 From: A-Baji Date: Thu, 22 Jun 2023 20:44:56 +0000 Subject: [PATCH 59/71] fix --- datajoint/condition.py | 2 +- docs/src/query/operators.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/condition.py b/datajoint/condition.py index e12ec1443..de6372c6a 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -71,7 +71,7 @@ class Top: """ limit: Union[int, None] = 1 - order_by: Union[str, List[str]] = ["KEY"] + order_by: Union[str, List[str]] = "KEY" offset: int = 0 def __post_init__(self): diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index 6022a0d34..8392f63c4 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -231,7 +231,7 @@ Session & dj.top(limit=10, order_by='session_date') The result of this expression returns the first 10 rows of `Session` and sorts them by their `session_date` in ascending order. If the `order_by` argument was instead: `session_date DESC`, then it would be sorted in descending order. -The default values for `dj.top` parameters are `limit=1`, `order_by=["KEY"]`, and `offset=0`. +The default values for `dj.top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. ## Restriction From f303549c8eba36642bc9d9b1169f0f9e0945d909 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 23 Jun 2023 17:13:53 +0000 Subject: [PATCH 60/71] more docs --- docs/src/query/operators.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index 8392f63c4..f5ace4ec3 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -229,7 +229,15 @@ Session & dj.top(limit=10, order_by='session_date') ``` The result of this expression returns the first 10 rows of `Session` and sorts them -by their `session_date` in ascending order. If the `order_by` argument was instead: `session_date DESC`, then it would be sorted in descending order. +by their `session_date` in ascending order. + +### `order_by` + +| Example | Description | +|-------------------------------------------|-----------------------------------------------------------------------------------| +| `order_by="session_date DESC"` | Sorts by `session_date` in *descending* order | +| `order_by="KEY"` | Sorts by the primary key(s) | +| `order_by=["subject_id", "session_date"]` | Sorts by `subject_id`, then sorts matching `subject_id`s by their `session_date` | The default values for `dj.top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. From d6626a821f7a9be4446c60c4cce99740ecb512c0 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 23 Jun 2023 17:59:06 +0000 Subject: [PATCH 61/71] suggestiosn --- docs/src/query/operators.md | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index f5ace4ec3..c884c0108 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -221,11 +221,11 @@ The examples below will use the table definitions in [table tiers](../reproduce/ ## Top -Similar to the univeral set operator, the top operator uses `dj.top` notation. It is used to +Similar to the univeral set operator, the top operator uses `dj.Top` notation. It is used to restrict a query by the given `limit`, `order_by`, and `offset` parameters: ```python -Session & dj.top(limit=10, order_by='session_date') +Session & dj.Top(limit=10, order_by='session_date') ``` The result of this expression returns the first 10 rows of `Session` and sorts them @@ -233,13 +233,14 @@ by their `session_date` in ascending order. ### `order_by` -| Example | Description | -|-------------------------------------------|-----------------------------------------------------------------------------------| -| `order_by="session_date DESC"` | Sorts by `session_date` in *descending* order | -| `order_by="KEY"` | Sorts by the primary key(s) | -| `order_by=["subject_id", "session_date"]` | Sorts by `subject_id`, then sorts matching `subject_id`s by their `session_date` | +| Example | Description | +|-------------------------------------------|---------------------------------------------------------------------------------| +| `order_by="session_date DESC"` | Sort by `session_date` in *descending* order | +| `order_by="KEY"` | Sort by the primary key | +| `order_by="KEY DESC"` | Sort by the primary key in descending order | +| `order_by=["subject_id", "session_date"]` | Sort by `subject_id`, then sort matching `subject_id`s by their `session_date` | -The default values for `dj.top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. +The default values for `dj.Top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. ## Restriction From 96245f03fedc9b702c132f1da123db820031c69d Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 23 Jun 2023 18:01:25 +0000 Subject: [PATCH 62/71] italicise --- docs/src/query/operators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index c884c0108..b33a53e43 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -237,7 +237,7 @@ by their `session_date` in ascending order. |-------------------------------------------|---------------------------------------------------------------------------------| | `order_by="session_date DESC"` | Sort by `session_date` in *descending* order | | `order_by="KEY"` | Sort by the primary key | -| `order_by="KEY DESC"` | Sort by the primary key in descending order | +| `order_by="KEY DESC"` | Sort by the primary key in *descending* order | | `order_by=["subject_id", "session_date"]` | Sort by `subject_id`, then sort matching `subject_id`s by their `session_date` | The default values for `dj.Top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. From d9aabf2cf9ae6b7596300e35b59417d74c96bdd6 Mon Sep 17 00:00:00 2001 From: A-Baji Date: Fri, 23 Jun 2023 19:52:37 +0000 Subject: [PATCH 63/71] typo --- docs/src/query/operators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index b33a53e43..550108c75 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -221,7 +221,7 @@ The examples below will use the table definitions in [table tiers](../reproduce/ ## Top -Similar to the univeral set operator, the top operator uses `dj.Top` notation. It is used to +Similar to the universal set operator, the top operator uses `dj.Top` notation. It is used to restrict a query by the given `limit`, `order_by`, and `offset` parameters: ```python From 88783f051dbad61eb9220fa1e439e2255429a13d Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:48:48 +0000 Subject: [PATCH 64/71] Disable format on save in VS Code Causes a lot of whitespace changes in the git diff. We use black and flake8 for linting already. --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index efb8c58b5..00ebd4b97 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "editor.formatOnPaste": false, - "editor.formatOnSave": true, + "editor.formatOnSave": false, "editor.rulers": [ 94 ], From a5eb6fbeda36611eb2db61b1a677f54437005629 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:51:47 +0000 Subject: [PATCH 65/71] Migrate most tests from #1084 to pytest --- tests/conftest.py | 2 + tests/schema_simple.py | 20 +++++++ tests/test_declare.py | 11 ++++ tests/test_relational_operand.py | 92 ++++++++++++++++++++++++++++++++ tests/test_schema.py | 2 + 5 files changed, 127 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 65d68268b..9ece6bb49 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -330,6 +330,8 @@ def schema_simp(connection_test, prefix): schema = dj.Schema( prefix + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test ) + schema(schema_simple.SelectPK) + schema(schema_simple.KeyPK) schema(schema_simple.IJ) schema(schema_simple.JI) schema(schema_simple.A) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index 9e3113c9a..77ee6849b 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -13,6 +13,26 @@ import inspect +@schema +class SelectPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id: int + select : int + """ + contents = list(dict(id=i, select=i * j) + for i in range(3) for j in range(4, 0, -1)) + + +@schema +class KeyPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id : int + key : int + """ + contents = list(dict(id=i, key=i + j) + for i in range(3) for j in range(4, 0, -1)) + + class IJ(dj.Lookup): definition = """ # tests restrictions i : int diff --git a/tests/test_declare.py b/tests/test_declare.py index 8939000bc..a3cc3fec2 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -339,6 +339,17 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): schema_any(WhyWouldAnyoneCreateATableNameThisLong) +def test_regex_mismatch(schema_any): + + class IndexAttribute(dj.Manual): + definition = """ + index: int + """ + + with pytest.raises(dj.DataJointError): + schema_any(IndexAttribute) + + def test_table_name_with_underscores(schema_any): """ Test issue #1150 -- Reject table names containing underscores. Tables should be in strict diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 65c6a5d74..9668f1bcc 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -5,6 +5,7 @@ import datetime import numpy as np import datajoint as dj +from datajoint.errors import DataJointError from .schema_simple import * from .schema import * @@ -570,3 +571,94 @@ def test_union_multiple(schema_simp_pop): y = set(zip(*q2.fetch("i", "j"))) assert x == y assert q1.fetch(as_dict=True) == q2.fetch(as_dict=True) + + +class TestDjTop: + """TODO: migrate""" + + def test_restrictions_by_top(self): + a = L() & dj.Top() + b = L() & dj.Top(order_by=["cond_in_l", "KEY"]) + x = L() & dj.Top(5, "id_l desc", 4) & "cond_in_l=1" + y = L() & "cond_in_l=1" & dj.Top(5, "id_l desc", 4) + z = ( + L() + & dj.Top(None, order_by="id_l desc") + & "cond_in_l=1" + & dj.Top(5, "id_l desc") + & ("id_l=20", "id_l=16", "id_l=17") + & dj.Top(2, "id_l asc", 1) + ) + assert len(a) == 1 + assert len(b) == 1 + assert len(x) == 1 + assert len(y) == 5 + assert len(z) == 2 + assert a.fetch(as_dict=True) == [ + {"id_l": 0, "cond_in_l": 1}, + ] + assert b.fetch(as_dict=True) == [ + {"id_l": 3, "cond_in_l": 0}, + ] + assert x.fetch(as_dict=True) == [{"id_l": 25, "cond_in_l": 1}] + assert y.fetch(as_dict=True) == [ + {"id_l": 16, "cond_in_l": 1}, + {"id_l": 15, "cond_in_l": 1}, + {"id_l": 11, "cond_in_l": 1}, + {"id_l": 10, "cond_in_l": 1}, + {"id_l": 5, "cond_in_l": 1}, + ] + assert z.fetch(as_dict=True) == [ + {"id_l": 17, "cond_in_l": 1}, + {"id_l": 20, "cond_in_l": 1}, + ] + + def test_top_restriction_with_keywords(self): + select = SelectPK() & dj.Top(limit=9, order_by=["select desc"]) + key = KeyPK() & dj.Top(limit=9, order_by="key desc") + assert select.fetch(as_dict=True) == [ + {"id": 2, "select": 8}, + {"id": 2, "select": 6}, + {"id": 1, "select": 4}, + {"id": 2, "select": 4}, + {"id": 1, "select": 3}, + {"id": 1, "select": 2}, + {"id": 2, "select": 2}, + {"id": 1, "select": 1}, + {"id": 0, "select": 0}, + ] + assert key.fetch(as_dict=True) == [ + {"id": 2, "key": 6}, + {"id": 2, "key": 5}, + {"id": 1, "key": 5}, + {"id": 0, "key": 4}, + {"id": 1, "key": 4}, + {"id": 2, "key": 4}, + {"id": 0, "key": 3}, + {"id": 1, "key": 3}, + {"id": 2, "key": 3}, + ] + + def test_top_errors(self): + with assert_raises(DataJointError) as err1: + L() & ("cond_in_l=1", dj.Top()) + with assert_raises(DataJointError) as err2: + L() & dj.AndList(["cond_in_l=1", dj.Top()]) + with assert_raises(TypeError) as err3: + L() & dj.Top(limit="1") + with assert_raises(TypeError) as err4: + L() & dj.Top(order_by=1) + with assert_raises(TypeError) as err5: + L() & dj.Top(offset="1") + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err1.exception) + ) + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err2.exception) + ) + assert "Top limit must be an integer" == str(err3.exception) + assert "Top order_by attributes must all be strings" == str( + err4.exception) + assert "The offset argument must be an integer" == str(err5.exception) diff --git a/tests/test_schema.py b/tests/test_schema.py index 6407cacab..257de221c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -210,6 +210,8 @@ def test_list_tables(schema_simp): "#website", "profile", "profile__website", + "#select_p_k", + "#key_p_k", ] ) actual = set(schema_simp.list_tables()) From 68693718f0302c0cfc3d0c23ff9d85b3742db695 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:52:41 +0000 Subject: [PATCH 66/71] Deprecate test_fetch.py::test_limit_warning Deprecated in #1084 --- tests/test_fetch.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 4f45ae9e9..7a3cf5a11 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -202,28 +202,6 @@ def test_offset(lang, languages): assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" -def test_limit_warning(lang): - """Tests whether warning is raised if offset is used without limit.""" - logger = logging.getLogger("datajoint") - log_capture = io.StringIO() - stream_handler = logging.StreamHandler(log_capture) - log_format = logging.Formatter( - "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" - ) - stream_handler.setFormatter(log_format) - stream_handler.set_name("test_limit_warning") - logger.addHandler(stream_handler) - lang.fetch(offset=1) - - log_contents = log_capture.getvalue() - log_capture.close() - - for handler in logger.handlers: # Clean up handler - if handler.name == "test_limit_warning": - logger.removeHandler(handler) - assert "[WARNING]: Offset set, but no limit." in log_contents - - def test_len(lang): """Tests __len__""" assert len(lang.fetch()) == len(lang), "__len__ is not behaving properly" From 7220ed09a3c703a6518138814380c3e7769bcf99 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:15:58 +0000 Subject: [PATCH 67/71] Fix schema_simp fixture --- tests/schema_simple.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index 77ee6849b..05d7aa7a8 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -13,7 +13,6 @@ import inspect -@schema class SelectPK(dj.Lookup): definition = """ # tests sql keyword escaping id: int @@ -23,7 +22,6 @@ class SelectPK(dj.Lookup): for i in range(3) for j in range(4, 0, -1)) -@schema class KeyPK(dj.Lookup): definition = """ # tests sql keyword escaping id : int From 14f8970abf49e767b1b6d775213fa6eab5a570fe Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:33:31 +0000 Subject: [PATCH 68/71] Migrate TestDjTop tests to pytest --- tests/test_relational_operand.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 9668f1bcc..7fc5127b0 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -576,7 +576,7 @@ def test_union_multiple(schema_simp_pop): class TestDjTop: """TODO: migrate""" - def test_restrictions_by_top(self): + def test_restrictions_by_top(self, schema_simp_pop): a = L() & dj.Top() b = L() & dj.Top(order_by=["cond_in_l", "KEY"]) x = L() & dj.Top(5, "id_l desc", 4) & "cond_in_l=1" @@ -613,7 +613,7 @@ def test_restrictions_by_top(self): {"id_l": 20, "cond_in_l": 1}, ] - def test_top_restriction_with_keywords(self): + def test_top_restriction_with_keywords(self, schema_simp_pop): select = SelectPK() & dj.Top(limit=9, order_by=["select desc"]) key = KeyPK() & dj.Top(limit=9, order_by="key desc") assert select.fetch(as_dict=True) == [ @@ -639,26 +639,26 @@ def test_top_restriction_with_keywords(self): {"id": 2, "key": 3}, ] - def test_top_errors(self): - with assert_raises(DataJointError) as err1: + def test_top_errors(self, schema_simp_pop): + with pytest.raises(DataJointError) as err1: L() & ("cond_in_l=1", dj.Top()) - with assert_raises(DataJointError) as err2: + with pytest.raises(DataJointError) as err2: L() & dj.AndList(["cond_in_l=1", dj.Top()]) - with assert_raises(TypeError) as err3: + with pytest.raises(TypeError) as err3: L() & dj.Top(limit="1") - with assert_raises(TypeError) as err4: + with pytest.raises(TypeError) as err4: L() & dj.Top(order_by=1) - with assert_raises(TypeError) as err5: + with pytest.raises(TypeError) as err5: L() & dj.Top(offset="1") assert ( - "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" - == str(err1.exception) + "datajoint.errors.DataJointError: Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err1.exconly()) ) assert ( - "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" - == str(err2.exception) + "datajoint.errors.DataJointError: Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err2.exconly()) ) - assert "Top limit must be an integer" == str(err3.exception) - assert "Top order_by attributes must all be strings" == str( - err4.exception) - assert "The offset argument must be an integer" == str(err5.exception) + assert "TypeError: Top limit must be an integer" == str(err3.exconly()) + assert "TypeError: Top order_by attributes must all be strings" == str( + err4.exconly()) + assert "TypeError: The offset argument must be an integer" == str(err5.exconly()) From 9baa3579a1f73f2996701332e5651b58c7e60e79 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:36:03 +0000 Subject: [PATCH 69/71] Format with black==24.4.2 --- tests/schema_simple.py | 6 ++---- tests/test_relational_operand.py | 7 +++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index 05d7aa7a8..f3e591382 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -18,8 +18,7 @@ class SelectPK(dj.Lookup): id: int select : int """ - contents = list(dict(id=i, select=i * j) - for i in range(3) for j in range(4, 0, -1)) + contents = list(dict(id=i, select=i * j) for i in range(3) for j in range(4, 0, -1)) class KeyPK(dj.Lookup): @@ -27,8 +26,7 @@ class KeyPK(dj.Lookup): id : int key : int """ - contents = list(dict(id=i, key=i + j) - for i in range(3) for j in range(4, 0, -1)) + contents = list(dict(id=i, key=i + j) for i in range(3) for j in range(4, 0, -1)) class IJ(dj.Lookup): diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 7fc5127b0..bebadb8db 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -660,5 +660,8 @@ def test_top_errors(self, schema_simp_pop): ) assert "TypeError: Top limit must be an integer" == str(err3.exconly()) assert "TypeError: Top order_by attributes must all be strings" == str( - err4.exconly()) - assert "TypeError: The offset argument must be an integer" == str(err5.exconly()) + err4.exconly() + ) + assert "TypeError: The offset argument must be an integer" == str( + err5.exconly() + ) From ad04fdecec16ec09938fb27f28bb6824fed4b798 Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:35:58 -0500 Subject: [PATCH 70/71] Remove merge artifacts --- docs/src/query/operators.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index 497872603..39f2488dd 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -94,7 +94,6 @@ of either the primary key or a foreign key. 2. All common attributes in the two relations must be of a compatible datatype for equality comparisons. -<<<<<<< HEAD ## Restriction The restriction operator `A & cond` selects the subset of entities from `A` that meet @@ -394,10 +393,3 @@ dj.U().aggr(Session, n="max(session)") # (3) `dj.U()`, as shown in the last example above, is often useful for integer IDs. For an example of this process, see the source code for [Element Array Electrophysiology's `insert_new_params`](https://datajoint.com/docs/elements/element-array-ephys/latest/api/element_array_ephys/ephys_acute/#element_array_ephys.ephys_acute.ClusteringParamSet.insert_new_params). -======= -These restrictions are introduced both for performance reasons and for conceptual -reasons. -For performance, they encourage queries that rely on indexes. -For conceptual reasons, they encourage database design in which entities in different -tables are related to each other by the use of primary keys and foreign keys. ->>>>>>> master From ff5765059f36ece509c08a6d4b79b8125f6fd90f Mon Sep 17 00:00:00 2001 From: Ethan Ho <53266718+ethho@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:37:14 -0500 Subject: [PATCH 71/71] Remove debug docstrings --- tests/test_relational_operand.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index bebadb8db..8ff8286e1 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -574,7 +574,6 @@ def test_union_multiple(schema_simp_pop): class TestDjTop: - """TODO: migrate""" def test_restrictions_by_top(self, schema_simp_pop): a = L() & dj.Top()