From 42f1a1acbd49270042d13d8108f17a8dac244762 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Wed, 31 Jan 2024 14:04:36 -0600 Subject: [PATCH] Address join-compatibility issue for long chains (#811) * Fix join error for position merge TableChain * Address failing tests from delete overwrite * Add gap node note --- src/spyglass/common/common_session.py | 12 ++- src/spyglass/common/common_usage.py | 2 +- src/spyglass/utils/dj_chains.py | 101 ++++++++++++++++++++++---- src/spyglass/utils/dj_mixin.py | 30 ++++++-- tests/conftest.py | 9 ++- 5 files changed, 128 insertions(+), 26 deletions(-) diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index f6f783262..acb4a0826 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -8,7 +8,7 @@ from spyglass.common.common_lab import Institution, Lab, LabMember from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_subject import Subject -from spyglass.settings import config, debug_mode +from spyglass.settings import config, debug_mode, test_mode from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file @@ -214,6 +214,8 @@ def add_session_to_group( *, skip_duplicates: bool = False, ): + if test_mode: + skip_duplicates = True SessionGroupSession.insert1( { "session_group_name": session_group_name, @@ -230,12 +232,16 @@ def remove_session_from_group( "session_group_name": session_group_name, "nwb_file_name": nwb_file_name, } - (SessionGroupSession & query).delete(*args, **kwargs) + (SessionGroupSession & query).delete( + force_permission=test_mode, *args, **kwargs + ) @staticmethod def delete_group(session_group_name: str, *args, **kwargs): query = {"session_group_name": session_group_name} - (SessionGroup & query).delete(*args, **kwargs) + (SessionGroup & query).delete( + force_permission=test_mode, *args, **kwargs + ) @staticmethod def get_group_sessions(session_group_name: str): diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 8b110cbc2..ccd58091d 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -18,6 +18,6 @@ class CautiousDelete(dj.Manual): dj_user: varchar(64) duration: float origin: varchar(64) - restriction: varchar(64) + restriction: varchar(255) merge_deletes = null: blob """ diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index b76132551..d8152d020 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List +from typing import List, Union import datajoint as dj import networkx as nx @@ -16,6 +16,38 @@ class TableChains: Functions as a plural version of TableChain, allowing a single `join` call across all chains from parent -> Merge table. + + Attributes + ---------- + parent : Table + Parent or origin of chains. + child : Table + Merge table or destination of chains. + connection : datajoint.Connection, optional + Connection to database used to create FreeTable objects. Defaults to + parent.connection. + part_names : List[str] + List of full table names of child parts. + chains : List[TableChain] + List of TableChain objects for each part in child. + has_link : bool + Cached attribute to store whether parent is linked to child via any of + child parts. False if (a) child is not in parent.descendants or (b) + nx.NetworkXNoPath is raised by nx.shortest_path for all chains. + + Methods + ------- + __init__(parent, child, connection=None) + Initialize TableChains with parent and child tables. + __repr__() + Return full representation of chains. + Multiline parent -> child for each chain. + __len__() + Return number of chains with links. + __getitem__(index: Union[int, str]) + Return TableChain object at index, or use substring of table name. + join(restriction: str = None) + Return list of joins for each chain in self.chains. """ def __init__(self, parent, child, connection=None): @@ -33,6 +65,14 @@ def __repr__(self): def __len__(self): return len([c for c in self.chains if c.has_link]) + def __getitem__(self, index: Union[int, str]): + """Return FreeTable object at index.""" + if isinstance(index, str): + for i, part in enumerate(self.part_names): + if index in part: + return self.chains[i] + return self.chains[index] + def join(self, restriction=None) -> List[QueryExpression]: """Return list of joins for each chain in self.chains.""" restriction = restriction or self.parent.restriction or True @@ -79,6 +119,8 @@ class TableChain: Return full representation of chain: parent -> {links} -> child. __len__() Return number of tables in chain. + __getitem__(index: Union[int, str]) + Return FreeTable object at index, or use substring of table name. join(restriction: str = None) Return join of tables in chain with restriction applied to parent. """ @@ -98,16 +140,14 @@ def __init__(self, parent: Table, child: Table, connection=None): self.parent = parent self.child = child self._has_link = child.full_table_name in parent.descendants() + self._errors = [] def __str__(self): """Return string representation of chain: parent -> child.""" if not self._has_link: return "No link" return ( - "Chain: " - + self.parent.table_name - + self._link_symbol - + self.child.table_name + self.parent.table_name + self._link_symbol + self.child.table_name ) def __repr__(self): @@ -123,6 +163,14 @@ def __len__(self): """Return number of tables in chain.""" return len(self.names) + def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: + """Return FreeTable object at index.""" + if isinstance(index, str): + for i, name in enumerate(self.names): + if index in name: + return self.objects[i] + return self.objects[index] + @property def has_link(self) -> bool: """Return True if parent is linked to child. @@ -132,20 +180,34 @@ def has_link(self) -> bool: """ return self._has_link + def pk_link(self, src, trg, data) -> float: + """Return 1 if data["primary"] else float("inf"). + + Currently unused. Preserved for future debugging.""" + return 1 if data["primary"] else float("inf") + @cached_property def names(self) -> List[str]: """Return list of full table names in chain. - Uses networkx.shortest_path. + Uses networkx.shortest_path. Ignores numeric table names, which are + 'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph + source code for comments on alias nodes. """ if not self._has_link: return None try: - return nx.shortest_path( - self.parent.connection.dependencies, - self.parent.full_table_name, - self.child.full_table_name, - ) + return [ + name + for name in nx.shortest_path( + self.parent.connection.dependencies, + self.parent.full_table_name, + self.child.full_table_name, + # weight: optional callable to determine edge weight + # weight=self.pk_link, + ) + if not name.isdigit() + ] except nx.NetworkXNoPath: self._has_link = False return None @@ -159,10 +221,23 @@ def objects(self) -> List[dj.FreeTable]: else None ) + def errors(self) -> List[str]: + """Return list of errors for each table in chain.""" + return self._errors + def join(self, restricton: str = None) -> dj.expression.QueryExpression: """Return join of tables in chain with restriction applied to parent.""" + if not self._has_link: + return None restriction = restricton or self.parent.restriction or True join = self.objects[0] & restriction for table in self.objects[1:]: - join = join * table - return join if join else None + try: + join = join.proj() * table + except dj.DataJointError as e: + attribute = str(e).split("attribute ")[-1] + logger.error( + f"{str(self)} at {table.table_name} with {attribute}" + ) + return None + return join diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 03f0ec08b..7632e0249 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -144,6 +144,12 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: merge_chains[name] = chains return merge_chains + def _get_chain(self, substring) -> TableChains: + """Return chain from self to merge table with substring in name.""" + for name, chain in self._merge_chains.items(): + if substring.lower() in name: + return chain + def _commit_merge_deletes( self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs ) -> None: @@ -206,8 +212,9 @@ def delete_downstream_merge( if not merge_join_dict and not disable_warning: logger.warning( - f"No merge tables found downstream of {self.full_table_name}." - + "\n\tIf this is unexpected, try running with `reload_cache`." + f"No merge deletes found w/ {self.table_name} & " + + f"{restriction}.\n\tIf this is unexpected, try running with " + + "`reload_cache`." ) if dry_run: @@ -352,7 +359,7 @@ def _usage_table(self): """Temporary inclusion for usage tracking.""" from spyglass.common.common_usage import CautiousDelete - return CautiousDelete + return CautiousDelete() def _log_use(self, start, merge_deletes=None): """Log use of cautious_delete.""" @@ -361,7 +368,9 @@ def _log_use(self, start, merge_deletes=None): duration=time() - start, dj_user=dj.config["database.user"], origin=self.full_table_name, - restriction=self.restriction, + restriction=( + str(self.restriction)[:255] if self.restriction else "None" + ), merge_deletes=merge_deletes, ) ) @@ -419,10 +428,15 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): self._log_use(start=start, merge_deletes=merge_deletes) - def cdel(self, *args, **kwargs): + def cdel(self, force_permission=False, *args, **kwargs): """Alias for cautious_delete.""" - self.cautious_delete(*args, **kwargs) + self.cautious_delete(force_permission=force_permission, *args, **kwargs) - def delete(self, *args, **kwargs): + def delete(self, force_permission=False, *args, **kwargs): """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" - self.cautious_delete(*args, **kwargs) + self.cautious_delete(force_permission=force_permission, *args, **kwargs) + + def super_delete(self, *args, **kwargs): + """Alias for datajoint.table.Table.delete.""" + logger.warning("!! Using super_delete. Bypassing cautious_delete !!") + super().delete(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 759ca43fa..60df55b3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,11 +243,18 @@ def mini_closed(mini_path): @pytest.fixture(autouse=True, scope="session") def mini_insert(mini_path, teardown, server, dj_conn): - from spyglass.common import Nwbfile, Session # noqa: E402 + from spyglass.common import LabMember, Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.spikesorting.merge import SpikeSortingOutput # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + LabMember().insert1( + ["Root User", "Root", "User"], skip_duplicates=not teardown + ) + LabMember.LabMemberInfo().insert1( + ["Root User", "email", "root", 1], skip_duplicates=not teardown + ) + dj_logger.info("Inserting test data.") if not server.connected: