From 67183dec787ddf42ebd20fccfacb5b98d756b3ed Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Wed, 7 Feb 2024 08:29:14 -0800 Subject: [PATCH] #820 --- src/spyglass/utils/dj_chains.py | 27 ++++++++++++++++++--------- src/spyglass/utils/dj_mixin.py | 13 +++++++------ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index d8152d020..275d859c3 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -135,12 +135,12 @@ def __init__(self, parent: Table, child: Table, connection=None): and MERGE_PK in child.heading.names ): logger.error("Child is a merge table. Use TableChains instead.") + return self._link_symbol = " -> " self.parent = parent self.child = child self._has_link = child.full_table_name in parent.descendants() - self._errors = [] def __str__(self): """Return string representation of chain: parent -> child.""" @@ -221,17 +221,26 @@ def objects(self) -> List[dj.FreeTable]: else None ) - def errors(self) -> List[str]: - """Return list of errors for each table in chain.""" - return self._errors - - def join(self, restricton: str = None) -> dj.expression.QueryExpression: - """Return join of tables in chain with restriction applied to parent.""" + def join( + self, restricton: str = None, reverse_order: bool = False + ) -> dj.expression.QueryExpression: + """Return join of tables in chain with restriction applied to parent. + + Parameters + ---------- + restriction : str, optional + Restriction to apply to first table in the order. + Defaults to self.parent.restriction. + reverse_order : bool, optional + If True, join tables in reverse order. Defaults to False. + """ if not self._has_link: return None + + objects = self.objects[::-1] if reverse_order else self.objects restriction = restricton or self.parent.restriction or True - join = self.objects[0] & restriction - for table in self.objects[1:]: + join = objects[0] & restriction + for table in objects[1:]: try: join = join.proj() * table except dj.DataJointError as e: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 7632e0249..bd6e0b244 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -61,10 +61,7 @@ def _nwb_table_tuple(self) -> tuple: Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. Implemented as a cached_property to avoid circular imports.""" - from spyglass.common.common_nwbfile import ( - AnalysisNwbfile, - Nwbfile, - ) # noqa F401 + from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile # noqa F401 table_dict = { AnalysisNwbfile: "analysis_file_abs_path", @@ -74,7 +71,9 @@ def _nwb_table_tuple(self) -> tuple: resolved = getattr(self, "_nwb_table", None) or ( AnalysisNwbfile if "-> AnalysisNwbfile" in self.definition - else Nwbfile if "-> Nwbfile" in self.definition else None + else Nwbfile + if "-> Nwbfile" in self.definition + else None ) if not resolved: @@ -278,7 +277,9 @@ def _get_exp_summary(self): empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) - sess_link = self._session_connection.join(self.restriction) + sess_link = self._session_connection.join( + self.restriction, reverse_order=True + ) exp_missing = format & (sess_link - SesExp).proj(**empty_pk) exp_present = format & (sess_link * SesExp - exp_missing).proj()