diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index b76132551..2cb952d08 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List +from typing import List, Union import datajoint as dj import networkx as nx @@ -7,6 +7,7 @@ from datajoint.table import Table from datajoint.utils import get_master +from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger @@ -16,6 +17,38 @@ class TableChains: Functions as a plural version of TableChain, allowing a single `join` call across all chains from parent -> Merge table. + + Attributes + ---------- + parent : Table + Parent or origin of chains. + child : Table + Merge table or destination of chains. + connection : datajoint.Connection, optional + Connection to database used to create FreeTable objects. Defaults to + parent.connection. + part_names : List[str] + List of full table names of child parts. + chains : List[TableChain] + List of TableChain objects for each part in child. + has_link : bool + Cached attribute to store whether parent is linked to child via any of + child parts. False if (a) child is not in parent.descendants or (b) + nx.NetworkXNoPath is raised by nx.shortest_path for all chains. + + Methods + ------- + __init__(parent, child, connection=None) + Initialize TableChains with parent and child tables. + __repr__() + Return full representation of chains. + Multiline parent -> child for each chain. + __len__() + Return number of chains with links. + __getitem__(index: Union[int, str]) + Return TableChain object at index, or use substring of table name. + join(restriction: str = None) + Return list of joins for each chain in self.chains. """ def __init__(self, parent, child, connection=None): @@ -33,6 +66,14 @@ def __repr__(self): def __len__(self): return len([c for c in self.chains if c.has_link]) + def __getitem__(self, index: Union[int, str]) -> TableChain: + """Return FreeTable object at index.""" + if isinstance(index, str): + for i, part in enumerate(self.part_names): + if index in part: + return self.chains[i] + return self.chains[index] + def join(self, restriction=None) -> List[QueryExpression]: """Return list of joins for each chain in self.chains.""" restriction = restriction or self.parent.restriction or True @@ -79,6 +120,8 @@ class TableChain: Return full representation of chain: parent -> {links} -> child. __len__() Return number of tables in chain. + __getitem__(index: Union[int, str]) + Return FreeTable object at index, or use substring of table name. join(restriction: str = None) Return join of tables in chain with restriction applied to parent. """ @@ -98,16 +141,14 @@ def __init__(self, parent: Table, child: Table, connection=None): self.parent = parent self.child = child self._has_link = child.full_table_name in parent.descendants() + self._errors = [] def __str__(self): """Return string representation of chain: parent -> child.""" if not self._has_link: return "No link" return ( - "Chain: " - + self.parent.table_name - + self._link_symbol - + self.child.table_name + self.parent.table_name + self._link_symbol + self.child.table_name ) def __repr__(self): @@ -123,6 +164,14 @@ def __len__(self): """Return number of tables in chain.""" return len(self.names) + def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: + """Return FreeTable object at index.""" + if isinstance(index, str): + for i, name in enumerate(self.names): + if index in name: + return self.objects[i] + return self.objects[index] + @property def has_link(self) -> bool: """Return True if parent is linked to child. @@ -132,6 +181,12 @@ def has_link(self) -> bool: """ return self._has_link + def pk_link(self, src, trg, data) -> float: + """Return 1 if data["primary"] else float("inf"). + + Currently unused. Preserved for future debugging.""" + return 1 if data["primary"] else float("inf") + @cached_property def names(self) -> List[str]: """Return list of full table names in chain. @@ -141,11 +196,17 @@ def names(self) -> List[str]: if not self._has_link: return None try: - return nx.shortest_path( - self.parent.connection.dependencies, - self.parent.full_table_name, - self.child.full_table_name, - ) + return [ + name + for name in nx.shortest_path( + self.parent.connection.dependencies, + self.parent.full_table_name, + self.child.full_table_name, + # weight: optional callable to determine edge weight + # weight=self.pk_link, + ) + if not name.isdigit() + ] except nx.NetworkXNoPath: self._has_link = False return None @@ -159,10 +220,22 @@ def objects(self) -> List[dj.FreeTable]: else None ) + def errors(self) -> List[str]: + """Return list of errors for each table in chain.""" + return self._errors + def join(self, restricton: str = None) -> dj.expression.QueryExpression: """Return join of tables in chain with restriction applied to parent.""" + if not self._has_link: + return None restriction = restricton or self.parent.restriction or True join = self.objects[0] & restriction for table in self.objects[1:]: - join = join * table - return join if join else None + try: + join = join.proj() * table + except dj.DataJointError as e: + logger.error( + f"{str(self)} at {table.table_name} with {attribute}" + ) + return None + return join diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 884608d13..4ef6717a0 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,18 +1,3 @@ -<<<<<<< HEAD -from time import time -from typing import Dict, List - -import datajoint as dj -import networkx as nx -from datajoint.table import logger as dj_logger -from datajoint.user_tables import Table, TableMeta -from datajoint.utils import get_master, user_choice - -from spyglass.utils.database_settings import SHARED_MODULES -from spyglass.utils.dj_helper_fn import fetch_nwb -from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK -from spyglass.utils.dj_merge_tables import Merge -======= from functools import cached_property from time import time from typing import Dict, List, Union @@ -26,7 +11,6 @@ from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 from spyglass.utils.logging import logger @@ -61,32 +45,13 @@ class SpyglassMixin: raised. `force_permission` can be set to True to bypass permission check. cdel(*args, **kwargs) Alias for cautious_delete. - delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False) - Delete downstream merge table entries associated with restricton. - Requires caching of merge tables and links, which is slow on first call. - `restriction` can be set to a string to restrict the delete. `dry_run` - can be set to False to commit the delete. `reload_cache` can be set to - True to reload the merge cache. - ddm(*args, **kwargs) - Alias for delete_downstream_merge. """ # _nwb_table = None # NWBFile table class, defined at the table level -<<<<<<< HEAD - _nwb_table_resolved = None # NWBFiletable class, resolved here from above - _delete_dependencies = [] # Session, LabMember, LabTeam, delay import - # pks for delete permission check, assumed to be on field -======= # pks for delete permission check, assumed to be one field for each ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure - _merge_table_cache = {} # Cache of merge tables downstream of self - _merge_chains_cache = {} # Cache of table chains to merges - _session_connection_cache = None # Cache of path from Session to self - _test_mode_cache = None # Cache of test mode setting for delete - _usage_table_cache = None # Temporary inclusion for usage tracking # ------------------------------- fetch_nwb ------------------------------- @@ -96,10 +61,7 @@ def _nwb_table_tuple(self) -> tuple: Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. Implemented as a cached_property to avoid circular imports.""" - from spyglass.common.common_nwbfile import ( - AnalysisNwbfile, - Nwbfile, - ) # noqa F401 + from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile # noqa F401 table_dict = { AnalysisNwbfile: "analysis_file_abs_path", @@ -109,7 +71,9 @@ def _nwb_table_tuple(self) -> tuple: resolved = getattr(self, "_nwb_table", None) or ( AnalysisNwbfile if "-> AnalysisNwbfile" in self.definition - else Nwbfile if "-> Nwbfile" in self.definition else None + else Nwbfile + if "-> Nwbfile" in self.definition + else None ) if not resolved: @@ -179,6 +143,12 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: merge_chains[name] = chains return merge_chains + def _get_chain(self, substring) -> TableChains: + """Return chain from self to merge table with substring in name.""" + for name, chain in self._merge_chains.items(): + if substring.lower() in name: + return chain + def _commit_merge_deletes( self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs ) -> None: @@ -190,74 +160,6 @@ def _commit_merge_deletes( Dictionary of merge tables and their joins. Uses 'merge_id' primary key to restrict delete. -<<<<<<< HEAD - @property - def _test_mode(self) -> bool: - """Return True if test mode is enabled.""" - if not self._test_mode_cache: - from spyglass.settings import test_mode - - self._test_mode_cache = test_mode - return self._test_mode_cache - - @property - def _merge_tables(self) -> Dict[str, dj.FreeTable]: - """Dict of merge tables downstream of self. - - Cache of items in parents of self.descendants(as_objects=True) that - have a merge primary key. - """ - if self._merge_table_cache: - return self._merge_table_cache - - def has_merge_pk(table): - return MERGE_PK in table.heading.names - - self.connection.dependencies.load() - for desc in self.descendants(as_objects=True): - if not has_merge_pk(desc): - continue - if not (master_name := get_master(desc.full_table_name)): - continue - master = dj.FreeTable(self.connection, master_name) - if has_merge_pk(master): - self._merge_table_cache[master_name] = master - logger.info( - f"Building merge cache for {self.table_name}.\n\t" - + f"Found {len(self._merge_table_cache)} downstream merge tables" - ) - - return self._merge_table_cache - - @property - def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: - """Dict of merge links downstream of self. - - For each merge table found in _merge_tables, find the path from self to - merge. If the path is valid, add it to the dict. Cache prevents need - to recompute whenever delete_downstream_merge is called with a new - restriction. - """ - if self._merge_chains_cache: - return self._merge_chains_cache - - for name, merge_table in self._merge_tables.items(): - chains = TableChains(self, merge_table, connection=self.connection) - if len(chains): - self._merge_chains_cache[name] = chains - return self._merge_chains_cache - - def _commit_merge_deletes(self, merge_join_dict, **kwargs): - """Commit merge deletes. - - Extracted for use in cautious_delete and delete_downstream_merge.""" - for table_name, part_restr in merge_join_dict.items(): - table = self._merge_tables[table_name] - keys = [part.fetch(MERGE_PK, as_dict=True) for part in part_restr] - (table & keys).delete(**kwargs) - - def delete_downstream_merge( -======= Extracted for use in cautious_delete and delete_downstream_merge.""" for table_name, part_restr in merge_join_dict.items(): table = self._merge_tables[table_name] @@ -309,64 +211,9 @@ def delete_downstream_merge( if not merge_join_dict and not disable_warning: logger.warning( - f"No merge tables found downstream of {self.full_table_name}." - + "\n\tIf this is unexpected, try running with `reload_cache`." - ) - - if dry_run: - return merge_join_dict.values() if return_parts else merge_join_dict - - self._commit_merge_deletes(merge_join_dict, **kwargs) - - def ddm( ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 - self, - restriction: str = None, - dry_run: bool = True, - reload_cache: bool = False, - disable_warning: bool = False, - return_parts: bool = True, -<<<<<<< HEAD - **kwargs, - ) -> List[dj.expression.QueryExpression]: - """Delete downstream merge table entries associated with restricton. - - Requires caching of merge tables and links, which is slow on first call. - - Parameters - ---------- - restriction : str, optional - Restriction to apply to merge tables. Default None. Will attempt to - use table restriction if None. - dry_run : bool, optional - If True, return list of merge part entries to be deleted. Default - True. - reload_cache : bool, optional - If True, reload merge cache. Default False. - disable_warning : bool, optional - If True, do not warn if no merge tables found. Default False. - return_parts : bool, optional - If True, return list of merge part entries to be deleted. Default - True. If False, return dictionary of merge tables and their joins. - **kwargs : Any - Passed to datajoint.table.Table.delete. - """ - if reload_cache: - self._merge_table_cache = {} - self._merge_chains_cache = {} - - restriction = restriction or self.restriction or True - - merge_join_dict = {} - for name, chain in self._merge_chains.items(): - join = chain.join(restriction) - if join: - merge_join_dict[name] = join - - if not merge_join_dict and not disable_warning: - logger.warning( - f"No merge tables found downstream of {self.full_table_name}." - + "\n\tIf this is unexpected, try running with `reload_cache`." + f"No merge deletes found w/ {self.table_name} & " + + f"{restriction}.\n\tIf this is unexpected, try running with " + + "`reload_cache`." ) if dry_run: @@ -383,20 +230,6 @@ def ddm( return_parts: bool = True, *args, **kwargs, - ): - """Alias for delete_downstream_merge.""" - return self.delete_downstream_merge( - restriction=restriction, - dry_run=dry_run, - reload_cache=reload_cache, - disable_warning=disable_warning, - return_parts=return_parts, - *args, - **kwargs, - ) -======= - *args, - **kwargs, ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: """Alias for delete_downstream_merge.""" return self.delete_downstream_merge( @@ -425,7 +258,6 @@ def _delete_deps(self) -> List[Table]: self._session_pk = Session.primary_key[0] self._member_pk = LabMember.primary_key[0] return [LabMember, LabTeam, Session] ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. @@ -452,20 +284,6 @@ def _get_exp_summary(self): return exp_missing + exp_present -<<<<<<< HEAD - @property - def _session_connection(self) -> dj.expression.QueryExpression: - """Path from Session table to self. - - None is not yet cached, False if no connection found. - """ - if self._session_connection_cache is None: - connection = TableChain(parent=self._delete_deps[-1], child=self) - self._session_connection_cache = ( - connection if connection.has_link else False - ) - return self._session_connection_cache -======= @cached_property def _session_connection(self) -> Union[TableChain, bool]: """Path from Session table to self. False if no connection found.""" @@ -480,7 +298,6 @@ def _test_mode(self) -> bool: from spyglass.settings import test_mode return test_mode ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 def _check_delete_permission(self) -> None: """Check user name against lab team assoc. w/ self -> Session. @@ -536,23 +353,12 @@ def _check_delete_permission(self) -> None: ) logger.info(f"Queueing delete for session(s):\n{sess_summary}") -<<<<<<< HEAD - @property - def _usage_table(self): - """Temporary inclusion for usage tracking.""" - if not self._usage_table_cache: - from spyglass.common.common_usage import CautiousDelete - - self._usage_table_cache = CautiousDelete - return self._usage_table_cache -======= @cached_property def _usage_table(self): """Temporary inclusion for usage tracking.""" from spyglass.common.common_usage import CautiousDelete return CautiousDelete ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 def _log_use(self, start, merge_deletes=None): """Log use of cautious_delete.""" @@ -566,10 +372,6 @@ def _log_use(self, start, merge_deletes=None): ) ) -<<<<<<< HEAD - # Rename to `delete` when we're ready to use it -======= ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5 # TODO: Intercept datajoint delete confirmation prompt for merge deletes def cautious_delete(self, force_permission: bool = False, *args, **kwargs): """Delete table rows after checking user permission. @@ -627,137 +429,6 @@ def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) -<<<<<<< HEAD - -class TableChains: - """Class for representing chains from parent to Merge table via parts.""" - - def __init__(self, parent, child, connection=None): - self.parent = parent - self.child = child - self.connection = connection or parent.connection - parts = child.parts(as_objects=True) - self.part_names = [part.full_table_name for part in parts] - self.chains = [TableChain(parent, part) for part in parts] - self.has_link = any([chain.has_link for chain in self.chains]) - - def __repr__(self): - return "\n".join([str(chain) for chain in self.chains]) - - def __len__(self): - return len([c for c in self.chains if c.has_link]) - - def join(self, restriction=None): - restriction = restriction or self.parent.restriction or True - joins = [] - for chain in self.chains: - if joined := chain.join(restriction): - joins.append(joined) - return joins - - -class TableChain: - """Class for representing a chain of tables. - - Note: Parent -> Merge should use TableChains instead. - """ - - def __init__(self, parent: Table, child: Table, connection=None): - self._connection = connection or parent.connection - if not self._connection.dependencies._loaded: - self._connection.dependencies.load() - - if ( # if child is a merge table - get_master(child.full_table_name) == "" - and MERGE_PK in child.heading.names - ): - logger.error("Child is a merge table. Use TableChains instead.") - - self._link_symbol = " -> " - self.parent = parent - self.child = child - self._repr = None - self._names = None # full table names of tables in chain - self._objects = None # free tables in chain - self._has_link = child.full_table_name in parent.descendants() - - def __str__(self): - """Return string representation of chain: parent -> child.""" - if not self._has_link: - return "No link" - return ( - "Chain: " - + self.parent.table_name - + self._link_symbol - + self.child.table_name - ) - - def __repr__(self): - """Return full representation of chain: parent -> {links} -> child.""" - if self._repr: - return self._repr - self._repr = ( - "Chain: " - + self._link_symbol.join([t.table_name for t in self.objects]) - if self.names - else "No link" - ) - return self._repr - - def __len__(self): - """Return number of tables in chain.""" - return len(self.names) - - @property - def has_link(self) -> bool: - """Return True if parent is linked to child. - - Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath - is raised by nx.shortest_path. - """ - return self._has_link - - @property - def names(self) -> List[str]: - """Return list of full table names in chain. - - Uses networkx.shortest_path. - """ - if not self._has_link: - return None - if self._names: - return self._names - try: - self._names = nx.shortest_path( - self.parent.connection.dependencies, - self.parent.full_table_name, - self.child.full_table_name, - ) - return self._names - except nx.NetworkXNoPath: - self._has_link = False - return None - - @property - def objects(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain.""" - if not self._objects: - self._objects = ( - [dj.FreeTable(self._connection, name) for name in self.names] - if self.names - else None - ) - return self._objects - - def join(self, restricton: str = None) -> dj.expression.QueryExpression: - """Return join of tables in chain with restriction applied to parent.""" - restriction = restricton or self.parent.restriction or True - join = self.objects[0] & restriction - for table in self.objects[1:]: - join = join * table - return join if join else None -======= def delete(self, *args, **kwargs): """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" self.cautious_delete(*args, **kwargs) ->>>>>>> b42432f0884ef3e990e111c53102a64529598ae5