Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Feb 7, 2024
1 parent a9fced6 commit 67183de
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
27 changes: 18 additions & 9 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def __init__(self, parent: Table, child: Table, connection=None):
and MERGE_PK in child.heading.names
):
logger.error("Child is a merge table. Use TableChains instead.")
return

self._link_symbol = " -> "
self.parent = parent
self.child = child
self._has_link = child.full_table_name in parent.descendants()
self._errors = []

def __str__(self):
"""Return string representation of chain: parent -> child."""
Expand Down Expand Up @@ -221,17 +221,26 @@ def objects(self) -> List[dj.FreeTable]:
else None
)

def errors(self) -> List[str]:
"""Return list of errors for each table in chain."""
return self._errors

def join(self, restricton: str = None) -> dj.expression.QueryExpression:
"""Return join of tables in chain with restriction applied to parent."""
def join(
self, restricton: str = None, reverse_order: bool = False
) -> dj.expression.QueryExpression:
"""Return join of tables in chain with restriction applied to parent.
Parameters
----------
restriction : str, optional
Restriction to apply to first table in the order.
Defaults to self.parent.restriction.
reverse_order : bool, optional
If True, join tables in reverse order. Defaults to False.
"""
if not self._has_link:
return None

objects = self.objects[::-1] if reverse_order else self.objects
restriction = restricton or self.parent.restriction or True
join = self.objects[0] & restriction
for table in self.objects[1:]:
join = objects[0] & restriction
for table in objects[1:]:
try:
join = join.proj() * table
except dj.DataJointError as e:
Expand Down
13 changes: 7 additions & 6 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def _nwb_table_tuple(self) -> tuple:
Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb.
Implemented as a cached_property to avoid circular imports."""
from spyglass.common.common_nwbfile import (
AnalysisNwbfile,
Nwbfile,
) # noqa F401
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile # noqa F401

table_dict = {
AnalysisNwbfile: "analysis_file_abs_path",
Expand All @@ -74,7 +71,9 @@ def _nwb_table_tuple(self) -> tuple:
resolved = getattr(self, "_nwb_table", None) or (
AnalysisNwbfile
if "-> AnalysisNwbfile" in self.definition
else Nwbfile if "-> Nwbfile" in self.definition else None
else Nwbfile
if "-> Nwbfile" in self.definition
else None
)

if not resolved:
Expand Down Expand Up @@ -278,7 +277,9 @@ def _get_exp_summary(self):
empty_pk = {self._member_pk: "NULL"}

format = dj.U(self._session_pk, self._member_pk)
sess_link = self._session_connection.join(self.restriction)
sess_link = self._session_connection.join(
self.restriction, reverse_order=True
)

exp_missing = format & (sess_link - SesExp).proj(**empty_pk)
exp_present = format & (sess_link * SesExp - exp_missing).proj()
Expand Down

0 comments on commit 67183de

Please sign in to comment.