Skip to content

Commit

Permalink
Fix join error for position merge TableChain
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 31, 2024
1 parent b4e770e commit fa6d70e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 354 deletions.
97 changes: 85 additions & 12 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import cached_property
from typing import List
from typing import List, Union

import datajoint as dj
import networkx as nx
from datajoint.expression import QueryExpression
from datajoint.table import Table
from datajoint.utils import get_master

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.logging import logger

Expand All @@ -16,6 +17,38 @@ class TableChains:
Functions as a plural version of TableChain, allowing a single `join`
call across all chains from parent -> Merge table.
Attributes
----------
parent : Table
Parent or origin of chains.
child : Table
Merge table or destination of chains.
connection : datajoint.Connection, optional
Connection to database used to create FreeTable objects. Defaults to
parent.connection.
part_names : List[str]
List of full table names of child parts.
chains : List[TableChain]
List of TableChain objects for each part in child.
has_link : bool
Cached attribute to store whether parent is linked to child via any of
child parts. False if (a) child is not in parent.descendants or (b)
nx.NetworkXNoPath is raised by nx.shortest_path for all chains.
Methods
-------
__init__(parent, child, connection=None)
Initialize TableChains with parent and child tables.
__repr__()
Return full representation of chains.
Multiline parent -> child for each chain.
__len__()
Return number of chains with links.
__getitem__(index: Union[int, str])
Return TableChain object at index, or use substring of table name.
join(restriction: str = None)
Return list of joins for each chain in self.chains.
"""

def __init__(self, parent, child, connection=None):
Expand All @@ -33,6 +66,14 @@ def __repr__(self):
def __len__(self):
return len([c for c in self.chains if c.has_link])

def __getitem__(self, index: Union[int, str]) -> TableChain:
"""Return FreeTable object at index."""
if isinstance(index, str):
for i, part in enumerate(self.part_names):
if index in part:
return self.chains[i]
return self.chains[index]

def join(self, restriction=None) -> List[QueryExpression]:
"""Return list of joins for each chain in self.chains."""
restriction = restriction or self.parent.restriction or True
Expand Down Expand Up @@ -79,6 +120,8 @@ class TableChain:
Return full representation of chain: parent -> {links} -> child.
__len__()
Return number of tables in chain.
__getitem__(index: Union[int, str])
Return FreeTable object at index, or use substring of table name.
join(restriction: str = None)
Return join of tables in chain with restriction applied to parent.
"""
Expand All @@ -98,16 +141,14 @@ def __init__(self, parent: Table, child: Table, connection=None):
self.parent = parent
self.child = child
self._has_link = child.full_table_name in parent.descendants()
self._errors = []

def __str__(self):
"""Return string representation of chain: parent -> child."""
if not self._has_link:
return "No link"
return (
"Chain: "
+ self.parent.table_name
+ self._link_symbol
+ self.child.table_name
self.parent.table_name + self._link_symbol + self.child.table_name
)

def __repr__(self):
Expand All @@ -123,6 +164,14 @@ def __len__(self):
"""Return number of tables in chain."""
return len(self.names)

def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
"""Return FreeTable object at index."""
if isinstance(index, str):
for i, name in enumerate(self.names):
if index in name:
return self.objects[i]
return self.objects[index]

@property
def has_link(self) -> bool:
"""Return True if parent is linked to child.
Expand All @@ -132,6 +181,12 @@ def has_link(self) -> bool:
"""
return self._has_link

def pk_link(self, src, trg, data) -> float:
"""Return 1 if data["primary"] else float("inf").
Currently unused. Preserved for future debugging."""
return 1 if data["primary"] else float("inf")

@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain.
Expand All @@ -141,11 +196,17 @@ def names(self) -> List[str]:
if not self._has_link:
return None
try:
return nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
)
return [
name
for name in nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
# weight: optional callable to determine edge weight
# weight=self.pk_link,
)
if not name.isdigit()
]
except nx.NetworkXNoPath:
self._has_link = False
return None
Expand All @@ -159,10 +220,22 @@ def objects(self) -> List[dj.FreeTable]:
else None
)

def errors(self) -> List[str]:
"""Return list of errors for each table in chain."""
return self._errors

def join(self, restricton: str = None) -> dj.expression.QueryExpression:
"""Return join of tables in chain with restriction applied to parent."""
if not self._has_link:
return None
restriction = restricton or self.parent.restriction or True
join = self.objects[0] & restriction
for table in self.objects[1:]:
join = join * table
return join if join else None
try:
join = join.proj() * table
except dj.DataJointError as e:
logger.error(
f"{str(self)} at {table.table_name} with {attribute}"
)
return None
return join
Loading

0 comments on commit fa6d70e

Please sign in to comment.