From a508b57fc507d787c1a0b8279c0cc8d28212ee99 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 09:24:19 -0700 Subject: [PATCH 01/11] Fix tests/doc build (#967) * Fix errored test teardown * Docs build pin, docstring fixes * Update overview page * Update changelog * Fix 0.X -> 1.X, remove sponsor-only option --- CHANGELOG.md | 1 + docs/build-docs.sh | 2 +- docs/mkdocs.yml | 2 ++ docs/src/misc/index.md | 4 +++- pyproject.toml | 2 +- src/spyglass/common/common_interval.py | 4 ++-- src/spyglass/position/v1/dlc_reader.py | 11 ++++++++--- src/spyglass/settings.py | 18 +++++++++--------- tests/utils/test_mixin.py | 2 -- 9 files changed, 27 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f09af8c0..5e27b7bb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 +- Pin `mkdocstring-python` to `1.9.0`, fix existing docstrings. #967 ## [0.5.2] (April 22, 2024) diff --git a/docs/build-docs.sh b/docs/build-docs.sh index bb9fa154a..50d44f511 100755 --- a/docs/build-docs.sh +++ b/docs/build-docs.sh @@ -7,7 +7,7 @@ cp ./CHANGELOG.md ./docs/src/ cp ./LICENSE ./docs/src/LICENSE.md mkdir -p ./docs/src/notebooks -rm -r ./docs/src/notebooks/* +rm -fr ./docs/src/notebooks/* cp ./notebooks/*ipynb ./docs/src/notebooks/ cp ./notebooks/*md ./docs/src/notebooks/ mv ./docs/src/notebooks/README.md ./docs/src/notebooks/index.md diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 920b646a7..1aaaa437e 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -74,6 +74,7 @@ nav: - FigURL: misc/figurl_views.md - Session Groups: misc/session_groups.md - Insert Data: misc/insert_data.md + - Mixin: misc/mixin.md - Merge Tables: misc/merge_tables.md - Database Management: misc/database_management.md - Export: misc/export.md @@ -100,6 +101,7 @@ plugins: default_handler: python handlers: python: + paths: [src] options: members_order: source group_by_category: false diff --git a/docs/src/misc/index.md b/docs/src/misc/index.md index 9b3991cb6..b9971a81c 100644 --- a/docs/src/misc/index.md +++ b/docs/src/misc/index.md @@ -3,7 +3,9 @@ This folder contains miscellaneous supporting files documentation. - [Database Management](./database_management.md) +- [Export](./export.md) - [figurl Views](./figurl_views.md) -- [insert Data](./insert_data.md) +- [Insert Data](./insert_data.md) - [Merge Tables](./merge_tables.md) +- [Mixin Class](./mixin.md) - [Session Groups](./session_groups.md) diff --git a/pyproject.toml b/pyproject.toml index ffb8d0df6..28cc12633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ docs = [ "mkdocs-jupyter", # Docs render notebooks "mkdocs-literate-nav", # Dynamic page list for API docs "mkdocs-material", # Docs theme - "mkdocstrings[python]", # Docs API docstrings + "mkdocstrings[python]<=1.9.0", # Docs API docstrings ] [tool.hatch.version] diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 39c676f5a..66e82bda8 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -256,6 +256,8 @@ def consolidate_intervals(interval_list): def interval_list_intersect(interval_list1, interval_list2, min_length=0): """Finds the intersections between two interval lists + Each interval is (start time, stop time) + Parameters ---------- interval_list1 : np.array, (N,2) where N = number of intervals @@ -263,8 +265,6 @@ def interval_list_intersect(interval_list1, interval_list2, min_length=0): min_length : float, optional. Minimum length of intervals to include, default 0 - Each interval is (start time, stop time) - Returns ------- interval_list: np.array, (N,2) diff --git a/src/spyglass/position/v1/dlc_reader.py b/src/spyglass/position/v1/dlc_reader.py index 8d6c18c23..c2e56063f 100644 --- a/src/spyglass/position/v1/dlc_reader.py +++ b/src/spyglass/position/v1/dlc_reader.py @@ -161,10 +161,15 @@ def read_yaml(fullpath, filename="*"): Parameters ---------- - fullpath: String or pathlib path. Directory with yaml files - filename: String. Filename, no extension. Permits wildcards. + fullpath: Union[str, pathlib.Path] + Directory with yaml files + filename: str + Filename, no extension. Permits wildcards. - Returns filepath and contents as dict + Returns + ------- + tuple + filepath and contents as dict """ from deeplabcut.utils.auxiliaryfunctions import read_config diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index be2912c9d..d9d469bba 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -20,7 +20,7 @@ class SpyglassConfig: facilitate testing. """ - def __init__(self, base_dir: str = None, **kwargs): + def __init__(self, base_dir: str = None, **kwargs) -> None: """ Initializes a new instance of the class. @@ -103,7 +103,7 @@ def load_config( force_reload=False, on_startup: bool = False, **kwargs, - ): + ) -> None: """ Loads the configuration settings for the object. @@ -223,25 +223,25 @@ def load_config( return self._config - def _load_env_vars(self): + def _load_env_vars(self) -> dict: loaded_dict = {} for var, val in self.env_defaults.items(): loaded_dict[var] = os.getenv(var, val) return loaded_dict - def _set_env_with_dict(self, env_dict): + def _set_env_with_dict(self, env_dict) -> None: # NOTE: Kept for backwards compatibility. Should be removed in future # for custom paths. Keep self.env_defaults. for var, val in env_dict.items(): os.environ[var] = str(val) - def _mkdirs_from_dict_vals(self, dir_dict): + def _mkdirs_from_dict_vals(self, dir_dict) -> None: if self._debug_mode: return for dir_str in dir_dict.values(): Path(dir_str).mkdir(exist_ok=True) - def _set_dj_config_stores(self, check_match=True, set_stores=True): + def _set_dj_config_stores(self, check_match=True, set_stores=True) -> None: """ Checks dj.config['stores'] match resolved dirs. Ensures stores set. @@ -287,7 +287,7 @@ def _set_dj_config_stores(self, check_match=True, set_stores=True): return - def dir_to_var(self, dir: str, dir_type: str = "spyglass"): + def dir_to_var(self, dir: str, dir_type: str = "spyglass") -> str: """Converts a dir string to an env variable name.""" return f"{dir_type.upper()}_{dir.upper()}_DIR" @@ -300,7 +300,7 @@ def _generate_dj_config( database_port: int = 3306, database_use_tls: bool = True, **kwargs, - ): + ) -> dict: """Generate a datajoint configuration file. Parameters @@ -345,7 +345,7 @@ def save_dj_config( base_dir=None, set_password=True, **kwargs, - ): + ) -> None: """Set the dj.config parameters, set password, and save config to file. Parameters diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index ac5c74bfe..faa823c8e 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -15,8 +15,6 @@ class Mixin(SpyglassMixin, dj.Manual): yield Mixin - Mixin().drop_quick() - @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_bad_prefix(caplog, dj_conn, Mixin): From 280406ba39b5be6d4255735d1805f1cfea4cee3a Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 12:18:43 -0500 Subject: [PATCH 02/11] Pin `mkdocstrings-python<=1.9.0` --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 28cc12633..2de8b3244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,8 @@ docs = [ "mkdocs-jupyter", # Docs render notebooks "mkdocs-literate-nav", # Dynamic page list for API docs "mkdocs-material", # Docs theme - "mkdocstrings[python]<=1.9.0", # Docs API docstrings + "mkdocstrings[python]", # Docs API docstrings + "mkdocstrings-python<=1.9.0" # Pinned #976 ] [tool.hatch.version] From fcde4c7f77d213896cdbfdae308d7fe8baffa67f Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 11:56:33 -0700 Subject: [PATCH 03/11] Fix relative pathing for mkdocstrings>=1.9.1 (#968) * Fix relative pathing for mkdocstrings>=1.9.1 * Update changelog --- CHANGELOG.md | 2 +- docs/mkdocs.yml | 2 +- docs/src/api/make_pages.py | 5 ++++- pyproject.toml | 1 - 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e27b7bb9..0644c3fc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 -- Pin `mkdocstring-python` to `1.9.0`, fix existing docstrings. #967 +- Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 ## [0.5.2] (April 22, 2024) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 1aaaa437e..acec4f829 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -101,12 +101,12 @@ plugins: default_handler: python handlers: python: - paths: [src] options: members_order: source group_by_category: false line_length: 80 docstring_style: numpy + paths: [../src] - literate-nav: nav_file: navigation.md - exclude-search: diff --git a/docs/src/api/make_pages.py b/docs/src/api/make_pages.py index 6886d50f4..d324919ce 100644 --- a/docs/src/api/make_pages.py +++ b/docs/src/api/make_pages.py @@ -16,8 +16,11 @@ if path.stem in ignored_stems or "cython" in path.stem: continue rel_path = path.relative_to("src/spyglass") + + # parts[0] is the src directory, ignore as of mkdocstrings-python 1.9.1 + module_path = ".".join([p for p in path.with_suffix("").parts[1:]]) + with mkdocs_gen_files.open(f"api/{rel_path.with_suffix('')}.md", "w") as f: - module_path = ".".join([p for p in path.with_suffix("").parts]) print(f"::: {module_path}", file=f) nav[rel_path.parts] = f"{rel_path.with_suffix('')}.md" diff --git a/pyproject.toml b/pyproject.toml index 2de8b3244..ffb8d0df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,6 @@ docs = [ "mkdocs-literate-nav", # Dynamic page list for API docs "mkdocs-material", # Docs theme "mkdocstrings[python]", # Docs API docstrings - "mkdocstrings-python<=1.9.0" # Pinned #976 ] [tool.hatch.version] From 2f6634b740c0c2fef71108a24ce8e34cbef473f6 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 12:01:50 -0700 Subject: [PATCH 04/11] Long distance restrictions (#949) * initial commit for restrict_from_upstream * Add tests for RestrGraph * WIP: ABC for RestrGraph * WIP: ABC for RestrGraph 2 * WIP: ABC for RestrGraph 3 * WIP: Operator for 'find upstream key' * WIP: Handle all alias cases in _bridge_restr * WIP: Add tests * WIP: Cascade through merge tables * WIP: add docs * WIP: Revise tests * WIP: Add way to ban item from search * Revert pytest options * Fix failing tests * Bail on cascade if restr empty * Update src/spyglass/utils/dj_mixin.py Co-authored-by: Samuel Bray * Permit dict/list-of-dict restr on long-distance restrict --------- Co-authored-by: Sam Bray Co-authored-by: Eric Denovellis --- CHANGELOG.md | 1 + docs/src/misc/mixin.md | 61 +- src/spyglass/utils/dj_chains.py | 373 -------- src/spyglass/utils/dj_graph.py | 1194 ++++++++++++++++++++----- src/spyglass/utils/dj_helper_fn.py | 28 +- src/spyglass/utils/dj_merge_tables.py | 71 +- src/spyglass/utils/dj_mixin.py | 231 ++++- src/spyglass/utils/nwb_helper_fn.py | 2 +- tests/common/test_device.py | 2 +- tests/conftest.py | 365 +++++++- tests/container.py | 2 +- tests/lfp/conftest.py | 133 --- tests/lfp/test_lfp.py | 5 - tests/linearization/conftest.py | 142 --- tests/linearization/test_lin.py | 2 +- tests/utils/__init__.py | 0 tests/utils/conftest.py | 241 ++++- tests/utils/test_chains.py | 25 +- tests/utils/test_graph.py | 143 +++ tests/utils/test_mixin.py | 43 +- 20 files changed, 2057 insertions(+), 1007 deletions(-) delete mode 100644 src/spyglass/utils/dj_chains.py delete mode 100644 tests/linearization/conftest.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_graph.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0644c3fc9..231e328d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 +- Add long-distance restrictions via `<<` and `>>` operators. #943 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 ## [0.5.2] (April 22, 2024) diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 747a12f9f..6b3884551 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -4,6 +4,7 @@ The Spyglass Mixin provides a way to centralize all Spyglass-specific functionalities that have been added to DataJoint tables. This includes... - Fetching NWB files +- Long-distance restrictions. - Delete functionality, including permission checks and part/master pairs - Export logging. See [export doc](export.md) for more information. @@ -11,16 +12,14 @@ To add this functionality to your own tables, simply inherit from the mixin: ```python import datajoint as dj + from spyglass.utils import SpyglassMixin -schema = dj.schema('my_schema') +schema = dj.schema("my_schema") -@schema -class MyOldTable(dj.Manual): - pass @schema -class MyNewTable(SpyglassMixin, dj.Manual):) +class MyOldTable(dj.Manual): pass ``` @@ -44,6 +43,58 @@ should be fetched from `Nwbfile` or an analysis file should be fetched from `AnalysisNwbfile`. If neither is foreign-key-referenced, the function will refer to a `_nwb_table` attribute. +## Long-Distance Restrictions + +In complicated pipelines like Spyglass, there are often tables that 'bury' their +foreign keys as secondary keys. This is done to avoid having to pass a long list +of foreign keys through the pipeline, potentially hitting SQL limits (see also +[Merge Tables](./merge_tables.md)). This burrying makes it difficult to restrict +a given table by familiar attributes. + +Spyglass provides a function, `restrict_by`, to handle this. The function takes +your restriction and checks parents/children until the restriction can be +applied. Spyglass introduces `<<` as a shorthand for `restrict_by` an upstream +key and `>>` as a shorthand for `restrict_by` a downstream key. + +```python +from spyglass.example import AnyTable + +AnyTable >> 'downsteam_attribute="value"' +AnyTable << 'upstream_attribute="value"' + +# Equivalent to +AnyTable.restrict_by('upstream_attribute="value"', direction="up") +AnyTable.restrict_by('downsteam_attribute="value"', direction="down") +``` + +Some caveats to this function: + +1. 'Peripheral' tables, like `IntervalList` and `AnalysisNwbfile` make it hard + to determine the correct parent/child relationship and have been removed + from this search. +2. This function will raise an error if it attempts to check a table that has + not been imported into the current namespace. It is best used for exploring + and debugging, not for production code. +3. It's hard to determine the attributes in a mixed dictionary/string + restriction. If you are having trouble, try using a pure string + restriction. +4. The most direct path to your restriction may not be the path took, especially + when using Merge Tables. When the result is empty see the warning about the + path used. Then, ban tables from the search to force a different path. + +```python +my_table = MyTable() # must be instantced +my_table.ban_search_table(UnwantedTable1) +my_table.ban_search_table([UnwantedTable2, UnwantedTable3]) +my_table.unban_search_table(UnwantedTable3) +my_table.see_banned_tables() + +my_table << my_restriction +``` + +When providing a restriction of the parent, use 'up' direction. When providing a +restriction of the child, use 'down' direction. + ## Delete Functionality The mixin overrides the default `delete` function to provide two additional diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py deleted file mode 100644 index fe9cebc02..000000000 --- a/src/spyglass/utils/dj_chains.py +++ /dev/null @@ -1,373 +0,0 @@ -from collections import OrderedDict -from functools import cached_property -from typing import List, Union - -import datajoint as dj -import networkx as nx -from datajoint.expression import QueryExpression -from datajoint.table import Table -from datajoint.utils import get_master, to_camel_case - -from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK -from spyglass.utils.logging import logger - -# Tables that should be excluded from the undirected graph when finding paths -# to maintain valid joins. -PERIPHERAL_TABLES = [ - "`common_interval`.`interval_list`", - "`common_nwbfile`.`__analysis_nwbfile_kachery`", - "`common_nwbfile`.`__nwbfile_kachery`", - "`common_nwbfile`.`analysis_nwbfile_kachery_selection`", - "`common_nwbfile`.`analysis_nwbfile_kachery`", - "`common_nwbfile`.`analysis_nwbfile`", - "`common_nwbfile`.`kachery_channel`", - "`common_nwbfile`.`nwbfile_kachery_selection`", - "`common_nwbfile`.`nwbfile_kachery`", - "`common_nwbfile`.`nwbfile`", -] - - -class TableChains: - """Class for representing chains from parent to Merge table via parts. - - Functions as a plural version of TableChain, allowing a single `join` - call across all chains from parent -> Merge table. - - Attributes - ---------- - parent : Table - Parent or origin of chains. - child : Table - Merge table or destination of chains. - connection : datajoint.Connection, optional - Connection to database used to create FreeTable objects. Defaults to - parent.connection. - part_names : List[str] - List of full table names of child parts. - chains : List[TableChain] - List of TableChain objects for each part in child. - has_link : bool - Cached attribute to store whether parent is linked to child via any of - child parts. False if (a) child is not in parent.descendants or (b) - nx.NetworkXNoPath is raised by nx.shortest_path for all chains. - - Methods - ------- - __init__(parent, child, connection=None) - Initialize TableChains with parent and child tables. - __repr__() - Return full representation of chains. - Multiline parent -> child for each chain. - __len__() - Return number of chains with links. - __getitem__(index: Union[int, str]) - Return TableChain object at index, or use substring of table name. - join(restriction: str = None) - Return list of joins for each chain in self.chains. - """ - - def __init__(self, parent, child, connection=None): - self.parent = parent - self.child = child - self.connection = connection or parent.connection - parts = child.parts(as_objects=True) - self.part_names = [part.full_table_name for part in parts] - self.chains = [TableChain(parent, part) for part in parts] - self.has_link = any([chain.has_link for chain in self.chains]) - - def __repr__(self): - return "\n".join([str(chain) for chain in self.chains]) - - def __len__(self): - return len([c for c in self.chains if c.has_link]) - - @property - def max_len(self): - """Return length of longest chain.""" - return max([len(chain) for chain in self.chains]) - - def __getitem__(self, index: Union[int, str]): - """Return FreeTable object at index.""" - if isinstance(index, str): - for i, part in enumerate(self.part_names): - if index in part: - return self.chains[i] - return self.chains[index] - - def join(self, restriction=None) -> List[QueryExpression]: - """Return list of joins for each chain in self.chains.""" - restriction = restriction or self.parent.restriction or True - joins = [] - for chain in self.chains: - if joined := chain.join(restriction): - joins.append(joined) - return joins - - -class TableChain: - """Class for representing a chain of tables. - - A chain is a sequence of tables from parent to child identified by - networkx.shortest_path. Parent -> Merge should use TableChains instead to - handle multiple paths to the respective parts of the Merge table. - - Attributes - ---------- - parent : Table - Parent or origin of chain. - child : Table - Child or destination of chain. - _connection : datajoint.Connection, optional - Connection to database used to create FreeTable objects. Defaults to - parent.connection. - _link_symbol : str - Symbol used to represent the link between parent and child. Hardcoded - to " -> ". - has_link : bool - Cached attribute to store whether parent is linked to child. False if - child is not in parent.descendants or nx.NetworkXNoPath is raised by - nx.shortest_path. - link_type : str - 'directed' or 'undirected' based on whether path is found with directed - or undirected graph. None if no path is found. - graph : nx.DiGraph - Directed graph of parent's dependencies from datajoint.connection. - names : List[str] - List of full table names in chain. - objects : List[dj.FreeTable] - List of FreeTable objects for each table in chain. - attr_maps : List[dict] - List of attribute maps for each link in chain. - path : OrderedDict[str, Dict[str, Union[dj.FreeTable,dict]]] - Dictionary of full table names in chain. Keys are self.names - Values are a dict of free_table (self.objects) and - attr_map (dict of new_name: old_name, self.attr_map). - - Methods - ------- - __str__() - Return string representation of chain: parent -> child. - __repr__() - Return full representation of chain: parent -> {links} -> child. - __len__() - Return number of tables in chain. - __getitem__(index: Union[int, str]) - Return FreeTable object at index, or use substring of table name. - find_path(directed=True) - Returns path OrderedDict of full table names in chain. If directed is - True, uses directed graph. If False, uses undirected graph. Undirected - excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain - valid joins. - join(restriction: str = None) - Return join of tables in chain with restriction applied to parent. - """ - - def __init__(self, parent: Table, child: Table, connection=None): - self._connection = connection or parent.connection - self.graph = self._connection.dependencies - self.graph.load() - - if ( # if child is a merge table - get_master(child.full_table_name) == "" - and MERGE_PK in child.heading.names - ): - raise TypeError("Child is a merge table. Use TableChains instead.") - - self._link_symbol = " -> " - self.parent = parent - self.child = child - self.link_type = None - self._searched = False - - if child.full_table_name not in self.graph.nodes: - logger.warning( - "Can't find item in graph. Try importing: " - + f"{child.full_table_name}" - ) - self._searched = True - - def __str__(self): - """Return string representation of chain: parent -> child.""" - if not self.has_link: - return "No link" - return ( - to_camel_case(self.parent.table_name) - + self._link_symbol - + to_camel_case(self.child.table_name) - ) - - def __repr__(self): - """Return full representation of chain: parent -> {links} -> child.""" - if not self.has_link: - return "No link" - return "Chain: " + self._link_symbol.join( - [t.table_name for t in self.objects] - ) - - def __len__(self): - """Return number of tables in chain.""" - if not self.has_link: - return 0 - return len(self.names) - - def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: - """Return FreeTable object at index.""" - if not self.has_link: - return None - if isinstance(index, str): - for i, name in enumerate(self.names): - if index in name: - return self.objects[i] - return self.objects[index] - - @property - def has_link(self) -> bool: - """Return True if parent is linked to child. - - If not searched, search for path. If searched and no link is found, - return False. If searched and link is found, return True. - """ - if not self._searched: - _ = self.path - return self.link_type is not None - - def pk_link(self, src, trg, data) -> float: - """Return 1 if data["primary"] else float("inf"). - - Currently unused. Preserved for future debugging. shortest_path accepts - an option weight callable parameter. - nx.shortest_path(G, source, target,weight=pk_link) - """ - return 1 if data["primary"] else float("inf") - - def find_path(self, directed=True) -> OrderedDict: - """Return list of full table names in chain. - - Parameters - ---------- - directed : bool, optional - If True, use directed graph. If False, use undirected graph. - Defaults to True. Undirected permits paths to traverse from merge - part-parent -> merge part -> merge table. Undirected excludes - PERIPHERAL_TABLES like interval_list, nwbfile, etc. - - Returns - ------- - OrderedDict - Dictionary of full table names in chain. Keys are full table names. - Values are free_table (dj.FreeTable representation) and attr_map - (dict of new_name: old_name). Attribute maps on the table upstream - of an alias node that can be used in .proj(). Returns None if no - path is found. - - Ignores numeric table names in paths, which are - 'gaps' or alias nodes in the graph. See datajoint.Diagram._make_graph - source code for comments on alias nodes. - """ - source, target = self.parent.full_table_name, self.child.full_table_name - if not directed: - self.graph = self.graph.to_undirected() - self.graph.remove_nodes_from(PERIPHERAL_TABLES) - try: - path = nx.shortest_path(self.graph, source, target) - except nx.NetworkXNoPath: - return None - except nx.NodeNotFound: - self._searched = True - return None - - ret = OrderedDict() - prev_table = None - for i, table in enumerate(path): - if table.isnumeric(): # get proj() attribute map for alias node - if not prev_table: - raise ValueError("Alias node found without prev table.") - try: - attr_map = self.graph[table][prev_table]["attr_map"] - except KeyError: # Why is this only DLCCentroid?? - attr_map = self.graph[prev_table][table]["attr_map"] - ret[prev_table]["attr_map"] = attr_map - else: - free_table = dj.FreeTable(self._connection, table) - ret[table] = {"free_table": free_table, "attr_map": {}} - prev_table = table - return ret - - @cached_property - def path(self) -> OrderedDict: - """Return list of full table names in chain.""" - if self._searched and not self.has_link: - return None - - link = None - if link := self.find_path(directed=True): - self.link_type = "directed" - elif link := self.find_path(directed=False): - self.link_type = "undirected" - self._searched = True - - return link - - @cached_property - def names(self) -> List[str]: - """Return list of full table names in chain.""" - if not self.has_link: - return None - return list(self.path.keys()) - - @cached_property - def objects(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain. - - Unused. Preserved for future debugging. - """ - if not self.has_link: - return None - return [v["free_table"] for v in self.path.values()] - - @cached_property - def attr_maps(self) -> List[dict]: - """Return list of attribute maps for each table in chain. - - Unused. Preserved for future debugging. - """ - if not self.has_link: - return None - return [v["attr_map"] for v in self.path.values()] - - def join( - self, restriction: str = None, reverse_order: bool = False - ) -> dj.expression.QueryExpression: - """Return join of tables in chain with restriction applied to parent. - - Parameters - ---------- - restriction : str, optional - Restriction to apply to first table in the order. - Defaults to self.parent.restriction. - reverse_order : bool, optional - If True, join tables in reverse order. Defaults to False. - """ - if not self.has_link: - return None - - restriction = restriction or self.parent.restriction or True - path = ( - OrderedDict(reversed(self.path.items())) - if reverse_order - else self.path - ).copy() - - _, first_val = path.popitem(last=False) - join = first_val["free_table"] & restriction - for i, val in enumerate(path.values()): - attr_map, free_table = val["attr_map"], val["free_table"] - try: - join = (join.proj() * free_table).proj(**attr_map) - except dj.DataJointError as e: - attribute = str(e).split("attribute ")[-1] - logger.error( - f"{str(self)} at {free_table.table_name} with {attribute}" - ) - return None - return join diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 59e7497d5..5bf3d25d0 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,79 +3,154 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from typing import Dict, List, Union +from abc import ABC, abstractmethod +from collections.abc import KeysView +from enum import Enum +from functools import cached_property +from itertools import chain as iter_chain +from typing import Any, Dict, List, Set, Tuple, Union -from datajoint import FreeTable +import datajoint as dj +from datajoint import FreeTable, Table from datajoint.condition import make_condition -from datajoint.table import Table +from datajoint.dependencies import unite_master_parts +from datajoint.utils import get_master, to_camel_case +from networkx import ( + NetworkXNoPath, + NodeNotFound, + all_simple_paths, + shortest_path, +) +from networkx.algorithms.dag import topological_sort from tqdm import tqdm -from spyglass.common import AnalysisNwbfile from spyglass.utils import logger -from spyglass.utils.dj_helper_fn import unique_dicts +from spyglass.utils.dj_helper_fn import ( + PERIPHERAL_TABLES, + fuzzy_get, + unique_dicts, +) +from spyglass.utils.dj_merge_tables import is_merge_table -class RestrGraph: - def __init__( - self, - seed_table: Table, - table_name: str = None, - restriction: str = None, - leaves: List[Dict[str, str]] = None, - verbose: bool = False, - **kwargs, - ): - """Use graph to cascade restrictions up from leaves to all ancestors. +class Direction(Enum): + """Cascade direction enum. Calling Up returns True. Inverting flips.""" + + UP = "up" + DOWN = "down" + NONE = None + + def __str__(self): + return self.value + + def __invert__(self) -> "Direction": + """Invert the direction.""" + if self.value is None: + logger.warning("Inverting NONE direction") + return Direction.NONE + return Direction.UP if self.value == "down" else Direction.DOWN + + def __bool__(self) -> bool: + """Return True if direction is not None.""" + return self.value is not None + + +class AbstractGraph(ABC): + """Abstract class for graph traversal and restriction application. + + Inherited by... + - RestrGraph: Cascade restriction(s) through a graph + - TableChain: Takes parent and child nodes, finds the shortest path, + and applies a restriction across the path. If either parent or child + is a merge table, use TableChains instead. If either parent or child + are not provided, search_restr is required to find the path to the + missing table. + + Methods + ------- + cascade: Abstract method implemented by child classes + cascade1: Cascade a restriction up/down the graph, recursively + + Properties + ---------- + all_ft: Get all FreeTables for visited nodes with restrictions applied. + as_dict: Get visited nodes as a list of dictionaries of + {table_name: restriction} + """ + + def __init__(self, seed_table: Table, verbose: bool = False, **kwargs): + """Initialize graph and connection. Parameters ---------- seed_table : Table Table to use to establish connection and graph - table_name : str, optional - Table name of single leaf, default None - restriction : str, optional - Restriction to apply to leaf. default None - leaves : Dict[str, str], optional - List of dictionaries with keys table_name and restriction. One - entry per leaf node. Default None. verbose : bool, optional Whether to print verbose output. Default False """ - + self.seed_table = seed_table self.connection = seed_table.connection + + # Undirected graph may not be needed, but adding FT to the graph + # prevents `to_undirected` from working. If using undirected, remove + # PERIPHERAL_TABLES from the graph. self.graph = seed_table.connection.dependencies self.graph.load() self.verbose = verbose - self.cascaded = False - self.ancestors = set() - self.visited = set() self.leaves = set() - self.analysis_pk = AnalysisNwbfile().primary_key + self.visited = set() + self.to_visit = set() + self.no_visit = set() + self.cascaded = False - if table_name and restriction: - self.add_leaf(table_name, restriction) - if leaves: - self.add_leaves(leaves, show_progress=verbose) + # --------------------------- Abstract Methods --------------------------- - def __repr__(self): - l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" - processed = "Cascaded" if self.cascaded else "Uncascaded" - return f"{processed} RestrictionGraph(\n\t{l_str})" + @abstractmethod + def cascade(self): + """Cascade restrictions through graph.""" + raise NotImplementedError("Child class mut implement `cascade` method") - @property - def all_ft(self): - """Get restricted FreeTables from all visited nodes.""" - self.cascade() - return [self._get_ft(table, with_restr=True) for table in self.visited] + # ---------------------------- Logging Helpers ---------------------------- - @property - def leaf_ft(self): - """Get restricted FreeTables from graph leaves.""" - return [self._get_ft(table, with_restr=True) for table in self.leaves] + def _log_truncate(self, log_str: str, max_len: int = 80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) + + def _camel(self, table): + """Convert table name(s) to camel case.""" + if isinstance(table, KeysView): + table = list(table) + if not isinstance(table, list): + table = [table] + ret = [to_camel_case(t.split(".")[-1].strip("`")) for t in table] + return ret[0] if len(ret) == 1 else ret - def _get_node(self, table): + def _print_restr(self): + """Print restrictions for debugging.""" + for table in self.visited: + if restr := self._get_restr(table): + logger.info(f"{table}: {restr}") + + # ------------------------------ Graph Nodes ------------------------------ + + def _ensure_name(self, table: Union[str, Table] = None) -> str: + """Ensure table is a string.""" + if table is None: + return None + if isinstance(table, str): + return table + if isinstance(table, list): + return [self._ensure_name(t) for t in table] + return getattr(table, "full_table_name", None) + + def _get_node(self, table: Union[str, Table]): """Get node from graph.""" + table = self._ensure_name(table) if not (node := self.graph.nodes.get(table)): raise ValueError( f"Table {table} not found in graph." @@ -83,31 +158,47 @@ def _get_node(self, table): ) return node - def _set_node(self, table, attr="ft", value=None): + def _set_node(self, table, attr: str = "ft", value: Any = None): """Set attribute on node. General helper for various attributes.""" _ = self._get_node(table) # Ensure node exists self.graph.nodes[table][attr] = value - def _get_ft(self, table, with_restr=False): - """Get FreeTable from graph node. If one doesn't exist, create it.""" - table = table if isinstance(table, str) else table.full_table_name - restr = self._get_restr(table) if with_restr else True - if ft := self._get_node(table).get("ft"): - return ft & restr - ft = FreeTable(self.connection, table) - self._set_node(table, "ft", ft) - return ft & restr + def _get_edge(self, child: str, parent: str) -> Tuple[bool, Dict[str, str]]: + """Get edge data between child and parent. + + Used as a fallback for _bridge_restr. Required for Maser/Part links to + temporarily flip direction. + + Returns + ------- + Tuple[bool, Dict[str, str]] + Tuple of boolean indicating direction and edge data. True if child + is child of parent. + """ + child = self._ensure_name(child) + parent = self._ensure_name(parent) + + if edge := self.graph.get_edge_data(parent, child): + return False, edge + elif edge := self.graph.get_edge_data(child, parent): + return True, edge + + # Handle alias nodes. `shortest_path` doesn't work with aliases + p1 = all_simple_paths(self.graph, child, parent) + p2 = all_simple_paths(self.graph, parent, child) + paths = [p for p in iter_chain(p1, p2)] # list for error handling + for path in paths: # Ignore long and non-alias paths + if len(path) > 3 or (len(path) > 2 and not path[1].isnumeric()): + continue + return self._get_edge(path[0], path[1]) + + raise ValueError(f"{child} -> {parent} not direct path: {paths}") def _get_restr(self, table): """Get restriction from graph node.""" - table = table if isinstance(table, str) else table.full_table_name - return self._get_node(table).get("restr", "False") - - def _get_files(self, table): - """Get analysis files from graph node.""" - return self._get_node(table).get("files", []) + return self._get_node(self._ensure_name(table)).get("restr") - def _set_restr(self, table, restriction): + def _set_restr(self, table, restriction, replace=False): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) restriction = ( # Convert to condition if list or dict @@ -115,9 +206,9 @@ def _set_restr(self, table, restriction): if not isinstance(restriction, str) else restriction ) - # orig_restr = restriction - if existing := self._get_restr(table): - if existing == restriction: + existing = self._get_restr(table) + if not replace and existing: + if restriction == existing: return join = ft & [existing, restriction] if len(join) == len(ft & existing): @@ -126,168 +217,337 @@ def _set_restr(self, table, restriction): ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) - # if table == "`spikesorting_merge`.`spike_sorting_output`": - # __import__("pdb").set_trace() - self._set_node(table, "restr", restriction) - def get_restr_ft(self, table: Union[int, str]): - """Get restricted FreeTable from graph node. + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + table = self._ensure_name(table) + if with_restr: + if not (restr := self._get_restr(table) or False): + self._log_truncate(f"No restriction for {table}") + else: + restr = True - Currently used. May be useful for debugging. + if not (ft := self._get_node(table).get("ft")): + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) - Parameters - ---------- - table : Union[int, str] - Table name or index in visited set - """ - if isinstance(table, int): - table = list(self.visited)[table] - return self._get_ft(table, with_restr=True) + return ft & restr - def _log_truncate(self, log_str, max_len=80): - """Truncate log lines to max_len and print if verbose.""" - if not self.verbose: - return - logger.info( - log_str[:max_len] + "..." if len(log_str) > max_len else log_str - ) + def _and_parts(self, table): + """Return table, its master and parts.""" + ret = [table] + if master := get_master(table): + ret.append(master) + if parts := self._get_ft(table).parts(): + ret.extend(parts) + return ret + + # ---------------------------- Graph Traversal ----------------------------- - def _child_to_parent( + def _bridge_restr( self, - child, - parent, - restriction, - attr_map=None, - primary=True, + table1: str, + table2: str, + restr: str, + direction: Direction = None, + attr_map: dict = None, + aliased: bool = None, **kwargs, - ) -> List[Dict[str, str]]: - """Given a child, child's restr, and parent, get parent's restr. + ): + """Given two tables and a restriction, return restriction for table2. + + Similar to ((table1 & restr) * table2).fetch(*table2.primary_key) + but with the ability to resolve aliases across tables. One table should + be the parent of the other. If direction or attr_map are not provided, + they will be inferred from the graph. Parameters ---------- - child : str - child table name - parent : str - parent table name - restriction : str - restriction to apply to child + table1 : str + Table name. Restriction always applied to this table. + table2 : str + Table name. Restriction pulled from this table. + restr : str + Restriction to apply to table1. + direction : Direction, optional + Direction to cascade. Default None. attr_map : dict, optional - dictionary mapping aliases across parend/child, as pulled from - DataJoint-assembled graph. Default None. Func will flip this dict - to convert from child to parent fields. - primary : bool, optional - Is parent in child's primary key? Default True. Also derived from - DataJoint-assembled graph. If True, project only primary key fields - to avoid secondary key collisions. + dictionary mapping aliases across tables, as pulled from + DataJoint-assembled graph. Default None. + Returns ------- List[Dict[str, str]] - List of dicts containing primary key fields for restricted parent - table. + List of dicts containing primary key fields for restricted table2. """ + if not all([direction, attr_map]): + dir_bool, edge = self._get_edge(table1, table2) + direction = "up" if dir_bool else "down" + attr_map = edge.get("attr_map") - # Need to flip attr_map to respect parent's fields - attr_reverse = ( - {v: k for k, v in attr_map.items() if k != v} if attr_map else {} - ) - child_ft = self._get_ft(child) - parent_ft = self._get_ft(parent).proj() - restr = restriction or self._get_restr(child_ft) or True - restr_child = child_ft & restr + ft1 = self._get_ft(table1) & restr + ft2 = self._get_ft(table2) + + if len(ft1) == 0: + return ["False"] - if primary: # Project only primary key fields to avoid collisions - join = restr_child.proj(**attr_reverse) * parent_ft - else: # Include all fields - join = restr_child.proj(..., **attr_reverse) * parent_ft + if bool(set(attr_map.values()) - set(ft1.heading.names)): + attr_map = {v: k for k, v in attr_map.items()} # reverse - ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) + join = ft1.proj(**attr_map) * ft2 + ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) - if len(ret) == len(parent_ft): - self._log_truncate(f"NULL restr {parent}") + if self.verbose: # For debugging. Not required for typical use. + result = ( + "EMPTY" + if len(ret) == 0 + else "FULL" if len(ft2) == len(ret) else "partial" + ) + path = f"{self._camel(table1)} -> {self._camel(table2)}" + self._log_truncate(f"Bridge Link: {path}: result {result}") return ret - def cascade_files(self): - """Set node attribute for analysis files.""" - for table in self.visited: - ft = self._get_ft(table) - if not set(self.analysis_pk).issubset(ft.heading.names): - continue - files = (ft & self._get_restr(table)).fetch(*self.analysis_pk) - self._set_node(table, "files", files) + def _get_next_tables(self, table: str, direction: Direction) -> Tuple: + """Get next tables/func based on direction. + + Used in cascade1 and cascade1_search to add master and parts. Direction + is intentionally omitted to force _get_edge to determine the edge for + this gap before resuming desired direction. Nextfunc is used to get + relevant parent/child tables after aliast node. + + Parameters + ---------- + table : str + Table name + direction : Direction + Direction to cascade + + Returns + ------- + Tuple[Dict[str, Dict[str, str]], Callable + Tuple of next tables and next function to get parent/child tables. + """ + G = self.graph + dir_dict = {"direction": direction} - def cascade1(self, table, restriction): - """Cascade a restriction up the graph, recursively on parents. + bonus = {} + direction = Direction(direction) + if direction == Direction.UP: + next_func = G.parents + bonus.update({part: {} for part in self._get_ft(table).parts()}) + elif direction == Direction.DOWN: + next_func = G.children + if (master_name := get_master(table)) != "": + bonus = {master_name: {}} + else: + raise ValueError(f"Invalid direction: {direction}") + + next_tables = { + k: {**v, **dir_dict} for k, v in next_func(table).items() + } + next_tables.update(bonus) + + return next_tables, next_func + + def cascade1( + self, + table: str, + restriction: str, + direction: Direction = Direction.UP, + replace=False, + count=0, + **kwargs, + ): + """Cascade a restriction up the graph, recursively on parents/children. Parameters ---------- table : str - table name + Table name restriction : str - restriction to apply + Restriction to apply + direction : Direction, optional + Direction to cascade. Default 'up' + replace : bool, optional + Replace existing restriction. Default False """ - self._set_restr(table, restriction) + if count > 100: + raise RecursionError("Cascade1: Recursion limit reached.") + + self._set_restr(table, restriction, replace=replace) self.visited.add(table) - for parent, data in self.graph.parents(table).items(): - if parent in self.visited: - continue + next_tables, next_func = self._get_next_tables(table, direction) - if parent.isnumeric(): - parent, data = self.graph.parents(parent).popitem() + self._log_truncate( + f"Checking {count:>2}: {self._camel(next_tables.keys())}" + ) + for next_table, data in next_tables.items(): + if next_table.isnumeric(): # Skip alias nodes + next_table, data = next_func(next_table).popitem() - parent_restr = self._child_to_parent( - child=table, - parent=parent, - restriction=restriction, + if ( + next_table in self.visited + or next_table in self.no_visit # Subclasses can set this + or table == next_table + ): + reason = ( + "Already saw" + if next_table in self.visited + else "Banned Tbl " + ) + self._log_truncate(f"{reason}: {self._camel(next_table)}") + continue + + next_restr = self._bridge_restr( + table1=table, + table2=next_table, + restr=restriction, **data, ) - self.cascade1(parent, parent_restr) # Parent set on recursion + if next_restr == ["False"]: # Stop cascade if empty restriction + continue - def cascade(self, show_progress=None) -> None: - """Cascade all restrictions up the graph. + self.cascade1( + table=next_table, + restriction=next_restr, + direction=direction, + replace=replace, + count=count + 1, + ) + + # ---------------------------- Graph Properties ---------------------------- + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes. + + Topological sort logic adopted from datajoint.diagram. + """ + self.cascade() + nodes = [n for n in self.visited if not n.isnumeric()] + sorted_nodes = unite_master_parts( + list(topological_sort(self.graph.subgraph(nodes))) + ) + all_ft = [ + self._get_ft(table, with_restr=True) for table in sorted_nodes + ] + return [ft for ft in all_ft if len(ft) > 0] + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.visited + if self._get_restr(table) + ] + + +class RestrGraph(AbstractGraph): + def __init__( + self, + seed_table: Table, + table_name: str = None, + restriction: str = None, + leaves: List[Dict[str, str]] = None, + direction: Direction = "up", + cascade: bool = False, + verbose: bool = False, + **kwargs, + ): + """Use graph to cascade restrictions up from leaves to all ancestors. + + 'Leaves' are nodes with restrictions applied. Restrictions are cascaded + up/down the graph to all ancestors/descendants. If cascade is desired + in both direction, leaves/cascades should be added and run separately. + Future development could allow for direction setting on a per-leaf + basis. Parameters ---------- - show_progress : bool, optional - Show tqdm progress bar. Default to verbose setting. + seed_table : Table + Table to use to establish connection and graph + table_name : str, optional + Table name of single leaf, default None + restriction : str, optional + Restriction to apply to leaf. default None + leaves : Dict[str, str], optional + List of dictionaries with keys table_name and restriction. One + entry per leaf node. Default None. + direction : Direction, optional + Direction to cascade. Default 'up' + cascade : bool, optional + Whether to cascade restrictions up the graph on initialization. + Default False + verbose : bool, optional + Whether to print verbose output. Default False """ - if self.cascaded: - return - to_visit = self.leaves - self.visited - for table in tqdm( - to_visit, - desc="RestrGraph: cascading restrictions", - total=len(to_visit), - disable=not (show_progress or self.verbose), - ): - restr = self._get_restr(table) - self._log_truncate(f"Start {table}: {restr}") - self.cascade1(table, restr) - if not self.visited == self.ancestors: - raise RuntimeError( - "Cascade: FAIL - incomplete cascade. Please post issue." - ) + super().__init__(seed_table, verbose=verbose) - self.cascade_files() - self.cascaded = True + self.add_leaf( + table_name=table_name, restriction=restriction, direction=direction + ) + self.add_leaves(leaves) + + if cascade: + self.cascade(direction=direction) + + # --------------------------- Dunder Properties --------------------------- + + def __repr__(self): + l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" + processed = "Cascaded" if self.cascaded else "Uncascaded" + return f"{processed} {self.__class__.__name__}(\n\t{l_str})" + + def __getitem__(self, index: Union[int, str]): + all_ft_names = [t.full_table_name for t in self.all_ft] + return fuzzy_get(index, all_ft_names, self.all_ft) + + def __len__(self): + return len(self.all_ft) + + # ---------------------------- Public Properties -------------------------- + + @property + def leaf_ft(self): + """Get restricted FreeTables from graph leaves.""" + return [self._get_ft(table, with_restr=True) for table in self.leaves] - def add_leaf(self, table_name, restriction, cascade=False) -> None: + # ------------------------------- Add Nodes ------------------------------- + + def add_leaf( + self, table_name=None, restriction=True, cascade=False, direction="up" + ) -> None: """Add leaf to graph and cascade if requested. Parameters ---------- - table_name : str - table name of leaf - restriction : str - restriction to apply to leaf + table_name : str, optional + table name of leaf. Default None, do nothing. + restriction : str, optional + restriction to apply to leaf. Default True, no restriction. + cascade : bool, optional + Whether to cascade the restrictions up the graph. Default False. """ - new_ancestors = set(self._get_ft(table_name).ancestors()) - self.ancestors |= new_ancestors # Add to total ancestors - self.visited -= new_ancestors # Remove from visited to revisit + if not table_name: + return + + self.cascaded = False + + new_visits = ( + set(self._get_ft(table_name).ancestors()) + if direction == "up" + else set(self._get_ft(table_name).descendants()) + ) + + self.to_visit |= new_visits # Add to total ancestors + self.visited -= new_visits # Remove from visited to revisit self.leaves.add(table_name) self._set_restr(table_name, restriction) # Redundant if cascaded @@ -297,71 +557,119 @@ def add_leaf(self, table_name, restriction, cascade=False) -> None: self.cascade_files() self.cascaded = True + def _process_leaves(self, leaves=None, default_restriction=True): + """Process leaves to ensure they are unique and have required keys.""" + if not leaves: + return [] + if not isinstance(leaves, list): + leaves = [leaves] + if all(isinstance(leaf, str) for leaf in leaves): + leaves = [ + {"table_name": leaf, "restriction": default_restriction} + for leaf in leaves + ] + if all(isinstance(leaf, dict) for leaf in leaves) and not all( + leaf.get("table_name") for leaf in leaves + ): + raise ValueError(f"All leaves must have table_name: {leaves}") + + return unique_dicts(leaves) + def add_leaves( - self, leaves: List[Dict[str, str]], cascade=False, show_progress=None + self, + leaves: Union[str, List, List[Dict[str, str]]] = None, + default_restriction: str = None, + cascade=False, ) -> None: """Add leaves to graph and cascade if requested. Parameters ---------- - leaves : List[Dict[str, str]] - list of dictionaries containing table_name and restriction + leaves : Union[str, List, List[Dict[str, str]]], optional + Table names of leaves, either as a list of strings or a list of + dictionaries with keys table_name and restriction. One entry per + leaf node. Default None, do nothing. + default_restriction : str, optional + Default restriction to apply to each leaf. Default True, no + restriction. Only used if leaf missing restriction. cascade : bool, optional Whether to cascade the restrictions up the graph. Default False + """ + leaves = self._process_leaves( + leaves=leaves, default_restriction=default_restriction + ) + for leaf in leaves: + self.add_leaf( + leaf.get("table_name"), + leaf.get("restriction"), + cascade=False, + ) + if cascade: + self.cascade() + + # ------------------------------ Graph Traversal -------------------------- + + def cascade(self, show_progress=None, direction="up") -> None: + """Cascade all restrictions up the graph. + + Parameters + ---------- show_progress : bool, optional Show tqdm progress bar. Default to verbose setting. """ - - if not leaves: + if self.cascaded: return - if not isinstance(leaves, list): - leaves = [leaves] - leaves = unique_dicts(leaves) - for leaf in tqdm( - leaves, - desc="RestrGraph: adding leaves", - total=len(leaves), + + to_visit = self.leaves - self.visited + + for table in tqdm( + to_visit, + desc="RestrGraph: cascading restrictions", + total=len(to_visit), disable=not (show_progress or self.verbose), ): - if not ( - (table_name := leaf.get("table_name")) - and (restriction := leaf.get("restriction")) - ): - raise ValueError( - f"Leaf must have table_name and restriction: {leaf}" - ) - self.add_leaf(table_name, restriction, cascade=False) - if cascade: - self.cascade() - self.cascade_files() + restr = self._get_restr(table) + self._log_truncate(f"Start {table}: {restr}") + self.cascade1(table, restr, direction=direction) + + self.cascade_files() + self.cascaded = True + + # ----------------------------- File Handling ----------------------------- + + def _get_files(self, table): + """Get analysis files from graph node.""" + return self._get_node(table).get("files", []) + + def cascade_files(self): + """Set node attribute for analysis files.""" + for table in self.visited: + ft = self._get_ft(table, with_restr=True) + if not set(self.analysis_pk).issubset(ft.heading.names): + continue + files = list(ft.fetch(*self.analysis_pk)) + self._set_node(table, "files", files) @property - def as_dict(self) -> List[Dict[str, str]]: - """Return as a list of dictionaries of table_name: restriction""" - self.cascade() - return [ - {"table_name": table, "restriction": self._get_restr(table)} - for table in self.ancestors - if self._get_restr(table) - ] + def analysis_file_tbl(self) -> Table: + """Return the analysis file table. Avoids circular import.""" + from spyglass.common import AnalysisNwbfile + + return AnalysisNwbfile() + + @property + def analysis_pk(self) -> List[str]: + """Return primary key fields from analysis file table.""" + return self.analysis_file_tbl.primary_key @property def file_dict(self) -> Dict[str, List[str]]: """Return dictionary of analysis files from all visited nodes. - Currently unused, but could be useful for debugging. + Included for debugging, to associate files with tables. """ - if not self.cascaded: - logger.warning("Uncascaded graph. Using leaves only.") - table_list = self.leaves - else: - table_list = self.visited - - return { - table: self._get_files(table) - for table in table_list - if any(self._get_files(table)) - } + self.cascade() + return {t: self._get_node(t).get("files", []) for t in self.visited} @property def file_paths(self) -> List[str]: @@ -371,11 +679,445 @@ def file_paths(self) -> List[str]: directly by the user. """ self.cascade() - unique_files = set( - [file for table in self.visited for file in self._get_files(table)] - ) return [ - {"file_path": AnalysisNwbfile().get_abs_path(file)} - for file in unique_files + {"file_path": self.analysis_file_tbl.get_abs_path(file)} + for file in set( + [f for files in self.file_dict.values() for f in files] + ) if file is not None ] + + +class TableChains: + """Class for representing chains from parent to Merge table via parts. + + Functions as a plural version of TableChain, allowing a single `cascade` + call across all chains from parent -> Merge table. + + Attributes + ---------- + parent : Table + Parent or origin of chains. + child : Table + Merge table or destination of chains. + connection : datajoint.Connection, optional + Connection to database used to create FreeTable objects. Defaults to + parent.connection. + part_names : List[str] + List of full table names of child parts. + chains : List[TableChain] + List of TableChain objects for each part in child. + has_link : bool + Cached attribute to store whether parent is linked to child via any of + child parts. False if (a) child is not in parent.descendants or (b) + nx.NetworkXNoPath is raised by nx.shortest_path for all chains. + + Methods + ------- + __init__(parent, child, connection=None) + Initialize TableChains with parent and child tables. + __repr__() + Return full representation of chains. + Multiline parent -> child for each chain. + __len__() + Return number of chains with links. + __getitem__(index: Union[int, str]) + Return TableChain object at index, or use substring of table name. + cascade(restriction: str = None) + Return list of cascade for each chain in self.chains. + """ + + def __init__(self, parent, child, direction=Direction.DOWN, verbose=False): + self.parent = parent + self.child = child + self.connection = parent.connection + self.part_names = child.parts() + self.chains = [ + TableChain(parent, part, direction=direction, verbose=verbose) + for part in self.part_names + ] + self.has_link = any([chain.has_link for chain in self.chains]) + + # --------------------------- Dunder Properties --------------------------- + + def __repr__(self): + l_str = ",\n\t".join([str(c) for c in self.chains]) + "\n" + return f"{self.__class__.__name__}(\n\t{l_str})" + + def __len__(self): + return len([c for c in self.chains if c.has_link]) + + def __getitem__(self, index: Union[int, str]): + """Return FreeTable object at index.""" + return fuzzy_get(index, self.part_names, self.chains) + + # ---------------------------- Public Properties -------------------------- + + @property + def max_len(self): + """Return length of longest chain.""" + return max([len(chain) for chain in self.chains]) + + # ------------------------------ Graph Traversal -------------------------- + + def cascade( + self, restriction: str = None, direction: Direction = Direction.DOWN + ): + """Return list of cascades for each chain in self.chains.""" + restriction = restriction or self.parent.restriction or True + cascades = [] + for chain in self.chains: + if joined := chain.cascade(restriction, direction): + cascades.append(joined) + return cascades + + +class TableChain(RestrGraph): + """Class for representing a chain of tables. + + A chain is a sequence of tables from parent to child identified by + networkx.shortest_path. Parent -> Merge should use TableChains instead to + handle multiple paths to the respective parts of the Merge table. + + Attributes + ---------- + parent : str + Parent or origin of chain. + child : str + Child or destination of chain. + has_link : bool + Cached attribute to store whether parent is linked to child. + path : List[str] + Names of tables along the path from parent to child. + all_ft : List[dj.FreeTable] + List of FreeTable objects for each table in chain with restriction + applied. + + Methods + ------- + find_path(directed=True) + Returns path OrderedDict of full table names in chain. If directed is + True, uses directed graph. If False, uses undirected graph. Undirected + excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain + valid joins. + cascade(restriction: str = None, direction: str = "up") + Given a restriction at the beginning, return a restricted FreeTable + object at the end of the chain. If direction is 'up', start at the child + and move up to the parent. If direction is 'down', start at the parent. + """ + + def __init__( + self, + parent: Table = None, + child: Table = None, + direction: Direction = Direction.NONE, + search_restr: str = None, + cascade: bool = False, + verbose: bool = False, + allow_merge: bool = False, + banned_tables: List[str] = None, + **kwargs, + ): + if not allow_merge and child is not None and is_merge_table(child): + raise TypeError("Child is a merge table. Use TableChains instead.") + + self.parent = self._ensure_name(parent) + self.child = self._ensure_name(child) + + if not self.parent and not self.child: + raise ValueError("Parent or child table required.") + if not search_restr and not (self.parent and self.child): + raise ValueError("Search restriction required to find path.") + + seed_table = parent if isinstance(parent, Table) else child + super().__init__(seed_table=seed_table, verbose=verbose) + + self.no_visit.update(PERIPHERAL_TABLES) + self.no_visit.update(self._ensure_name(banned_tables) or []) + self.no_visit.difference_update([self.parent, self.child]) + self.searched_tables = set() + self.found_restr = False + self.link_type = None + self.searched_path = False + self._link_symbol = " -> " + + self.search_restr = search_restr + self.direction = Direction(direction) + + self.leaf = None + if search_restr and not parent: + self.direction = Direction.UP + self.leaf = self.child + if search_restr and not child: + self.direction = Direction.DOWN + self.leaf = self.parent + if self.leaf: + self._set_find_restr(self.leaf, search_restr) + self.add_leaf(self.leaf, True, cascade=False, direction=direction) + + if cascade and search_restr: + self.cascade_search() + self.cascade(restriction=search_restr) + self.cascaded = True + + # --------------------------- Dunder Properties --------------------------- + + def __str__(self): + """Return string representation of chain: parent -> child.""" + if not self.has_link: + return "No link" + return ( + self._camel(self.parent) + + self._link_symbol + + self._camel(self.child) + ) + + def __repr__(self): + """Return full representation of chain: parent -> {links} -> child.""" + if not self.has_link: + return "No link" + return "Chain: " + self.path_str + + def __len__(self): + """Return number of tables in chain.""" + if not self.has_link: + return 0 + return len(self.path) + + def __getitem__(self, index: Union[int, str]): + return fuzzy_get(index, self.path, self.all_ft) + + # ---------------------------- Public Properties -------------------------- + + @property + def has_link(self) -> bool: + """Return True if parent is linked to child. + + If not searched, search for path. If searched and no link is found, + return False. If searched and link is found, return True. + """ + if not self.searched_path: + _ = self.path + return self.link_type is not None + + @cached_property + def all_ft(self) -> List[dj.FreeTable]: + """Return list of FreeTable objects for each table in chain. + + Unused. Preserved for future debugging. + """ + if not self.has_link: + return None + return [ + self._get_ft(table, with_restr=False) + for table in self.path + if not table.isnumeric() + ] + + @property + def path_str(self) -> str: + if not self.path: + return "No link" + return self._link_symbol.join([self._camel(t) for t in self.path]) + + # ------------------------------ Graph Nodes ------------------------------ + + def _set_find_restr(self, table_name, restriction): + """Set restr to look for from leaf node.""" + if isinstance(restriction, dict): + restriction = [restriction] + + if isinstance(restriction, list) and all( + [isinstance(r, dict) for r in restriction] + ): + restr_attrs = set(key for restr in restriction for key in restr) + find_restr = restriction + elif isinstance(restriction, str): + restr_attrs = set() # modified by make_condition + table_ft = self._get_ft(table_name) + find_restr = make_condition(table_ft, restriction, restr_attrs) + else: + raise ValueError( + f"Invalid restriction type, use STR: {restriction}" + ) + + self._set_node(table_name, "restr_attrs", restr_attrs) + self._set_node(table_name, "find_restr", find_restr) + + def _get_find_restr(self, table) -> Tuple[str, Set[str]]: + """Get restr and restr_attrs from leaf node.""" + node = self._get_node(table) + return node.get("find_restr", False), node.get("restr_attrs", set()) + + # ---------------------------- Graph Traversal ---------------------------- + + def cascade_search(self) -> None: + if self.cascaded: + return + restriction, restr_attrs = self._get_find_restr(self.leaf) + self.cascade1_search( + table=self.leaf, + restriction=restriction, + restr_attrs=restr_attrs, + replace=True, + ) + if not self.found_restr: + searched = ( + "parents" if self.direction == Direction.UP else "children" + ) + logger.warning( + f"Restriction could not be applied to any {searched}.\n\t" + + f"From: {self.leaves}\n\t" + + f"Restr: {restriction}" + ) + + def _set_found_vars(self, table): + """Set found_restr and searched_tables.""" + self._set_restr(table, self.search_restr, replace=True) + self.found_restr = True + self.searched_tables.update(set(self._and_parts(table))) + + if self.direction == Direction.UP: + self.parent = table + elif self.direction == Direction.DOWN: + self.child = table + + self._log_truncate(f"FVars: {self._camel(table)}") + + self.direction = ~self.direction + _ = self.path # Reset path + + def cascade1_search( + self, + table: str = None, + restriction: str = True, + restr_attrs: Set[str] = None, + replace: bool = True, + limit: int = 100, + **kwargs, + ): + if ( + self.found_restr + or not table + or limit < 1 + or table in self.searched_tables + ): + return + + self.searched_tables.add(table) + next_tables, next_func = self._get_next_tables(table, self.direction) + + for next_table, data in next_tables.items(): + if next_table.isnumeric(): + next_table, data = next_func(next_table).popitem() + self._log_truncate( + f"Search Link: {self._camel(table)} -> {self._camel(next_table)}" + ) + + if next_table in self.no_visit or table == next_table: + reason = "Already Saw" if next_table == table else "Banned Tbl " + self._log_truncate(f"{reason}: {self._camel(next_table)}") + continue + + next_ft = self._get_ft(next_table) + if restr_attrs.issubset(set(next_ft.heading.names)): + self._log_truncate(f"Found: {self._camel(next_table)}") + self._set_found_vars(next_table) + return + + self.cascade1_search( + table=next_table, + restriction=restriction, + restr_attrs=restr_attrs, + replace=replace, + limit=limit - 1, + **data, + ) + if self.found_restr: + return + + # ------------------------------ Path Finding ------------------------------ + + def find_path(self, directed=True) -> List[str]: + """Return list of full table names in chain. + + Parameters + ---------- + directed : bool, optional + If True, use directed graph. If False, use undirected graph. + Defaults to True. Undirected permits paths to traverse from merge + part-parent -> merge part -> merge table. Undirected excludes + PERIPHERAL_TABLES like interval_list, nwbfile, etc. + + Returns + ------- + List[str] + List of names in the path. + """ + source, target = self.parent, self.child + search_graph = self.graph + + if not directed: + self.connection.dependencies.load() + self.undirect_graph = self.connection.dependencies.to_undirected() + search_graph = self.undirect_graph + + search_graph.remove_nodes_from(self.no_visit) + + try: + path = shortest_path(search_graph, source, target) + except NetworkXNoPath: + return None # No path found, parent func may do undirected search + except NodeNotFound: + self.searched_path = True # No path found, don't search again + return None + + self._log_truncate(f"Path Found : {path}") + + ignore_nodes = self.graph.nodes - set(path) + self.no_visit.update(ignore_nodes) + + self._log_truncate(f"Ignore : {ignore_nodes}") + return path + + @cached_property + def path(self) -> list: + """Return list of full table names in chain.""" + if self.searched_path and not self.has_link: + return None + + path = None + if path := self.find_path(directed=True): + self.link_type = "directed" + elif path := self.find_path(directed=False): + self.link_type = "undirected" + self.searched_path = True + + return path + + def cascade(self, restriction: str = None, direction: Direction = None): + if not self.has_link: + return + + _ = self.path + + direction = Direction(direction) or self.direction + if direction == Direction.UP: + start, end = self.child, self.parent + elif direction == Direction.DOWN: + start, end = self.parent, self.child + else: + raise ValueError(f"Invalid direction: {direction}") + + self.cascade1( + table=start, + restriction=restriction or self._get_restr(start), + direction=direction, + replace=True, + ) + + return self._get_ft(end, with_restr=True) + + def restrict_by(self, *args, **kwargs) -> None: + """Cascade passthrough.""" + return self.cascade(*args, **kwargs) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 7af1fb2b4..89b1950cd 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -2,16 +2,40 @@ import inspect import os -from typing import Type +from typing import List, Type, Union import datajoint as dj import numpy as np from datajoint.user_tables import UserTable -from spyglass.utils.dj_chains import PERIPHERAL_TABLES from spyglass.utils.logging import logger from spyglass.utils.nwb_helper_fn import get_nwb_file +# Tables that should be excluded from the undirected graph when finding paths +# for TableChain objects and searching for an upstream key. +PERIPHERAL_TABLES = [ + "`common_interval`.`interval_list`", + "`common_nwbfile`.`__analysis_nwbfile_kachery`", + "`common_nwbfile`.`__nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile_kachery_selection`", + "`common_nwbfile`.`analysis_nwbfile_kachery`", + "`common_nwbfile`.`analysis_nwbfile`", + "`common_nwbfile`.`kachery_channel`", + "`common_nwbfile`.`nwbfile_kachery_selection`", + "`common_nwbfile`.`nwbfile_kachery`", + "`common_nwbfile`.`nwbfile`", +] + + +def fuzzy_get(index: Union[int, str], names: List[str], sources: List[str]): + """Given lists of items/names, return item at index or by substring.""" + if isinstance(index, int): + return sources[index] + for i, part in enumerate(names): + if index in part: + return sources[i] + return None + def unique_dicts(list_of_dict): """Remove duplicate dictionaries from a list.""" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 2b8aab5ef..0b8f16de6 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,8 +1,8 @@ -import re from contextlib import nullcontext from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint +from re import sub as re_sub from time import time from typing import Union @@ -10,7 +10,7 @@ from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html -from datajoint.utils import from_camel_case, to_camel_case +from datajoint.utils import from_camel_case, get_master, to_camel_case from IPython.core.display import HTML from spyglass.utils.logging import logger @@ -25,23 +25,29 @@ def is_merge_table(table): - """Return True if table definition matches the default Merge table. + """Return True if table fields exactly match Merge table.""" - Regex removes comments and blank lines before comparison. - """ + def trim_def(definition): + return re_sub( + r"\n\s*\n", "\n", re_sub(r"#.*\n", "\n", definition.strip()) + ) + + if isinstance(table, str): + table = dj.FreeTable(dj.conn(), table) if not isinstance(table, dj.Table): return False - if isinstance(table, dj.FreeTable): - fields, pk = table.heading.names, table.primary_key - return fields == [ - RESERVED_PRIMARY_KEY, - RESERVED_SECONDARY_KEY, - ] and pk == [RESERVED_PRIMARY_KEY] - return MERGE_DEFINITION == re.sub( - r"\n\s*\n", - "\n", - re.sub(r"#.*\n", "\n", getattr(table, "definition", "")), - ) + if get_master(table.full_table_name): + return False # Part tables are not merge tables + if not table.is_declared: + if tbl_def := getattr(table, "definition", None): + return trim_def(MERGE_DEFINITION) == trim_def(tbl_def) + logger.warning( + f"Cannot determine merge table status for {table.table_name}" + ) + return True + return table.primary_key == [ + RESERVED_PRIMARY_KEY + ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY] class Merge(dj.Manual): @@ -62,8 +68,8 @@ def __init__(self): if not is_merge_table(self): # Check definition logger.warn( "Merge table with non-default definition\n" - + f"Expected: {MERGE_DEFINITION.strip()}\n" - + f"Actual : {self.definition.strip()}" + + f"Expected:\n{MERGE_DEFINITION.strip()}\n" + + f"Actual :\n{self.definition.strip()}" ) for part in self.parts(as_objects=True): if part.primary_key != self.primary_key: @@ -74,12 +80,6 @@ def __init__(self): ) self._source_class_dict = {} - def _remove_comments(self, definition): - """Use regular expressions to remove comments and blank lines""" - return re.sub( # First remove comments, then blank lines - r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition) - ) - @staticmethod def _part_name(part=None): """Return the CamelCase name of a part table""" @@ -141,9 +141,6 @@ def _merge_restrict_parts( cls._ensure_dependencies_loaded() - if not restriction: - restriction = True - # Normalize restriction to sql string restr_str = make_condition(cls(), restriction, set()) @@ -387,8 +384,7 @@ def _ensure_dependencies_loaded(cls) -> None: Otherwise parts returns none """ - if not dj.conn.connection.dependencies._loaded: - dj.conn.connection.dependencies.load() + dj.conn.connection.dependencies.load() def insert(self, rows: list, **kwargs): """Merges table specific insert, ensuring data exists in part parents. @@ -783,7 +779,7 @@ def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: "No merge_fetch results.\n\t" + "If not restricting, try: `M.merge_fetch(True,'attr')\n\t" + "If restricting by source, use dict: " - + "`M.merge_fetch({'source':'X'})" + + "`M.merge_fetch({'source':'X'}" ) return results[0] if len(results) == 1 else results @@ -818,7 +814,7 @@ def super_delete(self, warn=True, *args, **kwargs): """ if warn: logger.warning("!! Bypassing cautious_delete !!") - self._log_use(start=time(), super_delete=True) + self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) @@ -830,10 +826,6 @@ def super_delete(self, warn=True, *args, **kwargs): def delete_downstream_merge( table: dj.Table, - restriction: str = None, - dry_run=True, - recurse_level=2, - disable_warning=False, **kwargs, ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries @@ -858,12 +850,15 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ + logger.warning( + "DEPRECATED: This function will be removed in `0.6`. " + + "Use AnyTable().delete_downstream_merge() instead." + ) + from spyglass.utils.dj_mixin import SpyglassMixin if not isinstance(table, SpyglassMixin): raise ValueError("Input must be a Spyglass Table.") table = table if isinstance(table, dj.Table) else table() - return table.delete_downstream_merge( - restriction=restriction, dry_run=dry_run, **kwargs - ) + return table.delete_downstream_merge(**kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 05f510193..08fa377b3 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -13,12 +13,11 @@ from datajoint.expression import QueryExpression from datajoint.logging import logger as dj_logger from datajoint.table import Table -from datajoint.utils import get_master, user_choice +from datajoint.utils import get_master, to_camel_case, user_choice from networkx import NetworkXError from pymysql.err import DataError from spyglass.utils.database_settings import SHARED_MODULES -from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb, get_nwb_table from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.dj_merge_tables import Merge, is_merge_table @@ -71,6 +70,8 @@ class SpyglassMixin: _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure + _banned_search_tables = set() # Tables to avoid in restrict_by + def __init__(self, *args, **kwargs): """Initialize SpyglassMixin. @@ -93,6 +94,33 @@ def __init__(self, *args, **kwargs): + self.full_table_name ) + # -------------------------- Misc helper methods -------------------------- + + @property + def camel_name(self): + """Return table name in camel case.""" + return to_camel_case(self.table_name) + + def _auto_increment(self, key, pk, *args, **kwargs): + """Auto-increment primary key.""" + if not key.get(pk): + key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 + return key + + def file_like(self, name=None, **kwargs): + """Convenience method for wildcard search on file name fields.""" + if not name: + return self & True + attr = None + for field in self.heading.names: + if "file" in field: + attr = field + break + if not attr: + logger.error(f"No file-like field found in {self.full_table_name}") + return + return self & f"{attr} LIKE '%{name}%'" + # ------------------------------- fetch_nwb ------------------------------- @cached_property @@ -203,6 +231,26 @@ def fetch_pynapple(self, *attrs, **kwargs): # ------------------------ delete_downstream_merge ------------------------ + def _import_merge_tables(self): + """Import all merge tables downstream of self.""" + from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401 + from spyglass.lfp.lfp_merge import LFPOutput # noqa F401 + from spyglass.linearization.merge import ( + LinearizedPositionOutput, + ) # noqa F401 + from spyglass.position.position_merge import PositionOutput # noqa F401 + from spyglass.spikesorting.spikesorting_merge import ( # noqa F401 + SpikeSortingOutput, + ) + + _ = ( + DecodingOutput(), + LFPOutput(), + LinearizedPositionOutput(), + PositionOutput(), + SpikeSortingOutput(), + ) + @cached_property def _merge_tables(self) -> Dict[str, dj.FreeTable]: """Dict of merge tables downstream of self: {full_table_name: FreeTable}. @@ -215,10 +263,6 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: visited = set() def search_descendants(parent): - # TODO: Add check that parents are in the graph. If not, raise error - # asking user to import the table. - # TODO: Make a `is_merge_table` helper, and check for false - # positives in the mixin init. for desc in parent.descendants(as_objects=True): if ( MERGE_PK not in desc.heading.names @@ -235,12 +279,16 @@ def search_descendants(parent): try: _ = search_descendants(self) - except NetworkXError as e: - table_name = "".join(e.args[0].split("`")[1:4]) - raise ValueError(f"Please import {table_name} and try again.") + except NetworkXError: + try: # Attempt to import missing table + self._import_merge_tables() + _ = search_descendants(self) + except NetworkXError as e: + table_name = "".join(e.args[0].split("`")[1:4]) + raise ValueError(f"Please import {table_name} and try again.") logger.info( - f"Building merge cache for {self.table_name}.\n\t" + f"Building merge cache for {self.camel_name}.\n\t" + f"Found {len(merge_tables)} downstream merge tables" ) @@ -258,9 +306,11 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: with a new restriction. To recompute, add `reload_cache=True` to delete_downstream_merge call. """ + from spyglass.utils.dj_graph import TableChains # noqa F401 + merge_chains = {} for name, merge_table in self._merge_tables.items(): - chains = TableChains(self, merge_table, connection=self.connection) + chains = TableChains(self, merge_table) if len(chains): merge_chains[name] = chains @@ -268,13 +318,14 @@ def _merge_chains(self) -> OrderedDict[str, List[dj.FreeTable]]: # that the merge table with the longest chain is the most downstream. # A more sophisticated approach would order by length from self to # each merge part independently, but this is a good first approximation. + return OrderedDict( sorted( merge_chains.items(), key=lambda x: x[1].max_len, reverse=True ) ) - def _get_chain(self, substring) -> TableChains: + def _get_chain(self, substring): """Return chain from self to merge table with substring in name.""" for name, chain in self._merge_chains.items(): if substring.lower() in name: @@ -330,20 +381,19 @@ def delete_downstream_merge( Passed to datajoint.table.Table.delete. """ if reload_cache: - del self._merge_tables - del self._merge_chains + for attr in ["_merge_tables", "_merge_chains"]: + _ = self.__dict__.pop(attr, None) restriction = restriction or self.restriction or True merge_join_dict = {} for name, chain in self._merge_chains.items(): - join = chain.join(restriction) - if join: + if join := chain.cascade(restriction, direction="down"): merge_join_dict[name] = join if not merge_join_dict and not disable_warning: logger.warning( - f"No merge deletes found w/ {self.table_name} & " + f"No merge deletes found w/ {self.camel_name} & " + f"{restriction}.\n\tIf this is unexpected, try importing " + " Merge table(s) and running with `reload_cache`." ) @@ -424,8 +474,10 @@ def _get_exp_summary(self): return exp_missing + exp_present @cached_property - def _session_connection(self) -> Union[TableChain, bool]: + def _session_connection(self): """Path from Session table to self. False if no connection found.""" + from spyglass.utils.dj_graph import TableChain # noqa F401 + connection = TableChain(parent=self._delete_deps[-1], child=self) return connection if connection.has_link else False @@ -716,27 +768,132 @@ def fetch1(self, *args, log_fetch=True, **kwargs): self._log_fetch(*args, **kwargs) return ret - # ------------------------- Other helper methods ------------------------- + # ------------------------------ Restrict by ------------------------------ - def _auto_increment(self, key, pk, *args, **kwargs): - """Auto-increment primary key.""" - if not key.get(pk): - key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 - return key + def __lshift__(self, restriction) -> QueryExpression: + """Restriction by upstream operator e.g. ``q1 << q2``. - def file_like(self, name=None, **kwargs): - """Convenience method for wildcard search on file name fields.""" - if not name: - return self & True - attr = None - for field in self.heading.names: - if "file" in field: - attr = field - break - if not attr: - logger.error(f"No file-like field found in {self.full_table_name}") - return - return self & f"{attr} LIKE '%{name}%'" + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_by(restriction, direction="up") + + def __rshift__(self, restriction) -> QueryExpression: + """Restriction by downstream operator e.g. ``q1 >> q2``. + + Returns + ------- + QueryExpression + A restricted copy of the query expression using the nearest upstream + table for which the restriction is valid. + """ + return self.restrict_by(restriction, direction="down") + + def _ensure_names(self, tables) -> List[str]: + """Ensure table is a string in a list.""" + if not isinstance(tables, (list, tuple, set)): + tables = [tables] + for table in tables: + return [getattr(table, "full_table_name", table) for t in tables] + + def ban_search_table(self, table): + """Ban table from search in restrict_by.""" + self._banned_search_tables.update(self._ensure_names(table)) + + def unban_search_table(self, table): + """Unban table from search in restrict_by.""" + self._banned_search_tables.difference_update(self._ensure_names(table)) + + def see_banned_tables(self): + """Print banned tables.""" + logger.info(f"Banned tables: {self._banned_search_tables}") + + def restrict_by( + self, + restriction: str = True, + direction: str = "up", + return_graph: bool = False, + verbose: bool = False, + **kwargs, + ) -> QueryExpression: + """Restrict self based on up/downstream table. + + If fails to restrict table, the shortest path may not have been correct. + If there's a different path that should be taken, ban unwanted tables. + + >>> my_table = MyTable() # must be instantced + >>> my_table.ban_search_table(UnwantedTable1) + >>> my_table.ban_search_table([UnwantedTable2, UnwantedTable3]) + >>> my_table.unban_search_table(UnwantedTable3) + >>> my_table.see_banned_tables() + >>> + >>> my_table << my_restriction + + Parameters + ---------- + restriction : str + Restriction to apply to the some table up/downstream of self. + direction : str, optional + Direction to search for valid restriction. Default 'up'. + return_graph : bool, optional + If True, return FindKeyGraph object. Default False, returns + restricted version of present table. + verbose : bool, optional + If True, print verbose output. Default False. + + Returns + ------- + Union[QueryExpression, FindKeyGraph] + Restricted version of present table or FindKeyGraph object. If + return_graph, use all_ft attribute to see all tables in cascade. + """ + from spyglass.utils.dj_graph import TableChain # noqa: F401 + + if restriction is True: + return self + + try: + ret = self.restrict(restriction) # Save time trying first + if len(ret) < len(self): + logger.warning("Restriction valid for this table. Using as is.") + return ret + except DataJointError: + pass # Could avoid try/except if assert_join_compatible return bool + logger.debug("Restriction not valid. Attempting to cascade.") + + if direction == "up": + parent, child = None, self + elif direction == "down": + parent, child = self, None + else: + raise ValueError("Direction must be 'up' or 'down'.") + + graph = TableChain( + parent=parent, + child=child, + direction=direction, + search_restr=restriction, + banned_tables=list(self._banned_search_tables), + allow_merge=True, + cascade=True, + verbose=verbose, + **kwargs, + ) + + if return_graph: + return graph + + ret = graph.leaf_ft[0] + if len(ret) == len(self) or len(ret) == 0: + logger.warning( + f"Failed to restrict with path: {graph.path_str}\n\t" + + "See `help(YourTable.restrict_by)`" + ) + + return ret class SpyglassMixinPart(SpyglassMixin, dj.Part): diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 43eb70aa9..de7671b42 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -513,7 +513,7 @@ def get_nwb_copy_filename(nwb_file_name): def change_group_permissions( subject_ids, set_group_name, analysis_dir="/stelmo/nwb/analysis" ): - logger.warning("This function is deprecated and will be removed soon.") + logger.warning("DEPRECATED: This function will be removed in `0.6`.") # Change to directory with analysis nwb files os.chdir(analysis_dir) # Get nwb file directories with specified subject ids diff --git a/tests/common/test_device.py b/tests/common/test_device.py index 49bbd9027..19103cf98 100644 --- a/tests/common/test_device.py +++ b/tests/common/test_device.py @@ -2,7 +2,7 @@ from numpy import array_equal -def test_invalid_device(common, populate_exception): +def test_invalid_device(common, populate_exception, mini_insert): device_dict = common.DataAcquisitionDevice.fetch(as_dict=True)[0] device_dict["other"] = "invalid" with pytest.raises(populate_exception): diff --git a/tests/conftest.py b/tests/conftest.py index 0bcb4a3fd..7950854d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,10 @@ +"""Configuration for pytest, including fixtures and command line options. + +Fixtures in this script are mad available to all tests in the test suite. +conftest.py files in subdirectories have fixtures that are only available to +tests in that subdirectory. +""" + import os import sys import warnings @@ -7,17 +14,19 @@ from time import sleep as tsleep import datajoint as dj +import numpy as np import pynwb import pytest from datajoint.logging import logger as dj_logger from .container import DockerMySQLManager -# ---------------------- CONSTANTS --------------------- +warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") + +# ------------------------------- TESTS CONFIG ------------------------------- # globals in pytest_configure: # BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD -warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") def pytest_addoption(parser): @@ -131,7 +140,7 @@ def pytest_unconfigure(config): SERVER.stop() -# ------------------- FIXTURES ------------------- +# ---------------------------- FIXTURES, TEST ENV ---------------------------- @pytest.fixture(scope="session") @@ -143,6 +152,27 @@ def verbose(): @pytest.fixture(scope="session", autouse=True) def verbose_context(verbose): """Verbosity context for suppressing Spyglass logging.""" + + class QuietStdOut: + """Used to quiet all prints and logging as context manager.""" + + def __init__(self): + from spyglass.utils import logger as spyglass_logger + + self.spy_logger = spyglass_logger + self.previous_level = None + + def __enter__(self): + self.previous_level = self.spy_logger.getEffectiveLevel() + self.spy_logger.setLevel("CRITICAL") + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + self.spy_logger.setLevel(self.previous_level) + sys.stdout.close() + sys.stdout = self._original_stdout + yield nullcontext() if verbose else QuietStdOut() @@ -193,6 +223,9 @@ def raw_dir(base_dir): yield base_dir / "raw" +# ------------------------------- FIXTURES, DATA ------------------------------- + + @pytest.fixture(scope="session") def mini_path(raw_dir): path = raw_dir / TEST_FILE @@ -251,12 +284,14 @@ def load_config(dj_conn, base_dir): from spyglass.settings import SpyglassConfig yield SpyglassConfig().load_config( - base_dir=base_dir, test_mode=True, force_reload=True + base_dir=base_dir, debug_mode=False, test_mode=True, force_reload=True ) @pytest.fixture(autouse=True, scope="session") -def mini_insert(mini_path, teardown, server, load_config): +def mini_insert( + dj_conn, mini_path, mini_content, teardown, server, load_config +): from spyglass.common import LabMember, Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.spikesorting.spikesorting_merge import ( # noqa: E402 @@ -264,6 +299,8 @@ def mini_insert(mini_path, teardown, server, load_config): ) from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + _ = SpikeSortingOutput() + LabMember().insert1( ["Root User", "Root", "User"], skip_duplicates=not teardown ) @@ -287,8 +324,7 @@ def mini_insert(mini_path, teardown, server, load_config): yield close_nwb_files() - # Note: no need to run deletes in teardown, since we are using teardown - # will remove the container + # Note: no need to run deletes in teardown, bc removing the container @pytest.fixture(scope="session") @@ -301,6 +337,9 @@ def mini_dict(mini_copy_name): yield {"nwb_file_name": mini_copy_name} +# --------------------------- FIXTURES, SUBMODULES --------------------------- + + @pytest.fixture(scope="session") def common(dj_conn): from spyglass import common @@ -322,6 +361,41 @@ def settings(dj_conn): yield settings +@pytest.fixture(scope="session") +def sgp(common): + from spyglass import position + + yield position + + +@pytest.fixture(scope="session") +def lfp(common): + from spyglass import lfp + + return lfp + + +@pytest.fixture(scope="session") +def lfp_band(lfp): + from spyglass.lfp.analysis.v1 import lfp_band + + return lfp_band + + +@pytest.fixture(scope="session") +def sgl(common): + from spyglass import linearization + + yield linearization + + +@pytest.fixture(scope="session") +def sgpl(sgl): + from spyglass.linearization import v1 + + yield v1 + + @pytest.fixture(scope="session") def populate_exception(): from spyglass.common.errors import PopulateException @@ -329,11 +403,7 @@ def populate_exception(): yield PopulateException -@pytest.fixture(scope="session") -def sgp(common): - from spyglass import position - - yield position +# ------------------------- FIXTURES, POSITION TABLES ------------------------- @pytest.fixture(scope="session") @@ -418,12 +488,16 @@ def trodes_pos_v1(teardown, sgp, trodes_sel_keys): def pos_merge_tables(dj_conn): """Return the merge tables as activated.""" from spyglass.common.common_position import TrackGraph + from spyglass.lfp.lfp_merge import LFPOutput from spyglass.linearization.merge import LinearizedPositionOutput from spyglass.position.position_merge import PositionOutput # must import common_position before LinOutput to avoid circular import - _ = TrackGraph() + + # import LFPOutput to use when testing mixin cascade + _ = LFPOutput() + return [PositionOutput(), LinearizedPositionOutput()] @@ -442,25 +516,258 @@ def pos_merge_key(pos_merge, trodes_pos_v1, trodes_sel_keys): yield pos_merge.merge_get_part(trodes_sel_keys[-1]).fetch1("KEY") -# ------------------ GENERAL FUNCTION ------------------ +# ---------------------- FIXTURES, LINEARIZATION TABLES ---------------------- +# ---------------------- Note: Used to test RestrGraph ----------------------- + + +@pytest.fixture(scope="session") +def pos_lin_key(trodes_sel_keys): + yield trodes_sel_keys[-1] + + +@pytest.fixture(scope="session") +def position_info(pos_merge, pos_merge_key): + yield (pos_merge & {"merge_id": pos_merge_key}).fetch1_dataframe() + + +@pytest.fixture(scope="session") +def track_graph_key(): + yield {"track_graph_name": "6 arm"} + + +@pytest.fixture(scope="session") +def track_graph(teardown, sgpl, track_graph_key): + node_positions = np.array( + [ + (79.910, 216.720), # top left well 0 + (132.031, 187.806), # top middle intersection 1 + (183.718, 217.713), # top right well 2 + (132.544, 132.158), # middle intersection 3 + (87.202, 101.397), # bottom left intersection 4 + (31.340, 126.110), # middle left well 5 + (180.337, 104.799), # middle right intersection 6 + (92.693, 42.345), # bottom left well 7 + (183.784, 45.375), # bottom right well 8 + (231.338, 136.281), # middle right well 9 + ] + ) + + edges = np.array( + [ + (0, 1), + (1, 2), + (1, 3), + (3, 4), + (4, 5), + (3, 6), + (6, 9), + (4, 7), + (6, 8), + ] + ) + + linear_edge_order = [ + (3, 6), + (6, 8), + (6, 9), + (3, 1), + (1, 2), + (1, 0), + (3, 4), + (4, 5), + (4, 7), + ] + linear_edge_spacing = 15 + + sgpl.TrackGraph.insert1( + { + **track_graph_key, + "environment": track_graph_key["track_graph_name"], + "node_positions": node_positions, + "edges": edges, + "linear_edge_order": linear_edge_order, + "linear_edge_spacing": linear_edge_spacing, + }, + skip_duplicates=True, + ) + + yield sgpl.TrackGraph & {"track_graph_name": "6 arm"} + if teardown: + sgpl.TrackGraph().delete(safemode=False) + + +@pytest.fixture(scope="session") +def lin_param_key(): + yield {"linearization_param_name": "default"} + + +@pytest.fixture(scope="session") +def lin_params( + teardown, + sgpl, + lin_param_key, +): + param_table = sgpl.LinearizationParameters() + param_table.insert1(lin_param_key, skip_duplicates=True) + yield param_table + + +@pytest.fixture(scope="session") +def lin_sel_key( + pos_merge_key, track_graph_key, lin_param_key, lin_params, track_graph +): + yield { + "pos_merge_id": pos_merge_key["merge_id"], + **track_graph_key, + **lin_param_key, + } + + +@pytest.fixture(scope="session") +def lin_sel(teardown, sgpl, lin_sel_key): + sel_table = sgpl.LinearizationSelection() + sel_table.insert1(lin_sel_key, skip_duplicates=True) + yield sel_table + if teardown: + sel_table.delete(safemode=False) -class QuietStdOut: - """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" +@pytest.fixture(scope="session") +def lin_v1(teardown, sgpl, lin_sel): + v1 = sgpl.LinearizedPositionV1() + v1.populate() + yield v1 + if teardown: + v1.delete(safemode=False) + + +@pytest.fixture(scope="session") +def lin_merge_key(lin_merge, lin_v1, lin_sel_key): + yield lin_merge.merge_get_part(lin_sel_key).fetch1("KEY") - def __init__(self): - from spyglass.utils import logger as spyglass_logger - self.spy_logger = spyglass_logger - self.previous_level = None +# --------------------------- FIXTURES, LFP TABLES --------------------------- +# ---------------- Note: LFPOuput is used to test RestrGraph ----------------- + + +@pytest.fixture(scope="module") +def lfp_band_v1(lfp_band): + yield lfp_band.LFPBandV1() + + +@pytest.fixture(scope="session") +def firfilters_table(common): + return common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def electrodegroup_table(lfp): + return lfp.v1.LFPElectrodeGroup() - def __enter__(self): - self.previous_level = self.spy_logger.getEffectiveLevel() - self.spy_logger.setLevel("CRITICAL") - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - def __exit__(self, exc_type, exc_val, exc_tb): - self.spy_logger.setLevel(self.previous_level) - sys.stdout.close() - sys.stdout = self._original_stdout +@pytest.fixture(scope="session") +def lfp_constants(common, mini_copy_name, mini_dict): + n_delay = 9 + lfp_electrode_group_name = "test" + orig_list_name = "01_s1" + orig_valid_times = ( + common.IntervalList + & mini_dict + & f"interval_list_name = '{orig_list_name}'" + ).fetch1("valid_times") + new_list_name = orig_list_name + f"_first{n_delay}" + new_list_key = { + "nwb_file_name": mini_copy_name, + "interval_list_name": new_list_name, + "valid_times": np.asarray( + [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] + ), + } + + yield dict( + lfp_electrode_ids=[0], + lfp_electrode_group_name=lfp_electrode_group_name, + lfp_eg_key={ + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + }, + n_delay=n_delay, + orig_interval_list_name=orig_list_name, + orig_valid_times=orig_valid_times, + interval_list_name=new_list_name, + interval_key=new_list_key, + filter1_name="LFP 0-400 Hz", + filter_sampling_rate=30_000, + filter2_name="Theta 5-11 Hz", + lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes + lfp_band_sampling_rate=100, # desired sampling rate + ) + + +@pytest.fixture(scope="session") +def add_electrode_group( + firfilters_table, + electrodegroup_table, + mini_copy_name, + lfp_constants, +): + firfilters_table.create_standard_filters() + group_name = lfp_constants.get("lfp_electrode_group_name") + electrodegroup_table.create_lfp_electrode_group( + nwb_file_name=mini_copy_name, + group_name=group_name, + electrode_list=np.array(lfp_constants.get("lfp_electrode_ids")), + ) + assert len( + electrodegroup_table & {"lfp_electrode_group_name": group_name} + ), "Failed to add LFPElectrodeGroup." + yield + + +@pytest.fixture(scope="session") +def add_interval(common, lfp_constants): + common.IntervalList.insert1( + lfp_constants.get("interval_key"), skip_duplicates=True + ) + yield lfp_constants.get("interval_list_name") + + +@pytest.fixture(scope="session") +def add_selection( + lfp, common, add_electrode_group, add_interval, lfp_constants +): + lfp_s_key = { + **lfp_constants.get("lfp_eg_key"), + "target_interval_list_name": add_interval, + "filter_name": lfp_constants.get("filter1_name"), + "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), + } + lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) + yield lfp_s_key + + +@pytest.fixture(scope="session") +def lfp_s_key(lfp_constants, mini_copy_name): + yield { + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_constants.get( + "lfp_electrode_group_name" + ), + "target_interval_list_name": lfp_constants.get("interval_list_name"), + } + + +@pytest.fixture(scope="session") +def populate_lfp(lfp, add_selection, lfp_s_key): + lfp.v1.LFPV1().populate(add_selection) + yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} + + +@pytest.fixture(scope="session") +def lfp_merge_key(populate_lfp): + yield populate_lfp + + +@pytest.fixture(scope="session") +def lfp_v1_key(lfp, lfp_s_key): + yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") diff --git a/tests/container.py b/tests/container.py index 04e176fee..fa26f1c46 100644 --- a/tests/container.py +++ b/tests/container.py @@ -193,7 +193,7 @@ def creds(self): "database.user": self.user, "database.port": int(self.port), "safemode": "false", - "custom": {"test_mode": True}, + "custom": {"test_mode": True, "debug_mode": False}, } @property diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index 354803493..e62a03dea 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -1,140 +1,7 @@ -import numpy as np import pytest from pynwb import NWBHDF5IO -@pytest.fixture(scope="session") -def lfp(common): - from spyglass import lfp - - return lfp - - -@pytest.fixture(scope="session") -def lfp_band(lfp): - from spyglass.lfp.analysis.v1 import lfp_band - - return lfp_band - - -@pytest.fixture(scope="session") -def firfilters_table(common): - return common.FirFilterParameters() - - -@pytest.fixture(scope="session") -def electrodegroup_table(lfp): - return lfp.v1.LFPElectrodeGroup() - - -@pytest.fixture(scope="session") -def lfp_constants(common, mini_copy_name, mini_dict): - n_delay = 9 - lfp_electrode_group_name = "test" - orig_list_name = "01_s1" - orig_valid_times = ( - common.IntervalList - & mini_dict - & f"interval_list_name = '{orig_list_name}'" - ).fetch1("valid_times") - new_list_name = orig_list_name + f"_first{n_delay}" - new_list_key = { - "nwb_file_name": mini_copy_name, - "interval_list_name": new_list_name, - "valid_times": np.asarray( - [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] - ), - } - - yield dict( - lfp_electrode_ids=[0], - lfp_electrode_group_name=lfp_electrode_group_name, - lfp_eg_key={ - "nwb_file_name": mini_copy_name, - "lfp_electrode_group_name": lfp_electrode_group_name, - }, - n_delay=n_delay, - orig_interval_list_name=orig_list_name, - orig_valid_times=orig_valid_times, - interval_list_name=new_list_name, - interval_key=new_list_key, - filter1_name="LFP 0-400 Hz", - filter_sampling_rate=30_000, - filter2_name="Theta 5-11 Hz", - lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes - lfp_band_sampling_rate=100, # desired sampling rate - ) - - -@pytest.fixture(scope="session") -def add_electrode_group( - firfilters_table, - electrodegroup_table, - mini_copy_name, - lfp_constants, -): - firfilters_table.create_standard_filters() - group_name = lfp_constants.get("lfp_electrode_group_name") - electrodegroup_table.create_lfp_electrode_group( - nwb_file_name=mini_copy_name, - group_name=group_name, - electrode_list=np.array(lfp_constants.get("lfp_electrode_ids")), - ) - assert len( - electrodegroup_table & {"lfp_electrode_group_name": group_name} - ), "Failed to add LFPElectrodeGroup." - yield - - -@pytest.fixture(scope="session") -def add_interval(common, lfp_constants): - common.IntervalList.insert1( - lfp_constants.get("interval_key"), skip_duplicates=True - ) - yield lfp_constants.get("interval_list_name") - - -@pytest.fixture(scope="session") -def add_selection( - lfp, common, add_electrode_group, add_interval, lfp_constants -): - lfp_s_key = { - **lfp_constants.get("lfp_eg_key"), - "target_interval_list_name": add_interval, - "filter_name": lfp_constants.get("filter1_name"), - "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), - } - lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) - yield lfp_s_key - - -@pytest.fixture(scope="session") -def lfp_s_key(lfp_constants, mini_copy_name): - yield { - "nwb_file_name": mini_copy_name, - "lfp_electrode_group_name": lfp_constants.get( - "lfp_electrode_group_name" - ), - "target_interval_list_name": lfp_constants.get("interval_list_name"), - } - - -@pytest.fixture(scope="session") -def populate_lfp(lfp, add_selection, lfp_s_key): - lfp.v1.LFPV1().populate(add_selection) - yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} - - -@pytest.fixture(scope="session") -def lfp_merge_key(populate_lfp): - yield populate_lfp - - -@pytest.fixture(scope="session") -def lfp_v1_key(lfp, lfp_s_key): - yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") - - @pytest.fixture(scope="module") def lfp_analysis_raw(common, lfp, populate_lfp, mini_dict): abs_path = (common.AnalysisNwbfile * lfp.v1.LFPV1 & mini_dict).fetch( diff --git a/tests/lfp/test_lfp.py b/tests/lfp/test_lfp.py index 51b2e96f4..b496ae445 100644 --- a/tests/lfp/test_lfp.py +++ b/tests/lfp/test_lfp.py @@ -37,11 +37,6 @@ def test_lfp_band_dataframe(lfp_band_analysis_raw, lfp_band, lfp_band_key): assert df_raw.equals(df_fetch), "LFPBand dataframe not match." -@pytest.fixture(scope="module") -def lfp_band_v1(lfp_band): - yield lfp_band.LFPBandV1() - - def test_lfp_band_compute_signal_invalid(lfp_band_v1): with pytest.raises(ValueError): lfp_band_v1.compute_analytic_signal([4]) diff --git a/tests/linearization/conftest.py b/tests/linearization/conftest.py deleted file mode 100644 index 505dcc816..000000000 --- a/tests/linearization/conftest.py +++ /dev/null @@ -1,142 +0,0 @@ -import numpy as np -import pytest - - -@pytest.fixture(scope="session") -def sgl(common): - from spyglass import linearization - - yield linearization - - -@pytest.fixture(scope="session") -def sgpl(sgl): - from spyglass.linearization import v1 - - yield v1 - - -@pytest.fixture(scope="session") -def pos_lin_key(trodes_sel_keys): - yield trodes_sel_keys[-1] - - -@pytest.fixture(scope="session") -def position_info(pos_merge, pos_merge_key): - yield (pos_merge & {"merge_id": pos_merge_key}).fetch1_dataframe() - - -@pytest.fixture(scope="session") -def track_graph_key(): - yield {"track_graph_name": "6 arm"} - - -@pytest.fixture(scope="session") -def track_graph(teardown, sgpl, track_graph_key): - node_positions = np.array( - [ - (79.910, 216.720), # top left well 0 - (132.031, 187.806), # top middle intersection 1 - (183.718, 217.713), # top right well 2 - (132.544, 132.158), # middle intersection 3 - (87.202, 101.397), # bottom left intersection 4 - (31.340, 126.110), # middle left well 5 - (180.337, 104.799), # middle right intersection 6 - (92.693, 42.345), # bottom left well 7 - (183.784, 45.375), # bottom right well 8 - (231.338, 136.281), # middle right well 9 - ] - ) - - edges = np.array( - [ - (0, 1), - (1, 2), - (1, 3), - (3, 4), - (4, 5), - (3, 6), - (6, 9), - (4, 7), - (6, 8), - ] - ) - - linear_edge_order = [ - (3, 6), - (6, 8), - (6, 9), - (3, 1), - (1, 2), - (1, 0), - (3, 4), - (4, 5), - (4, 7), - ] - linear_edge_spacing = 15 - - sgpl.TrackGraph.insert1( - { - **track_graph_key, - "environment": track_graph_key["track_graph_name"], - "node_positions": node_positions, - "edges": edges, - "linear_edge_order": linear_edge_order, - "linear_edge_spacing": linear_edge_spacing, - }, - skip_duplicates=True, - ) - - yield sgpl.TrackGraph & {"track_graph_name": "6 arm"} - if teardown: - sgpl.TrackGraph().delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_param_key(): - yield {"linearization_param_name": "default"} - - -@pytest.fixture(scope="session") -def lin_params( - teardown, - sgpl, - lin_param_key, -): - param_table = sgpl.LinearizationParameters() - param_table.insert1(lin_param_key, skip_duplicates=True) - yield param_table - - -@pytest.fixture(scope="session") -def lin_sel_key( - pos_merge_key, track_graph_key, lin_param_key, lin_params, track_graph -): - yield { - "pos_merge_id": pos_merge_key["merge_id"], - **track_graph_key, - **lin_param_key, - } - - -@pytest.fixture(scope="session") -def lin_sel(teardown, sgpl, lin_sel_key): - sel_table = sgpl.LinearizationSelection() - sel_table.insert1(lin_sel_key, skip_duplicates=True) - yield sel_table - if teardown: - sel_table.delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_v1(teardown, sgpl, lin_sel): - v1 = sgpl.LinearizedPositionV1() - v1.populate() - yield v1 - if teardown: - v1.delete(safemode=False) - - -@pytest.fixture(scope="session") -def lin_merge_key(lin_merge, lin_sel_key): - yield lin_merge.merge_get_part(lin_sel_key).fetch1("KEY") diff --git a/tests/linearization/test_lin.py b/tests/linearization/test_lin.py index 4225ad5bf..a5db28d9a 100644 --- a/tests/linearization/test_lin.py +++ b/tests/linearization/test_lin.py @@ -9,4 +9,4 @@ def test_fetch1_dataframe(lin_v1, lin_merge, lin_merge_key): assert hash_df == hash_exp, "Dataframe differs from expected" -## Todo: Add more tests of this pipeline, not just the fetch1_dataframe method +# TODO: Add more tests of this pipeline, not just the fetch1_dataframe method diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 3503f9649..a4bc7f900 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -38,6 +38,8 @@ def chains(Nwbfile): ) # noqa: F401 from spyglass.position.position_merge import PositionOutput # noqa: F401 + _ = LFPOutput, LinearizedPositionOutput, PositionOutput + yield Nwbfile._get_chain("linear") @@ -51,6 +53,243 @@ def chain(chains): def no_link_chain(Nwbfile): """Return example TableChain object with no link.""" from spyglass.common.common_usage import InsertError - from spyglass.utils.dj_chains import TableChain + from spyglass.utils.dj_graph import TableChain yield TableChain(Nwbfile, InsertError()) + + +@pytest.fixture(scope="module") +def _Merge(): + """Return the _Merge class.""" + from spyglass.utils import _Merge + + yield _Merge + + +@pytest.fixture(scope="module") +def SpyglassMixin(): + """Return a mixin class.""" + from spyglass.utils import SpyglassMixin + + yield SpyglassMixin + + +@pytest.fixture(scope="module") +def graph_schema(SpyglassMixin, _Merge): + """ + NOTE: Must declare tables within fixture to avoid loading config defaults. + """ + parent_id = range(10) + parent_attr = [i + 10 for i in range(2, 12)] + other_id = range(9) + other_attr = [i + 10 for i in range(3, 12)] + intermediate_id = range(2, 10) + intermediate_attr = [i + 10 for i in range(4, 12)] + pk_id = range(3, 10) + pk_attr = [i + 10 for i in range(5, 12)] + sk_id = range(6) + sk_attr = [i + 10 for i in range(6, 12)] + pk_sk_id = range(5) + pk_sk_attr = [i + 10 for i in range(7, 12)] + pk_alias_id = range(4) + pk_alias_attr = [i + 10 for i in range(8, 12)] + sk_alias_id = range(3) + sk_alias_attr = [i + 10 for i in range(9, 12)] + + def offset(gen, offset): + return list(gen)[offset:] + + class ParentNode(SpyglassMixin, dj.Lookup): + definition = """ + parent_id: int + --- + parent_attr : int + """ + contents = [(i, j) for i, j in zip(parent_id, parent_attr)] + + class OtherParentNode(SpyglassMixin, dj.Lookup): + definition = """ + other_id: int + --- + other_attr : int + """ + contents = [(i, j) for i, j in zip(other_id, other_attr)] + + class IntermediateNode(SpyglassMixin, dj.Lookup): + definition = """ + intermediate_id: int + --- + -> ParentNode + intermediate_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip( + intermediate_id, offset(parent_id, 1), intermediate_attr + ) + ] + + class PkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_id: int + -> IntermediateNode + --- + pk_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip(pk_id, offset(intermediate_id, 2), pk_attr) + ] + + class SkNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_id: int + --- + -> IntermediateNode + sk_attr : int + """ + contents = [ + (i, j, k) + for i, j, k in zip(sk_id, offset(intermediate_id, 3), sk_attr) + ] + + class PkSkNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_sk_id: int + -> IntermediateNode + --- + -> OtherParentNode + pk_sk_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_sk_id, offset(intermediate_id, 4), other_id, pk_sk_attr + ) + ] + + class PkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + pk_alias_id: int + -> PkNode.proj(fk_pk_id='pk_id') + --- + pk_alias_attr : int + """ + contents = [ + (i, j, k, m) + for i, j, k, m in zip( + pk_alias_id, + offset(pk_id, 1), + offset(intermediate_id, 3), + pk_alias_attr, + ) + ] + + class SkAliasNode(SpyglassMixin, dj.Lookup): + definition = """ + sk_alias_id: int + --- + -> SkNode.proj(fk_sk_id='sk_id') + -> PkSkNode + sk_alias_attr : int + """ + contents = [ + (i, j, k, m, n) + for i, j, k, m, n in zip( + sk_alias_id, + offset(sk_id, 2), + offset(pk_sk_id, 1), + offset(intermediate_id, 5), + sk_alias_attr, + ) + ] + + class MergeOutput(_Merge, SpyglassMixin): + definition = """ + merge_id: uuid + --- + source: varchar(32) + """ + + class PkNode(dj.Part): + definition = """ + -> MergeOutput + --- + -> PkNode + """ + + class MergeChild(SpyglassMixin, dj.Manual): + definition = """ + -> MergeOutput + merge_child_id: int + --- + merge_child_attr: int + """ + + yield { + "ParentNode": ParentNode, + "OtherParentNode": OtherParentNode, + "IntermediateNode": IntermediateNode, + "PkNode": PkNode, + "SkNode": SkNode, + "PkSkNode": PkSkNode, + "PkAliasNode": PkAliasNode, + "SkAliasNode": SkAliasNode, + "MergeOutput": MergeOutput, + "MergeChild": MergeChild, + } + + +@pytest.fixture(scope="module") +def graph_tables(dj_conn, graph_schema): + + schema = dj.Schema(context=graph_schema) + + for table in graph_schema.values(): + schema(table) + + schema.activate("test_graph", connection=dj_conn) + + # Merge inserts after declaring tables + merge_keys = graph_schema["PkNode"].fetch("KEY", offset=1, as_dict=True) + graph_schema["MergeOutput"].insert(merge_keys, skip_duplicates=True) + merge_child_keys = graph_schema["MergeOutput"].merge_fetch( + True, "merge_id", offset=1 + ) + merge_child_inserts = [ + (i, j, k + 10) + for i, j, k in zip(merge_child_keys, range(4), range(10, 15)) + ] + graph_schema["MergeChild"].insert(merge_child_inserts, skip_duplicates=True) + + yield graph_schema + + schema.drop(force=True) + + +@pytest.fixture(scope="module") +def graph_tables_many_to_one(graph_tables): + ParentNode = graph_tables["ParentNode"] + IntermediateNode = graph_tables["IntermediateNode"] + PkSkNode = graph_tables["PkSkNode"] + + pk_sk_keys = PkSkNode().fetch(as_dict=True)[-2:] + new_inserts = [ + { + "pk_sk_id": k["pk_sk_id"] + 3, + "intermediate_id": k["intermediate_id"] + 3, + "intermediate_attr": k["intermediate_id"] + 16, + "parent_id": k["intermediate_id"] - 1, + "parent_attr": k["intermediate_id"] + 11, + "other_id": k["other_id"], # No change + "pk_sk_attr": k["pk_sk_attr"] + 10, + } + for k in pk_sk_keys + ] + + insert_kwargs = {"ignore_extra_fields": True, "skip_duplicates": True} + ParentNode.insert(new_inserts, **insert_kwargs) + IntermediateNode.insert(new_inserts, **insert_kwargs) + PkSkNode.insert(new_inserts, **insert_kwargs) + + yield graph_tables diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index 7ba4b1fa2..66d9772c3 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -4,15 +4,20 @@ @pytest.fixture(scope="session") def TableChain(): - from spyglass.utils.dj_chains import TableChain + from spyglass.utils.dj_graph import TableChain return TableChain +def full_to_camel(t): + return to_camel_case(t.split(".")[-1].strip("`")) + + def test_chains_repr(chains): """Test that the repr of a TableChains object is as expected.""" repr_got = repr(chains) - repr_exp = "\n".join([str(c) for c in chains.chains]) + chain_st = ",\n\t".join([str(c) for c in chains.chains]) + "\n" + repr_exp = f"TableChains(\n\t{chain_st})" assert repr_got == repr_exp, "Unexpected repr of TableChains object." @@ -32,11 +37,13 @@ def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain): def test_chain_str(chain): """Test that the str of a TableChain object is as expected.""" chain = chain - parent = to_camel_case(chain.parent.table_name) - child = to_camel_case(chain.child.table_name) str_got = str(chain) - str_exp = parent + chain._link_symbol + child + str_exp = ( + full_to_camel(chain.parent) + + chain._link_symbol + + full_to_camel(chain.child) + ) assert str_got == str_exp, "Unexpected str of TableChain object." @@ -45,25 +52,25 @@ def test_chain_repr(chain): """Test that the repr of a TableChain object is as expected.""" repr_got = repr(chain) repr_ext = "Chain: " + chain._link_symbol.join( - [t.table_name for t in chain.objects] + [full_to_camel(t) for t in chain.path] ) assert repr_got == repr_ext, "Unexpected repr of TableChain object." def test_chain_len(chain): """Test that the len of a TableChain object is as expected.""" - assert len(chain) == len(chain.names), "Unexpected len of TableChain." + assert len(chain) == len(chain.path), "Unexpected len of TableChain." def test_chain_getitem(chain): """Test getitem of TableChain object.""" by_int = chain[0] - by_str = chain[chain.names[0]] + by_str = chain[chain.path[0]] assert by_int == by_str, "Getitem by int and str not equal." def test_nolink_join(no_link_chain): - assert no_link_chain.join() is None, "Unexpected join of no link chain." + assert no_link_chain.cascade() is None, "Unexpected join of no link chain." def test_chain_str_no_link(no_link_chain): diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py new file mode 100644 index 000000000..7d5257a36 --- /dev/null +++ b/tests/utils/test_graph.py @@ -0,0 +1,143 @@ +import pytest + + +@pytest.fixture(scope="session") +def leaf(lin_merge): + yield lin_merge.LinearizedPositionV1() + + +@pytest.fixture(scope="session") +def restr_graph(leaf, verbose, lin_merge_key): + from spyglass.utils.dj_graph import RestrGraph + + _ = lin_merge_key # linearization merge table populated + + yield RestrGraph( + seed_table=leaf, + table_name=leaf.full_table_name, + restriction=True, + cascade=True, + verbose=verbose, + ) + + +def test_rg_repr(restr_graph, leaf): + """Test that the repr of a RestrGraph object is as expected.""" + repr_got = repr(restr_graph) + + assert "cascade" in repr_got.lower(), "Cascade not in repr." + assert leaf.full_table_name in repr_got, "Table name not in repr." + + +def test_rg_ft(restr_graph): + """Test FreeTable attribute of RestrGraph.""" + assert len(restr_graph.leaf_ft) == 1, "Unexpected # of leaf tables." + assert len(restr_graph["spatial"]) == 2, "Unexpected cascaded table length." + + +def test_rg_restr_ft(restr_graph): + """Test get restricted free tables.""" + ft = restr_graph["spatial_series"] + assert len(ft) == 2, "Unexpected restricted table length." + + +def test_rg_file_paths(restr_graph): + """Test collection of upstream file paths.""" + paths = [p.get("file_path") for p in restr_graph.file_paths] + assert len(paths) == 2, "Unexpected number of file paths." + + +@pytest.fixture(scope="session") +def restr_graph_new_leaf(restr_graph, common): + restr_graph.add_leaf( + table_name=common.common_behav.PositionSource.full_table_name, + restriction=True, + ) + + yield restr_graph + + +def test_add_leaf_cascade(restr_graph_new_leaf): + assert ( + not restr_graph_new_leaf.cascaded + ), "Cascaded flag not set when add leaf." + + +def test_add_leaf_restr_ft(restr_graph_new_leaf): + restr_graph_new_leaf.cascade() + ft = restr_graph_new_leaf._get_ft( + "`common_interval`.`interval_list`", with_restr=True + ) + assert len(ft) == 2, "Unexpected restricted table length." + + +@pytest.fixture(scope="session") +def restr_graph_root(restr_graph, common, lfp_band, lin_v1): + from spyglass.utils.dj_graph import RestrGraph + + yield RestrGraph( + seed_table=common.Session(), + table_name=common.Session.full_table_name, + restriction="True", + direction="down", + cascade=True, + verbose=False, + ) + + +def test_rg_root(restr_graph_root): + assert ( + len(restr_graph_root["trodes_pos_v1"]) == 2 + ), "Incomplete cascade from root." + + +@pytest.mark.parametrize( + "restr, expect_n, msg", + [ + ("pk_attr > 16", 4, "pk no alias"), + ("sk_attr > 17", 3, "sk no alias"), + ("pk_alias_attr > 18", 3, "pk pk alias"), + ("sk_alias_attr > 19", 2, "sk sk alias"), + ("merge_child_attr > 21", 2, "merge child down"), + ({"merge_child_attr": 21}, 1, "dict restr"), + ], +) +def test_restr_from_upstream(graph_tables, restr, expect_n, msg): + msg = "Error in `>>` for " + msg + assert len(graph_tables["ParentNode"]() >> restr) == expect_n, msg + + +@pytest.mark.parametrize( + "table, restr, expect_n, msg", + [ + ("PkNode", "parent_attr > 15", 5, "pk no alias"), + ("SkNode", "parent_attr > 16", 4, "sk no alias"), + ("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"), + ("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"), + ("MergeChild", "parent_attr > 18", 2, "merge child"), + ("MergeChild", {"parent_attr": 18}, 1, "dict restr"), + ], +) +def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): + msg = "Error in `<<` for " + msg + assert len(graph_tables[table]() << restr) == expect_n, msg + + +def test_restr_many_to_one(graph_tables_many_to_one): + PK = graph_tables_many_to_one["PkSkNode"]() + OP = graph_tables_many_to_one["OtherParentNode"]() + + msg_template = "Error in `%s` for many to one." + + assert len(PK << "other_attr > 14") == 4, msg_template % "<<" + assert len(PK << {"other_attr": 15}) == 2, msg_template % "<<" + assert len(OP >> "pk_sk_attr > 19") == 2, msg_template % ">>" + assert ( + len(OP >> [{"pk_sk_attr": 19}, {"pk_sk_attr": 20}]) == 2 + ), "Error accepting list of dicts for `>>` for many to one." + + +def test_restr_invalid(graph_tables): + PkNode = graph_tables["PkNode"]() + with pytest.raises(ValueError): + len(PkNode << set(["parent_attr > 15", "parent_attr < 20"])) diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index faa823c8e..010abf03c 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -1,7 +1,7 @@ import datajoint as dj import pytest -from tests.conftest import VERBOSE +from tests.conftest import TEARDOWN, VERBOSE @pytest.fixture(scope="module") @@ -16,7 +16,10 @@ class Mixin(SpyglassMixin, dj.Manual): yield Mixin -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") +@pytest.mark.skipif( + not VERBOSE or not TEARDOWN, + reason="Error only on verbose or new declare.", +) def test_bad_prefix(caplog, dj_conn, Mixin): schema_bad = dj.Schema("badprefix", {}, connection=dj_conn) schema_bad(Mixin) @@ -38,6 +41,19 @@ def test_merge_detect(Nwbfile, pos_merge_tables): ), "Merges not detected by mixin." +def test_merge_chain_join(Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key): + """Test that the mixin can join merge chains.""" + _ = lin_v1, lfp_merge_key # merge tables populated + + all_chains = [ + chains.cascade(True, direction="down") + for chains in Nwbfile._merge_chains.values() + ] + end_len = [len(chain[0]) for chain in all_chains if chain] + + assert sum(end_len) == 4, "Merge chains not joined correctly." + + def test_get_chain(Nwbfile, pos_merge_tables): """Test that the mixin can get the chain of a merge.""" lin_parts = Nwbfile._get_chain("linear").part_names @@ -48,7 +64,28 @@ def test_get_chain(Nwbfile, pos_merge_tables): @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_ddm_warning(Nwbfile, caplog): """Test that the mixin warns on empty delete_downstream_merge.""" - (Nwbfile & "nwb_file_name LIKE 'BadName'").delete_downstream_merge( + (Nwbfile.file_like("BadName")).delete_downstream_merge( reload_cache=True, disable_warnings=False ) assert "No merge deletes found" in caplog.text, "No warning issued." + + +def test_ddm_dry_run(Nwbfile, common, sgp, pos_merge_tables, lin_v1): + """Test that the mixin can dry run delete_downstream_merge.""" + _ = lin_v1 # merge tables populated + pos_output_name = pos_merge_tables[0].full_table_name + + param_field = "trodes_pos_params_name" + trodes_params = sgp.v1.TrodesPosParams() + + rft = (trodes_params & f'{param_field} LIKE "%ups%"').ddm( + reload_cache=True, dry_run=True, return_parts=False + )[pos_output_name][0] + assert len(rft) == 1, "ddm did not return restricted table." + + table_name = [p for p in pos_merge_tables[0].parts() if "trode" in p][0] + assert table_name == rft.full_table_name, "ddm didn't grab right table." + + assert ( + rft.fetch1(param_field) == "single_led_upsampled" + ), "ddm didn't grab right row." From 042fd1cd631f2accd0ed0f25544898c628c65075 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 12:05:54 -0700 Subject: [PATCH 05/11] Transaction on `populate_all_common` (#957) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP: transaction on populate_all_common * ✅ : Seperate rollback and raise err options --- CHANGELOG.md | 1 + notebooks/01_Insert_Data.ipynb | 20 ++- notebooks/py_scripts/01_Insert_Data.py | 18 ++- notebooks/py_scripts/50_MUA_Detection.py | 111 +++++++++++++ src/spyglass/common/common_behav.py | 28 +++- src/spyglass/common/common_dio.py | 15 +- src/spyglass/common/common_ephys.py | 128 ++++++++++----- src/spyglass/common/common_nwbfile.py | 1 + src/spyglass/common/common_session.py | 7 + src/spyglass/common/common_task.py | 10 +- src/spyglass/common/populate_all_common.py | 171 +++++++++++++++----- src/spyglass/data_import/insert_sessions.py | 16 +- src/spyglass/spikesorting/imported.py | 9 +- 13 files changed, 441 insertions(+), 94 deletions(-) create mode 100644 notebooks/py_scripts/50_MUA_Detection.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 231e328d6..bf8804795 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 +- Add rollback option to `populate_all_common` #957 - Add long-distance restrictions via `<<` and `>>` operators. #943 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 diff --git a/notebooks/01_Insert_Data.ipynb b/notebooks/01_Insert_Data.ipynb index e68623e72..2a2297642 100644 --- a/notebooks/01_Insert_Data.ipynb +++ b/notebooks/01_Insert_Data.ipynb @@ -1082,8 +1082,22 @@ "- neural activity (extracellular recording of multiple brain areas)\n", "- etc.\n", "\n", - "_Note:_ this may take time as Spyglass creates the copy. You may see a prompt\n", - "about inserting device information.\n" + "_Notes:_ this may take time as Spyglass creates the copy. You may see a prompt\n", + "about inserting device information.\n", + "\n", + "By default, the session insert process is error permissive. It will log an\n", + "error and continue attempts across various tables. You have two options you can\n", + "toggle to adjust this.\n", + "\n", + "- `rollback_on_fail`: Default False. If True, errors will still be logged for\n", + " all tables and, if any are registered, the `Nwbfile` entry will be deleted.\n", + " This is helpful for knowing why your file failed, and making it easy to retry.\n", + "- `raise_err`: Default False. If True, errors will not be logged and will\n", + " instead be raised. This is useful for debugging and exploring the error stack.\n", + " The end result may be that some tables may still have entries from this file\n", + " that will need to be manually deleted after a failed attempt. 'transactions'\n", + " are used where possible to rollback sibling tables, but child table errors\n", + " will still leave entries from parent tables.\n" ] }, { @@ -1146,7 +1160,7 @@ } ], "source": [ - "sgi.insert_sessions(nwb_file_name)" + "sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False)" ] }, { diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index 975ed4ac5..870c6907a 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -198,11 +198,25 @@ # - neural activity (extracellular recording of multiple brain areas) # - etc. # -# _Note:_ this may take time as Spyglass creates the copy. You may see a prompt +# _Notes:_ this may take time as Spyglass creates the copy. You may see a prompt # about inserting device information. # +# By default, the session insert process is error permissive. It will log an +# error and continue attempts across various tables. You have two options you can +# toggle to adjust this. +# +# - `rollback_on_fail`: Default False. If True, errors will still be logged for +# all tables and, if any are registered, the `Nwbfile` entry will be deleted. +# This is helpful for knowing why your file failed, and making it easy to retry. +# - `raise_err`: Default False. If True, errors will not be logged and will +# instead be raised. This is useful for debugging and exploring the error stack. +# The end result may be that some tables may still have entries from this file +# that will need to be manually deleted after a failed attempt. 'transactions' +# are used where possible to rollback sibling tables, but child table errors +# will still leave entries from parent tables. +# -sgi.insert_sessions(nwb_file_name) +sgi.insert_sessions(nwb_file_name, rollback_on_fail=False, raise_error=False) # ## Inspecting the data # diff --git a/notebooks/py_scripts/50_MUA_Detection.py b/notebooks/py_scripts/50_MUA_Detection.py new file mode 100644 index 000000000..bc319ff82 --- /dev/null +++ b/notebooks/py_scripts/50_MUA_Detection.py @@ -0,0 +1,111 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: spyglass +# language: python +# name: python3 +# --- + +# + +import datajoint as dj +from pathlib import Path + +dj.config.load( + Path("../dj_local_conf.json").absolute() +) # load config for database connection info + +from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters + +# - + +MuaEventsParameters() + +MuaEventsV1() + +# + +from spyglass.position import PositionOutput + +nwb_copy_file_name = "mediumnwb20230802_.nwb" + +trodes_s_key = { + "nwb_file_name": nwb_copy_file_name, + "interval_list_name": "pos 0 valid times", + "trodes_pos_params_name": "single_led_upsampled", +} + +pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id") +pos_merge_id + +# + +from spyglass.spikesorting.analysis.v1.group import ( + SortedSpikesGroup, +) + +sorted_spikes_group_key = { + "nwb_file_name": nwb_copy_file_name, + "sorted_spikes_group_name": "test_group", + "unit_filter_params_name": "default_exclusion", +} + +SortedSpikesGroup & sorted_spikes_group_key + +# + +mua_key = { + "mua_param_name": "default", + **sorted_spikes_group_key, + "pos_merge_id": pos_merge_id, + "detection_interval": "pos 0 valid times", +} + +MuaEventsV1().populate(mua_key) +MuaEventsV1 & mua_key +# - + +mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe() +mua_times + +# + +import matplotlib.pyplot as plt +import numpy as np + +fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4)) +speed = MuaEventsV1.get_speed(mua_key).to_numpy() +time = speed.index.to_numpy() +multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time) + +time_slice = slice( + np.searchsorted(time, mua_times.loc[10].start_time) - 1_000, + np.searchsorted(time, mua_times.loc[10].start_time) + 5_000, +) + +axes[0].plot( + time[time_slice], + multiunit_firing_rate[time_slice], + color="black", +) +axes[0].set_ylabel("firing rate (Hz)") +axes[0].set_title("multiunit") +axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey") +axes[1].set_ylabel("speed (cm/s)") +axes[1].set_xlabel("time (s)") + +for id, mua_time in mua_times.loc[ + np.logical_and( + mua_times["start_time"] > time[time_slice].min(), + mua_times["end_time"] < time[time_slice].max(), + ) +].iterrows(): + axes[0].axvspan( + mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5 + ) +# - + +(MuaEventsV1 & mua_key).create_figurl( + zscore_mua=True, +) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index bdb769e73..b7e8d953b 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -43,12 +43,8 @@ class SpatialSeries(SpyglassMixin, dj.Part): name=null: varchar(32) # name of spatial series """ - def populate(self, keys=None): - """Insert position source data from NWB file. - - WARNING: populate method on Manual table is not protected by transaction - protections like other DataJoint tables. - """ + def _no_transaction_make(self, keys=None): + """Insert position source data from NWB file.""" if not isinstance(keys, list): keys = [keys] if isinstance(keys[0], (dj.Table, dj.expression.QueryExpression)): @@ -227,6 +223,12 @@ def _get_column_names(rp, pos_id): return column_names def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] interval_list_name = key["interval_list_name"] @@ -238,7 +240,7 @@ def make(self, key): PositionSource.get_epoch_num(interval_list_name) ] - self.insert1(key) + self.insert1(key, allow_direct_insert=True) self.PosObject.insert( [ dict( @@ -294,6 +296,12 @@ class StateScriptFile(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" """Add a new row to the StateScriptFile table.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) @@ -309,6 +317,7 @@ def make(self, key): ) return # See #849 + script_inserts = [] for associated_file_obj in associated_files.data_interfaces.values(): if not isinstance( associated_file_obj, ndx_franklab_novela.AssociatedFiles @@ -337,10 +346,13 @@ def make(self, key): # find the file associated with this epoch if str(key["epoch"]) in epoch_list: key["file_object_id"] = associated_file_obj.object_id - self.insert1(key) + script_inserts.append(key.copy()) else: logger.info("not a statescript file") + if script_inserts: + self.insert(script_inserts, allow_direct_insert=True) + @schema class VideoFile(SpyglassMixin, dj.Imported): diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index 3db854e6a..629adef47 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -27,6 +27,12 @@ class DIOEvents(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -45,10 +51,17 @@ def make(self, key): key["interval_list_name"] = ( Raw() & {"nwb_file_name": nwb_file_name} ).fetch1("interval_list_name") + + dio_inserts = [] for event_series in behav_events.time_series.values(): key["dio_event_name"] = event_series.name key["dio_object_id"] = event_series.object_id - self.insert1(key, skip_duplicates=True) + dio_inserts.append(key.copy()) + self.insert( + dio_inserts, + skip_duplicates=True, + allow_direct_insert=True, + ) def plot_all_dio_events(self, return_fig=False): """Plot all DIO events in the session. diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 1880340a9..d03f6edff 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -45,6 +45,12 @@ class ElectrodeGroup(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -69,7 +75,7 @@ def make(self, key): else: # if negative x coordinate # define target location as left hemisphere key["target_hemisphere"] = "Left" - self.insert1(key, skip_duplicates=True) + self.insert1(key, skip_duplicates=True, allow_direct_insert=True) @schema @@ -95,6 +101,12 @@ class Electrode(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -108,23 +120,32 @@ def make(self, key): else: electrode_config_dicts = dict() + electrode_constants = { + "x_warped": 0, + "y_warped": 0, + "z_warped": 0, + "contacts": "", + } + + electrode_inserts = [] electrodes = nwbf.electrodes.to_dataframe() for elect_id, elect_data in electrodes.iterrows(): - key["electrode_id"] = elect_id - key["name"] = str(elect_id) - key["electrode_group_name"] = elect_data.group_name - key["region_id"] = BrainRegion.fetch_add( - region_name=elect_data.group.location + key.update( + { + "electrode_id": elect_id, + "name": str(elect_id), + "electrode_group_name": elect_data.group_name, + "region_id": BrainRegion.fetch_add( + region_name=elect_data.group.location + ), + "x": elect_data.x, + "y": elect_data.y, + "z": elect_data.z, + "filtering": elect_data.filtering, + "impedance": elect_data.get("imp"), + **electrode_constants, + } ) - key["x"] = elect_data.x - key["y"] = elect_data.y - key["z"] = elect_data.z - key["x_warped"] = 0 - key["y_warped"] = 0 - key["z_warped"] = 0 - key["contacts"] = "" - key["filtering"] = elect_data.filtering - key["impedance"] = elect_data.get("imp") # rough check of whether the electrodes table was created by # rec_to_nwb and has the appropriate custom columns used by @@ -140,13 +161,17 @@ def make(self, key): and "bad_channel" in elect_data and "ref_elect_id" in elect_data ): - key["probe_id"] = elect_data.group.device.probe_type - key["probe_shank"] = elect_data.probe_shank - key["probe_electrode"] = elect_data.probe_electrode - key["bad_channel"] = ( - "True" if elect_data.bad_channel else "False" + key.update( + { + "probe_id": elect_data.group.device.probe_type, + "probe_shank": elect_data.probe_shank, + "probe_electrode": elect_data.probe_electrode, + "bad_channel": ( + "True" if elect_data.bad_channel else "False" + ), + "original_reference_electrode": elect_data.ref_elect_id, + } ) - key["original_reference_electrode"] = elect_data.ref_elect_id # override with information from the config YAML based on primary # key (electrode id) @@ -163,8 +188,13 @@ def make(self, key): ) else: key.update(electrode_config_dicts[elect_id]) + electrode_inserts.append(key.copy()) - self.insert1(key, skip_duplicates=True) + self.insert1( + key, + skip_duplicates=True, + allow_direct_insert=True, # for no_transaction, pop_all_common + ) @classmethod def create_from_config(cls, nwb_file_name: str): @@ -246,10 +276,17 @@ class Raw(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) raw_interval_name = "raw data valid times" + # get the acquisition object try: # TODO this assumes there is a single item in NWBFile.acquisition @@ -261,19 +298,21 @@ def make(self, key): + f"Skipping entry in {self.full_table_name}" ) return + if rawdata.rate is not None: - sampling_rate = rawdata.rate + key["sampling_rate"] = rawdata.rate else: logger.info("Estimating sampling rate...") # NOTE: Only use first 1e6 timepoints to save time - sampling_rate = estimate_sampling_rate( + key["sampling_rate"] = estimate_sampling_rate( np.asarray(rawdata.timestamps[: int(1e6)]), 1.5, verbose=True ) - key["sampling_rate"] = sampling_rate - interval_dict = dict() - interval_dict["nwb_file_name"] = key["nwb_file_name"] - interval_dict["interval_list_name"] = raw_interval_name + interval_dict = { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": raw_interval_name, + } + if rawdata.rate is not None: interval_dict["valid_times"] = np.array( [[0, len(rawdata.data) / rawdata.rate]] @@ -291,18 +330,25 @@ def make(self, key): # now insert each of the electrodes as an individual row, but with the # same nwb_object_id - key["raw_object_id"] = rawdata.object_id - key["sampling_rate"] = sampling_rate logger.info( - f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz' + f'Importing raw data: Sampling rate:\t{key["sampling_rate"]} Hz\n' + + f'Number of valid intervals:\t{len(interval_dict["valid_times"])}' ) - logger.info( - f'Number of valid intervals:\t{len(interval_dict["valid_times"])}' + + key.update( + { + "raw_object_id": rawdata.object_id, + "interval_list_name": raw_interval_name, + "comments": rawdata.comments, + "description": rawdata.description, + } + ) + + self.insert1( + key, + skip_duplicates=True, + allow_direct_insert=True, ) - key["interval_list_name"] = raw_interval_name - key["comments"] = rawdata.comments - key["description"] = rawdata.description - self.insert1(key, skip_duplicates=True) def nwb_object(self, key): # TODO return the nwb_object; FIX: this should be replaced with a fetch @@ -330,6 +376,12 @@ class SampleCount(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -343,7 +395,7 @@ def make(self, key): ) return # see #849 key["sample_count_object_id"] = sample_count.object_id - self.insert1(key) + self.insert1(key, allow_direct_insert=True) @schema diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 19700d3b3..d5bba9e51 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -65,6 +65,7 @@ def insert_from_relative_file_name(cls, nwb_file_name): The relative path to the NWB file. """ nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) + assert os.path.exists( nwb_file_abs_path ), f"File does not exist: {nwb_file_abs_path}" diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index acb4a0826..e97934122 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -52,6 +52,12 @@ class Experimenter(SpyglassMixin, dj.Part): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" # These imports must go here to avoid cyclic dependencies # from .common_task import Task, TaskEpoch from .common_interval import IntervalList @@ -114,6 +120,7 @@ def make(self, key): "experiment_description": nwbf.experiment_description, }, skip_duplicates=True, + allow_direct_insert=True, # for populate_all_common ) logger.info("Skipping Apparatus for now...") diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 0dffa4ac5..49fd7bb0e 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -97,6 +97,12 @@ class TaskEpoch(SpyglassMixin, dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -120,6 +126,7 @@ def make(self, key): logger.warn(f"No tasks processing module found in {nwbf}\n") return + task_inserts = [] for task in tasks_mod.data_interfaces.values(): if self.check_task_table(task): # check if the task is in the Task table and if not, add it @@ -169,7 +176,8 @@ def make(self, key): break # TODO case when interval is not found is not handled key["interval_list_name"] = interval - self.insert1(key) + task_inserts.append(key.copy()) + self.insert(task_inserts, allow_direct_insert=True) @classmethod def update_entries(cls, restrict={}): diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 2972ed145..04df52dec 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -1,3 +1,5 @@ +from typing import List, Union + import datajoint as dj from spyglass.common.common_behav import ( @@ -20,54 +22,147 @@ from spyglass.utils import logger -def populate_all_common(nwb_file_name): - """Insert all common tables for a given NWB file.""" +def log_insert_error( + table: str, err: Exception, error_constants: dict = None +) -> None: + """Log a given error to the InsertError table. + + Parameters + ---------- + table : str + The table name where the error occurred. + err : Exception + The exception that was raised. + error_constants : dict, optional + Dictionary with keys for dj_user, connection_id, and nwb_file_name. + Defaults to checking dj.conn and using "Unknown" for nwb_file_name. + """ + if error_constants is None: + error_constants = dict( + dj_user=dj.config["database.user"], + connection_id=dj.conn().connection_id, + nwb_file_name="Unknown", + ) + InsertError.insert1( + dict( + **error_constants, + table=table.__name__, + error_type=type(err).__name__, + error_message=str(err), + error_raw=str(err), + ) + ) + + +def single_transaction_make( + tables: List[dj.Table], + nwb_file_name: str, + raise_err: bool = False, + error_constants: dict = None, +): + """For each table, run the `_no_transaction_make` method. + + Requires `allow_direct_insert` set to True within each method. Uses + nwb_file_name search table key_source for relevant key. Currently assumes + all tables will have exactly one key_source entry per nwb file. + """ + file_restr = {"nwb_file_name": nwb_file_name} + with Nwbfile.connection.transaction: + for table in tables: + logger.info(f"Populating {table.__name__}...") + + # If imported/computed table, get key from key_source + key_source = getattr(table, "key_source", None) + if key_source is None: # Generate key from parents + parents = table.parents(as_objects=True) + key_source = parents[0].proj() + for parent in parents[1:]: + key_source *= parent.proj() + pop_key = (key_source & file_restr).fetch1("KEY") + + try: + table()._no_transaction_make(pop_key) + except Exception as err: + if raise_err: + raise err + log_insert_error( + table=table, err=err, error_constants=error_constants + ) + + +def populate_all_common( + nwb_file_name, rollback_on_fail=False, raise_err=False +) -> Union[List, None]: + """Insert all common tables for a given NWB file. + + Parameters + ---------- + nwb_file_name : str + The name of the NWB file to populate. + rollback_on_fail : bool, optional + If True, will delete the Session entry if any errors occur. + Defaults to False. + raise_err : bool, optional + If True, will raise any errors that occur during population. + Defaults to False. This will prevent any rollback from occurring. + + Returns + ------- + List + A list of keys for InsertError entries if any errors occurred. + """ from spyglass.spikesorting.imported import ImportedSpikeSorting - key = [(Nwbfile & f"nwb_file_name LIKE '{nwb_file_name}'").proj()] - tables = [ - Session, - # NwbfileKachery, # Not used by default - ElectrodeGroup, - Electrode, - Raw, - SampleCount, - DIOEvents, - # SensorData, # Not used by default. Generates large files - RawPosition, - TaskEpoch, - StateScriptFile, - VideoFile, - PositionSource, - RawPosition, - ImportedSpikeSorting, - ] error_constants = dict( dj_user=dj.config["database.user"], connection_id=dj.conn().connection_id, nwb_file_name=nwb_file_name, ) - for table in tables: - logger.info(f"Populating {table.__name__}...") - try: - table.populate(key) - except Exception as e: - InsertError.insert1( - dict( - **error_constants, - table=table.__name__, - error_type=type(e).__name__, - error_message=str(e), - error_raw=str(e), - ) - ) - query = InsertError & error_constants - if query: - err_tables = query.fetch("table") + table_lists = [ + [ # Tables that can be inserted in a single transaction + Session, + ElectrodeGroup, # Depends on Session + Electrode, # Depends on ElectrodeGroup + Raw, # Depends on Session + SampleCount, # Depends on Session + DIOEvents, # Depends on Session + TaskEpoch, # Depends on Session + ImportedSpikeSorting, # Depends on Session + # NwbfileKachery, # Not used by default + # SensorData, # Not used by default. Generates large files + ], + [ # Tables that depend on above transaction + PositionSource, # Depends on Session + VideoFile, # Depends on TaskEpoch + StateScriptFile, # Depends on TaskEpoch + ], + [ + RawPosition, # Depends on PositionSource + ], + ] + + for tables in table_lists: + single_transaction_make( + tables=tables, + nwb_file_name=nwb_file_name, + raise_err=raise_err, + error_constants=error_constants, + ) + + err_query = InsertError & error_constants + nwbfile_query = Nwbfile & {"nwb_file_name": nwb_file_name} + + if err_query and nwbfile_query and rollback_on_fail: + logger.error(f"Rolling back population for {nwb_file_name}...") + # Should this be safemode=False to prevent confirmation prompt? + nwbfile_query.super_delete(warn=False) + + if err_query: + err_tables = err_query.fetch("table") logger.error( f"Errors occurred during population for {nwb_file_name}:\n\t" + f"Failed tables {err_tables}\n\t" + "See common_usage.InsertError for more details" ) - return query.fetch("KEY") + return err_query.fetch("KEY") diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 329a7be42..a5d539e8e 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -12,7 +12,11 @@ from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename -def insert_sessions(nwb_file_names: Union[str, List[str]]): +def insert_sessions( + nwb_file_names: Union[str, List[str]], + rollback_on_fail: bool = False, + raise_err: bool = False, +): """ Populate the dj database with new sessions. @@ -23,6 +27,10 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): existing .nwb files. Each file represents a session. Also accepts strings with glob wildcards (e.g., *) so long as the wildcard specifies exactly one file. + rollback_on_fail : bool, optional + If True, undo all inserts if an error occurs. Default is False. + raise_err : bool, optional + If True, raise an error if an error occurs. Default is False. """ if not isinstance(nwb_file_names, list): @@ -66,7 +74,11 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): # the raw data in the original file copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name) Nwbfile().insert_from_relative_file_name(out_nwb_file_name) - populate_all_common(out_nwb_file_name) + return populate_all_common( + out_nwb_file_name, + rollback_on_fail=rollback_on_fail, + raise_err=raise_err, + ) def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index ca1bdc9d0..048502081 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -31,6 +31,13 @@ class Annotations(SpyglassMixin, dj.Part): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key): + """Make without transaction + + Allows populate_all_common to work within a single transaction.""" + raise RuntimeError("TEMP: This is a test error. Please ignore.") orig_key = copy.deepcopy(key) nwb_file_abs_path = Nwbfile.get_abs_path(key["nwb_file_name"]) @@ -49,7 +56,7 @@ def make(self, key): key["object_id"] = nwbfile.units.object_id - self.insert1(key, skip_duplicates=True) + self.insert1(key, skip_duplicates=True, allow_direct_insert=True) part_name = SpikeSortingOutput._part_name(self.table_name) SpikeSortingOutput._merge_insert( From a6e2ea6414dc1725e1afe733d3c5fe6bf1654c60 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 10 May 2024 15:15:57 -0700 Subject: [PATCH 06/11] Permit multiple restrict_by (#969) * Permit multiple restrict_by * Update Changelog * Fix typo --- CHANGELOG.md | 2 +- docs/src/misc/mixin.md | 20 ++++++++++--------- src/spyglass/decoding/v1/waveform_features.py | 2 +- src/spyglass/utils/dj_merge_tables.py | 20 +------------------ src/spyglass/utils/dj_mixin.py | 2 +- 5 files changed, 15 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8804795..299a264b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 - Add rollback option to `populate_all_common` #957 -- Add long-distance restrictions via `<<` and `>>` operators. #943 +- Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 ## [0.5.2] (April 22, 2024) diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 6b3884551..229747402 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -59,37 +59,39 @@ key and `>>` as a shorthand for `restrict_by` a downstream key. ```python from spyglass.example import AnyTable -AnyTable >> 'downsteam_attribute="value"' -AnyTable << 'upstream_attribute="value"' +AnyTable() << 'upstream_attribute="value"' +AnyTable() >> 'downsteam_attribute="value"' # Equivalent to -AnyTable.restrict_by('upstream_attribute="value"', direction="up") -AnyTable.restrict_by('downsteam_attribute="value"', direction="down") +AnyTable().restrict_by('downsteam_attribute="value"', direction="down") +AnyTable().restrict_by('upstream_attribute="value"', direction="up") ``` Some caveats to this function: 1. 'Peripheral' tables, like `IntervalList` and `AnalysisNwbfile` make it hard to determine the correct parent/child relationship and have been removed - from this search. + from this search by default. 2. This function will raise an error if it attempts to check a table that has not been imported into the current namespace. It is best used for exploring and debugging, not for production code. 3. It's hard to determine the attributes in a mixed dictionary/string restriction. If you are having trouble, try using a pure string restriction. -4. The most direct path to your restriction may not be the path took, especially - when using Merge Tables. When the result is empty see the warning about the - path used. Then, ban tables from the search to force a different path. +4. The most direct path to your restriction may not be the path your data took, + especially when using Merge Tables. When the result is empty see the + warning about the path used. Then, ban tables from the search to force a + different path. ```python -my_table = MyTable() # must be instantced +my_table = MyTable() # must be instanced my_table.ban_search_table(UnwantedTable1) my_table.ban_search_table([UnwantedTable2, UnwantedTable3]) my_table.unban_search_table(UnwantedTable3) my_table.see_banned_tables() my_table << my_restriction +my_table << upstream_restriction >> downstream_restriction ``` When providing a restriction of the parent, use 'up' direction. When providing a diff --git a/src/spyglass/decoding/v1/waveform_features.py b/src/spyglass/decoding/v1/waveform_features.py index 4a999accd..536ed4864 100644 --- a/src/spyglass/decoding/v1/waveform_features.py +++ b/src/spyglass/decoding/v1/waveform_features.py @@ -82,7 +82,7 @@ def supported_waveform_features(self) -> list[str]: @schema -class UnitWaveformFeaturesSelection(dj.Manual): +class UnitWaveformFeaturesSelection(SpyglassMixin, dj.Manual): definition = """ -> SpikeSortingOutput.proj(spikesorting_merge_id="merge_id") -> WaveformFeaturesParams diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 0b8f16de6..ce96fe00e 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -830,25 +830,7 @@ def delete_downstream_merge( ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries - Parameters - ---------- - table: dj.Table - DataJoint table or restriction thereof - restriction: str - Optional restriction to apply before deletion from merge/part - tables. If not provided, delete all downstream entries. - dry_run: bool - Default True. If true, return list of tuples, merge/part tables - downstream of table input. Otherwise, delete merge/part table entries. - disable_warning: bool - Default False. If True, don't warn about restrictions on table object. - kwargs: dict - Additional keyword arguments for DataJoint delete. - - Returns - ------- - List[Tuple[dj.Table, dj.Table]] - Entries in merge/part tables downstream of table input. + Passthrough to SpyglassMixin.delete_downstream_merge """ logger.warning( "DEPRECATED: This function will be removed in `0.6`. " diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 08fa377b3..cf1471ee6 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -886,7 +886,7 @@ def restrict_by( if return_graph: return graph - ret = graph.leaf_ft[0] + ret = self & graph._get_restr(self.full_table_name) if len(ret) == len(self) or len(ret) == 0: logger.warning( f"Failed to restrict with path: {graph.path_str}\n\t" From 113ce9a25a1bf2f4dfb3c3cce16c80b9251b57d5 Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Mon, 13 May 2024 12:29:43 -0700 Subject: [PATCH 07/11] Allow dlc pipeline to run without prior position tracking (#970) * fix dlc pose estimation populate if no raw position data * allow dlc pipeline to run without raw spatial data * update changelog * string formatting * fix analysis nwb create time --- CHANGELOG.md | 5 ++ .../position/v1/position_dlc_centroid.py | 19 +++++-- .../position/v1/position_dlc_orient.py | 25 ++++++--- .../v1/position_dlc_pose_estimation.py | 53 ++++++++++++------- .../position/v1/position_dlc_position.py | 9 ++-- .../position/v1/position_dlc_selection.py | 8 +-- 6 files changed, 82 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 299a264b7..4b0855fe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 +### Pipelines + +- DLC + - Allow dlc without pre-existing tracking data #950 + ## [0.5.2] (April 22, 2024) ### Infrastructure diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index e989265da..5b96d19db 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -268,17 +268,26 @@ def make(self, key): ) position = pynwb.behavior.Position() velocity = pynwb.behavior.BehavioralTimeSeries() - spatial_series = (RawPosition() & key).fetch_nwb()[0][ - "raw_position" - ] + if ( + RawPosition & key + ): # if spatial series exists, get metadata from there + spatial_series = (RawPosition() & key).fetch_nwb()[0][ + "raw_position" + ] + reference_frame = spatial_series.reference_frame + comments = spatial_series.comments + else: + reference_frame = "" + comments = "no comments" + METERS_PER_CM = 0.01 position.create_spatial_series( name="position", timestamps=final_df.index.to_numpy(), conversion=METERS_PER_CM, data=final_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=reference_frame, + comments=comments, description="x_position, y_position", ) velocity.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index e1e5c668b..0c873fff4 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -1,3 +1,5 @@ +from time import time + import datajoint as dj import numpy as np import pandas as pd @@ -85,9 +87,7 @@ class DLCOrientation(SpyglassMixin, dj.Computed): def make(self, key): # Get labels to smooth from Parameters table - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() cohort_entries = DLCSmoothInterpCohort.BodyPart & key pos_df = pd.concat( { @@ -133,15 +133,28 @@ def make(self, key): final_df = pd.DataFrame( orientation, columns=["orientation"], index=pos_df.index ) - spatial_series = (RawPosition() & key).fetch_nwb()[0]["raw_position"] + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) + if ( + RawPosition & key + ): # if spatial series exists, get metadata from there + spatial_series = (RawPosition() & key).fetch_nwb()[0][ + "raw_position" + ] + reference_frame = spatial_series.reference_frame + comments = spatial_series.comments + else: + reference_frame = "" + comments = "no comments" orientation = pynwb.behavior.CompassDirection() orientation.create_spatial_series( name="orientation", timestamps=final_df.index.to_numpy(), conversion=1.0, data=final_df["orientation"].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=reference_frame, + comments=comments, description="orientation", ) nwb_analysis_file = AnalysisNwbfile() diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index dfc6095a5..bf56fb6fd 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -232,25 +232,38 @@ def make(self, key): dlc_result.creation_time ).strftime("%Y-%m-%d %H:%M:%S") - logger.logger.info("getting raw position") - interval_list_name = ( - convert_epoch_interval_name_to_position_interval_name( - { - "nwb_file_name": key["nwb_file_name"], - "epoch": key["epoch"], - }, - populate_missing=False, + # get video information + _, _, meters_per_pixel, video_time = get_video_path(key) + # check if a position interval exists for this epoch + try: + interval_list_name = ( + convert_epoch_interval_name_to_position_interval_name( + { + "nwb_file_name": key["nwb_file_name"], + "epoch": key["epoch"], + }, + populate_missing=False, + ) ) - ) - spatial_series = ( - RawPosition() - & {**key, "interval_list_name": interval_list_name} - ).fetch_nwb()[0]["raw_position"] - _, _, _, video_time = get_video_path(key) - pos_time = spatial_series.timestamps - # TODO: should get timestamps from VideoFile, but need the video_frame_ind from RawPosition, - # which also has timestamps - key["meters_per_pixel"] = spatial_series.conversion + raw_position = True + except KeyError: + raw_position = False + + if raw_position: + logger.logger.info("Getting raw position") + spatial_series = ( + RawPosition() + & {**key, "interval_list_name": interval_list_name} + ).fetch_nwb()[0]["raw_position"] + pos_time = spatial_series.timestamps + reference_frame = spatial_series.reference_frame + comments = spatial_series.comments + else: + pos_time = video_time + reference_frame = "" + comments = "no comments" + + key["meters_per_pixel"] = meters_per_pixel # Insert entry into DLCPoseEstimation logger.logger.info( @@ -292,8 +305,8 @@ def make(self, key): timestamps=part_df.time.to_numpy(), conversion=METERS_PER_CM, data=part_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=reference_frame, + comments=comments, description="x_position, y_position", ) likelihood.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 436d890d5..11c7019f3 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -1,3 +1,5 @@ +from time import time + import datajoint as dj import numpy as np import pandas as pd @@ -167,9 +169,7 @@ def make(self, key): path=f"{output_dir.as_posix()}/log.log", print_console=False, ) as logger: - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() logger.logger.info("-----------------------") idx = pd.IndexSlice # Get labels to smooth from Parameters table @@ -227,6 +227,9 @@ def make(self, key): .fetch_nwb()[0]["dlc_pose_estimation_position"] .get_spatial_series() ) + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) # Add dataframe to AnalysisNwbfile nwb_analysis_file = AnalysisNwbfile() position = pynwb.behavior.Position() diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 74354db31..facfb8e25 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -1,5 +1,6 @@ import copy from pathlib import Path +from time import time import datajoint as dj import numpy as np @@ -58,9 +59,7 @@ class DLCPosV1(SpyglassMixin, dj.Computed): def make(self, key): orig_key = copy.deepcopy(key) # Add to Analysis NWB file - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() key["pose_eval_result"] = self.evaluate_pose_estimation(key) pos_nwb = (DLCCentroid & key).fetch_nwb()[0] @@ -114,6 +113,9 @@ def make(self, key): comments=vid_frame_obj.comments, ) + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) nwb_analysis_file = AnalysisNwbfile() key["orientation_object_id"] = nwb_analysis_file.add_nwb_object( key["analysis_file_name"], orientation From 9d8b19a7b458b9fff0f8bfb06133315e8403b91f Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 13 May 2024 12:33:13 -0700 Subject: [PATCH 08/11] Revert "Allow dlc pipeline to run without prior position tracking (#970)" (#972) This reverts commit 113ce9a25a1bf2f4dfb3c3cce16c80b9251b57d5. Co-authored-by: Chris Brozdowski --- CHANGELOG.md | 5 -- .../position/v1/position_dlc_centroid.py | 19 ++----- .../position/v1/position_dlc_orient.py | 25 +++------ .../v1/position_dlc_pose_estimation.py | 53 +++++++------------ .../position/v1/position_dlc_position.py | 9 ++-- .../position/v1/position_dlc_selection.py | 8 ++- 6 files changed, 37 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b0855fe7..299a264b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,11 +14,6 @@ - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 -### Pipelines - -- DLC - - Allow dlc without pre-existing tracking data #950 - ## [0.5.2] (April 22, 2024) ### Infrastructure diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index 5b96d19db..e989265da 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -268,26 +268,17 @@ def make(self, key): ) position = pynwb.behavior.Position() velocity = pynwb.behavior.BehavioralTimeSeries() - if ( - RawPosition & key - ): # if spatial series exists, get metadata from there - spatial_series = (RawPosition() & key).fetch_nwb()[0][ - "raw_position" - ] - reference_frame = spatial_series.reference_frame - comments = spatial_series.comments - else: - reference_frame = "" - comments = "no comments" - + spatial_series = (RawPosition() & key).fetch_nwb()[0][ + "raw_position" + ] METERS_PER_CM = 0.01 position.create_spatial_series( name="position", timestamps=final_df.index.to_numpy(), conversion=METERS_PER_CM, data=final_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=reference_frame, - comments=comments, + reference_frame=spatial_series.reference_frame, + comments=spatial_series.comments, description="x_position, y_position", ) velocity.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index 0c873fff4..e1e5c668b 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -1,5 +1,3 @@ -from time import time - import datajoint as dj import numpy as np import pandas as pd @@ -87,7 +85,9 @@ class DLCOrientation(SpyglassMixin, dj.Computed): def make(self, key): # Get labels to smooth from Parameters table - AnalysisNwbfile()._creation_times["pre_create_time"] = time() + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) cohort_entries = DLCSmoothInterpCohort.BodyPart & key pos_df = pd.concat( { @@ -133,28 +133,15 @@ def make(self, key): final_df = pd.DataFrame( orientation, columns=["orientation"], index=pos_df.index ) - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) - if ( - RawPosition & key - ): # if spatial series exists, get metadata from there - spatial_series = (RawPosition() & key).fetch_nwb()[0][ - "raw_position" - ] - reference_frame = spatial_series.reference_frame - comments = spatial_series.comments - else: - reference_frame = "" - comments = "no comments" + spatial_series = (RawPosition() & key).fetch_nwb()[0]["raw_position"] orientation = pynwb.behavior.CompassDirection() orientation.create_spatial_series( name="orientation", timestamps=final_df.index.to_numpy(), conversion=1.0, data=final_df["orientation"].to_numpy(), - reference_frame=reference_frame, - comments=comments, + reference_frame=spatial_series.reference_frame, + comments=spatial_series.comments, description="orientation", ) nwb_analysis_file = AnalysisNwbfile() diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index bf56fb6fd..dfc6095a5 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -232,38 +232,25 @@ def make(self, key): dlc_result.creation_time ).strftime("%Y-%m-%d %H:%M:%S") - # get video information - _, _, meters_per_pixel, video_time = get_video_path(key) - # check if a position interval exists for this epoch - try: - interval_list_name = ( - convert_epoch_interval_name_to_position_interval_name( - { - "nwb_file_name": key["nwb_file_name"], - "epoch": key["epoch"], - }, - populate_missing=False, - ) + logger.logger.info("getting raw position") + interval_list_name = ( + convert_epoch_interval_name_to_position_interval_name( + { + "nwb_file_name": key["nwb_file_name"], + "epoch": key["epoch"], + }, + populate_missing=False, ) - raw_position = True - except KeyError: - raw_position = False - - if raw_position: - logger.logger.info("Getting raw position") - spatial_series = ( - RawPosition() - & {**key, "interval_list_name": interval_list_name} - ).fetch_nwb()[0]["raw_position"] - pos_time = spatial_series.timestamps - reference_frame = spatial_series.reference_frame - comments = spatial_series.comments - else: - pos_time = video_time - reference_frame = "" - comments = "no comments" - - key["meters_per_pixel"] = meters_per_pixel + ) + spatial_series = ( + RawPosition() + & {**key, "interval_list_name": interval_list_name} + ).fetch_nwb()[0]["raw_position"] + _, _, _, video_time = get_video_path(key) + pos_time = spatial_series.timestamps + # TODO: should get timestamps from VideoFile, but need the video_frame_ind from RawPosition, + # which also has timestamps + key["meters_per_pixel"] = spatial_series.conversion # Insert entry into DLCPoseEstimation logger.logger.info( @@ -305,8 +292,8 @@ def make(self, key): timestamps=part_df.time.to_numpy(), conversion=METERS_PER_CM, data=part_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=reference_frame, - comments=comments, + reference_frame=spatial_series.reference_frame, + comments=spatial_series.comments, description="x_position, y_position", ) likelihood.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 11c7019f3..436d890d5 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -1,5 +1,3 @@ -from time import time - import datajoint as dj import numpy as np import pandas as pd @@ -169,7 +167,9 @@ def make(self, key): path=f"{output_dir.as_posix()}/log.log", print_console=False, ) as logger: - AnalysisNwbfile()._creation_times["pre_create_time"] = time() + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) logger.logger.info("-----------------------") idx = pd.IndexSlice # Get labels to smooth from Parameters table @@ -227,9 +227,6 @@ def make(self, key): .fetch_nwb()[0]["dlc_pose_estimation_position"] .get_spatial_series() ) - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) # Add dataframe to AnalysisNwbfile nwb_analysis_file = AnalysisNwbfile() position = pynwb.behavior.Position() diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index facfb8e25..74354db31 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -1,6 +1,5 @@ import copy from pathlib import Path -from time import time import datajoint as dj import numpy as np @@ -59,7 +58,9 @@ class DLCPosV1(SpyglassMixin, dj.Computed): def make(self, key): orig_key = copy.deepcopy(key) # Add to Analysis NWB file - AnalysisNwbfile()._creation_times["pre_create_time"] = time() + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) key["pose_eval_result"] = self.evaluate_pose_estimation(key) pos_nwb = (DLCCentroid & key).fetch_nwb()[0] @@ -113,9 +114,6 @@ def make(self, key): comments=vid_frame_obj.comments, ) - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) nwb_analysis_file = AnalysisNwbfile() key["orientation_object_id"] = nwb_analysis_file.add_nwb_object( key["analysis_file_name"], orientation From 97373ea6f0b6e67f016068fc71025f6c53bee919 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Mon, 13 May 2024 14:42:32 -0500 Subject: [PATCH 09/11] Fix test fails related to #957 (#971) * Address failing tests * Revert to bare make for no transaction * Update changelog --- CHANGELOG.md | 2 +- src/spyglass/common/common_behav.py | 43 +++++++++++----------- src/spyglass/common/common_dio.py | 3 -- src/spyglass/common/common_ephys.py | 16 +------- src/spyglass/common/common_session.py | 3 -- src/spyglass/common/common_task.py | 6 --- src/spyglass/common/populate_all_common.py | 22 +++++------ src/spyglass/spikesorting/imported.py | 4 -- src/spyglass/utils/dj_merge_tables.py | 10 ----- src/spyglass/utils/dj_mixin.py | 10 +++++ tests/common/test_behav.py | 14 +++---- tests/conftest.py | 2 +- tests/position/test_trodes.py | 6 --- 13 files changed, 54 insertions(+), 87 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 299a264b7..103240dca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ - Create class `SpyglassGroupPart` to aid delete propagations #899 - Fix bug report template #955 -- Add rollback option to `populate_all_common` #957 +- Add rollback option to `populate_all_common` #957, #971 - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index b7e8d953b..b206ed61f 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -1,7 +1,7 @@ import pathlib import re from functools import reduce -from typing import Dict +from typing import Dict, List, Union import datajoint as dj import ndx_franklab_novela @@ -43,7 +43,14 @@ class SpatialSeries(SpyglassMixin, dj.Part): name=null: varchar(32) # name of spatial series """ - def _no_transaction_make(self, keys=None): + def populate(self, *args, **kwargs): + logger.warning( + "PositionSource is a manual table with a custom `make`." + + " Use `make` instead." + ) + self.make(*args, **kwargs) + + def make(self, keys: Union[List[Dict], dj.Table]): """Insert position source data from NWB file.""" if not isinstance(keys, list): keys = [keys] @@ -52,10 +59,7 @@ def _no_transaction_make(self, keys=None): for key in keys: nwb_file_name = key.get("nwb_file_name") if not nwb_file_name: - raise ValueError( - "PositionSource.populate is an alias for a non-computed table " - + "and must be passed a key with nwb_file_name" - ) + raise ValueError("PositionSource.make requires nwb_file_name") self.insert_from_nwbfile(nwb_file_name, skip_duplicates=True) @classmethod @@ -106,7 +110,7 @@ def insert_from_nwbfile(cls, nwb_file_name, skip_duplicates=False) -> None: ) ) - with cls.connection.transaction: + with cls._safe_context(): IntervalList.insert(intervals, skip_duplicates=skip_duplicates) cls.insert(sources, skip_duplicates=skip_duplicates) cls.SpatialSeries.insert( @@ -223,9 +227,6 @@ def _get_column_names(rp, pos_id): return column_names def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" @@ -296,9 +297,6 @@ class StateScriptFile(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" @@ -433,7 +431,9 @@ def _no_transaction_make(self, key, verbose=True): + "in CameraDevice table." ) key["video_file_object_id"] = video_obj.object_id - self.insert1(key) + self.insert1( + key, skip_duplicates=True, allow_direct_insert=True + ) is_found = True if not is_found and verbose: @@ -567,7 +567,7 @@ def _no_transaction_make(self, key): # Insert into table key["position_interval_name"] = matching_pos_intervals[0] - self.insert1(key, allow_direct_insert=True) + self.insert1(key, skip_duplicates=True, allow_direct_insert=True) logger.info( "Populated PosIntervalMap for " + f'{nwb_file_name}, {key["interval_list_name"]}' @@ -660,9 +660,10 @@ def populate_position_interval_map_session(nwb_file_name: str): for interval_name in (TaskEpoch & {"nwb_file_name": nwb_file_name}).fetch( "interval_list_name" ): - PositionIntervalMap.populate( - { - "nwb_file_name": nwb_file_name, - "interval_list_name": interval_name, - } - ) + with PositionIntervalMap._safe_context(): + PositionIntervalMap().make( + { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + ) diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index 629adef47..228e9caf9 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -27,9 +27,6 @@ class DIOEvents(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index d03f6edff..4cddc099d 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -45,9 +45,6 @@ class ElectrodeGroup(SpyglassMixin, dj.Imported): """ def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" @@ -101,9 +98,6 @@ class Electrode(SpyglassMixin, dj.Imported): """ def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" @@ -190,8 +184,8 @@ def _no_transaction_make(self, key): key.update(electrode_config_dicts[elect_id]) electrode_inserts.append(key.copy()) - self.insert1( - key, + self.insert( + electrode_inserts, skip_duplicates=True, allow_direct_insert=True, # for no_transaction, pop_all_common ) @@ -276,9 +270,6 @@ class Raw(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" @@ -376,9 +367,6 @@ class SampleCount(SpyglassMixin, dj.Imported): _nwb_table = Nwbfile def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index e97934122..b8139939a 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -52,9 +52,6 @@ class Experimenter(SpyglassMixin, dj.Part): """ def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 49fd7bb0e..d63901ec2 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -97,12 +97,6 @@ class TaskEpoch(SpyglassMixin, dj.Imported): """ def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): - """Make without transaction - - Allows populate_all_common to work within a single transaction.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 04df52dec..e78b68de1 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -60,7 +60,7 @@ def single_transaction_make( raise_err: bool = False, error_constants: dict = None, ): - """For each table, run the `_no_transaction_make` method. + """For each table, run the `make` method directly instead of `populate`. Requires `allow_direct_insert` set to True within each method. Uses nwb_file_name search table key_source for relevant key. Currently assumes @@ -78,16 +78,16 @@ def single_transaction_make( key_source = parents[0].proj() for parent in parents[1:]: key_source *= parent.proj() - pop_key = (key_source & file_restr).fetch1("KEY") - try: - table()._no_transaction_make(pop_key) - except Exception as err: - if raise_err: - raise err - log_insert_error( - table=table, err=err, error_constants=error_constants - ) + for pop_key in (key_source & file_restr).fetch("KEY"): + try: + table().make(pop_key) + except Exception as err: + if raise_err: + raise err + log_insert_error( + table=table, err=err, error_constants=error_constants + ) def populate_all_common( @@ -123,7 +123,6 @@ def populate_all_common( [ # Tables that can be inserted in a single transaction Session, ElectrodeGroup, # Depends on Session - Electrode, # Depends on ElectrodeGroup Raw, # Depends on Session SampleCount, # Depends on Session DIOEvents, # Depends on Session @@ -133,6 +132,7 @@ def populate_all_common( # SensorData, # Not used by default. Generates large files ], [ # Tables that depend on above transaction + Electrode, # Depends on ElectrodeGroup PositionSource, # Depends on Session VideoFile, # Depends on TaskEpoch StateScriptFile, # Depends on TaskEpoch diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index 048502081..7e518d6d8 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -31,13 +31,9 @@ class Annotations(SpyglassMixin, dj.Part): """ def make(self, key): - self._no_transaction_make(key) - - def _no_transaction_make(self, key): """Make without transaction Allows populate_all_common to work within a single transaction.""" - raise RuntimeError("TEMP: This is a test error. Please ignore.") orig_key = copy.deepcopy(key) nwb_file_abs_path = Nwbfile.get_abs_path(key["nwb_file_name"]) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index ce96fe00e..37a51b674 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint @@ -369,15 +368,6 @@ def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None: for part, part_entries in parts_entries.items(): part.insert(part_entries, **kwargs) - @classmethod - def _safe_context(cls): - """Return transaction if not already in one.""" - return ( - cls.connection.transaction - if not cls.connection.in_transaction - else nullcontext() - ) - @classmethod def _ensure_dependencies_loaded(cls) -> None: """Ensure connection dependencies loaded. diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index cf1471ee6..be6063d04 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,6 +1,7 @@ from atexit import register as exit_register from atexit import unregister as exit_unregister from collections import OrderedDict +from contextlib import nullcontext from functools import cached_property from inspect import stack as inspect_stack from os import environ @@ -121,6 +122,15 @@ def file_like(self, name=None, **kwargs): return return self & f"{attr} LIKE '%{name}%'" + @classmethod + def _safe_context(cls): + """Return transaction if not already in one.""" + return ( + cls.connection.transaction + if not cls.connection.in_transaction + else nullcontext() + ) + # ------------------------------- fetch_nwb ------------------------------- @cached_property diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index 28c205442..1f4767dfb 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -22,15 +22,15 @@ def test_valid_epoch_num(common): assert epoch_num == 1, "PositionSource get_epoch_num failed" -def test_invalid_populate(common): - """Test invalid populate""" - with pytest.raises(ValueError): - common.PositionSource.populate(dict()) +def test_possource_make(common): + """Test custom populate""" + common.PositionSource().make(common.Session()) -def test_custom_populate(common): - """Test custom populate""" - common.PositionSource.populate(common.Session()) +def test_possource_make_invalid(common): + """Test invalid populate""" + with pytest.raises(ValueError): + common.PositionSource().make(dict()) def test_raw_position_fetchnwb(common, mini_pos, mini_pos_interval_dict): diff --git a/tests/conftest.py b/tests/conftest.py index 7950854d6..cd9350ff1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -316,7 +316,7 @@ def mini_insert( if len(Nwbfile()) != 0: dj_logger.warning("Skipping insert, use existing data.") else: - insert_sessions(mini_path.name) + insert_sessions(mini_path.name, raise_err=True) if len(Session()) == 0: raise ValueError("No sessions inserted.") diff --git a/tests/position/test_trodes.py b/tests/position/test_trodes.py index 81a515cf1..d4bc617f6 100644 --- a/tests/position/test_trodes.py +++ b/tests/position/test_trodes.py @@ -59,9 +59,3 @@ def test_fetch_df(trodes_pos_v1, trodes_params): ) hash_exp = "5296e74dea2e5e68d39f81bc81723a12" assert hash_df == hash_exp, "Dataframe differs from expected" - - -def test_null_video(sgp): - """Note: This will change if video is added to the test data.""" - with pytest.raises(FileNotFoundError): - sgp.v1.TrodesPosVideo().populate() From 00bd5d8f8d2896b951cdd0e6c51f85e59ef4e474 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 13 May 2024 14:13:28 -0700 Subject: [PATCH 10/11] Allow dlc pipeline to run without prior position tracking (#973) * fix dlc pose estimation populate if no raw position data * allow dlc pipeline to run without raw spatial data * update changelog * string formatting * fix analysis nwb create time * review changes * allow empty returns from convert_epoch_interval_name_ without error * switch to getattr --------- Co-authored-by: Sam Bray Co-authored-by: Chris Brozdowski --- CHANGELOG.md | 5 +++ src/spyglass/common/common_behav.py | 6 +--- .../position/v1/position_dlc_centroid.py | 14 ++++---- .../position/v1/position_dlc_orient.py | 19 ++++++---- .../v1/position_dlc_pose_estimation.py | 36 +++++++++++-------- .../position/v1/position_dlc_position.py | 9 +++-- .../position/v1/position_dlc_selection.py | 8 +++-- 7 files changed, 59 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 103240dca..096ec4fe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 +### Pipelines + +- DLC + - Allow dlc without pre-existing tracking data #973 + ## [0.5.2] (April 22, 2024) ### Infrastructure diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index b206ed61f..67e6e35d9 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -613,11 +613,7 @@ def convert_epoch_interval_name_to_position_interval_name( if len(pos_query) == 0: if populate_missing: PositionIntervalMap()._no_transaction_make(key) - else: - raise KeyError( - f"{key} must be populated in the PositionIntervalMap table " - + "prior to your current populate call" - ) + pos_query = PositionIntervalMap & key if len(pos_query) == 0: logger.info(f"No position intervals found for {key}") diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index e989265da..f1f077d6a 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -268,17 +268,19 @@ def make(self, key): ) position = pynwb.behavior.Position() velocity = pynwb.behavior.BehavioralTimeSeries() - spatial_series = (RawPosition() & key).fetch_nwb()[0][ - "raw_position" - ] + if query := (RawPosition & key): + spatial_series = query.fetch_nwb()[0]["raw_position"] + else: + spatial_series = None + METERS_PER_CM = 0.01 position.create_spatial_series( name="position", timestamps=final_df.index.to_numpy(), conversion=METERS_PER_CM, data=final_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=getattr(spatial_series, "reference_frame", ""), + comments=getattr(spatial_series, "comments", "no comments"), description="x_position, y_position", ) velocity.create_timeseries( @@ -289,7 +291,7 @@ def make(self, key): data=velocity_df.loc[ :, idx[("velocity_x", "velocity_y", "speed")] ].to_numpy(), - comments=spatial_series.comments, + comments=getattr(spatial_series, "comments", "no comments"), description="x_velocity, y_velocity, speed", ) velocity.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index e1e5c668b..f64802a59 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -1,3 +1,5 @@ +from time import time + import datajoint as dj import numpy as np import pandas as pd @@ -85,9 +87,7 @@ class DLCOrientation(SpyglassMixin, dj.Computed): def make(self, key): # Get labels to smooth from Parameters table - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() cohort_entries = DLCSmoothInterpCohort.BodyPart & key pos_df = pd.concat( { @@ -133,15 +133,22 @@ def make(self, key): final_df = pd.DataFrame( orientation, columns=["orientation"], index=pos_df.index ) - spatial_series = (RawPosition() & key).fetch_nwb()[0]["raw_position"] + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) + # if spatial series exists, get metadata from there + if query := (RawPosition & key): + spatial_series = query.fetch_nwb()[0]["raw_position"] + else: + spatial_series = None orientation = pynwb.behavior.CompassDirection() orientation.create_spatial_series( name="orientation", timestamps=final_df.index.to_numpy(), conversion=1.0, data=final_df["orientation"].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=getattr(spatial_series, "reference_frame", ""), + comments=getattr(spatial_series, "comments", "no comments"), description="orientation", ) nwb_analysis_file = AnalysisNwbfile() diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index dfc6095a5..35d21345c 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -232,8 +232,10 @@ def make(self, key): dlc_result.creation_time ).strftime("%Y-%m-%d %H:%M:%S") - logger.logger.info("getting raw position") - interval_list_name = ( + # get video information + _, _, meters_per_pixel, video_time = get_video_path(key) + # check if a position interval exists for this epoch + if interval_list_name := ( convert_epoch_interval_name_to_position_interval_name( { "nwb_file_name": key["nwb_file_name"], @@ -241,16 +243,16 @@ def make(self, key): }, populate_missing=False, ) - ) - spatial_series = ( - RawPosition() - & {**key, "interval_list_name": interval_list_name} - ).fetch_nwb()[0]["raw_position"] - _, _, _, video_time = get_video_path(key) - pos_time = spatial_series.timestamps - # TODO: should get timestamps from VideoFile, but need the video_frame_ind from RawPosition, - # which also has timestamps - key["meters_per_pixel"] = spatial_series.conversion + ): + logger.logger.info("Getting raw position") + spatial_series = ( + RawPosition() + & {**key, "interval_list_name": interval_list_name} + ).fetch_nwb()[0]["raw_position"] + else: + spatial_series = None + + key["meters_per_pixel"] = meters_per_pixel # Insert entry into DLCPoseEstimation logger.logger.info( @@ -282,7 +284,9 @@ def make(self, key): part_df = convert_to_cm(part_df, meters_per_pixel) logger.logger.info("adding timestamps to DataFrame") part_df = add_timestamps( - part_df, pos_time=pos_time, video_time=video_time + part_df, + pos_time=getattr(spatial_series, "timestamps", video_time), + video_time=video_time, ) key["bodypart"] = body_part position = pynwb.behavior.Position() @@ -292,8 +296,10 @@ def make(self, key): timestamps=part_df.time.to_numpy(), conversion=METERS_PER_CM, data=part_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, + reference_frame=get_video_path( + spatial_series, "reference_frame", "" + ), + comments=getattr(spatial_series, "comments", "no commwnts"), description="x_position, y_position", ) likelihood.create_timeseries( diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 436d890d5..11c7019f3 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -1,3 +1,5 @@ +from time import time + import datajoint as dj import numpy as np import pandas as pd @@ -167,9 +169,7 @@ def make(self, key): path=f"{output_dir.as_posix()}/log.log", print_console=False, ) as logger: - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() logger.logger.info("-----------------------") idx = pd.IndexSlice # Get labels to smooth from Parameters table @@ -227,6 +227,9 @@ def make(self, key): .fetch_nwb()[0]["dlc_pose_estimation_position"] .get_spatial_series() ) + key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) # Add dataframe to AnalysisNwbfile nwb_analysis_file = AnalysisNwbfile() position = pynwb.behavior.Position() diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 74354db31..facfb8e25 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -1,5 +1,6 @@ import copy from pathlib import Path +from time import time import datajoint as dj import numpy as np @@ -58,9 +59,7 @@ class DLCPosV1(SpyglassMixin, dj.Computed): def make(self, key): orig_key = copy.deepcopy(key) # Add to Analysis NWB file - key["analysis_file_name"] = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + AnalysisNwbfile()._creation_times["pre_create_time"] = time() key["pose_eval_result"] = self.evaluate_pose_estimation(key) pos_nwb = (DLCCentroid & key).fetch_nwb()[0] @@ -114,6 +113,9 @@ def make(self, key): comments=vid_frame_obj.comments, ) + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) nwb_analysis_file = AnalysisNwbfile() key["orientation_object_id"] = nwb_analysis_file.add_nwb_object( key["analysis_file_name"], orientation From 88720432f04ef9d05e1bca8a927705208cd6f0de Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Wed, 15 May 2024 12:06:37 -0700 Subject: [PATCH 11/11] Cleanup of dlc video (#975) * wrong function call * get epoch directly from table key * default to opencv for speed * update changelog --- CHANGELOG.md | 2 +- .../v1/position_dlc_pose_estimation.py | 2 +- .../position/v1/position_dlc_selection.py | 21 ++----------------- 3 files changed, 4 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 096ec4fe9..6b4a2a5b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ ### Pipelines - DLC - - Allow dlc without pre-existing tracking data #973 + - Allow dlc without pre-existing tracking data #973, #975 ## [0.5.2] (April 22, 2024) diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 35d21345c..6a670fc31 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -296,7 +296,7 @@ def make(self, key): timestamps=part_df.time.to_numpy(), conversion=METERS_PER_CM, data=part_df.loc[:, idx[("x", "y")]].to_numpy(), - reference_frame=get_video_path( + reference_frame=getattr( spatial_series, "reference_frame", "" ), comments=getattr(spatial_series, "comments", "no commwnts"), diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index facfb8e25..b140111e1 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -311,24 +311,7 @@ def make(self, key): if "video_params" not in params: params["video_params"] = {} M_TO_CM = 100 - interval_list_name = ( - convert_epoch_interval_name_to_position_interval_name( - { - "nwb_file_name": key["nwb_file_name"], - "epoch": key["epoch"], - }, - populate_missing=False, - ) - ) - key["interval_list_name"] = interval_list_name - epoch = ( - int( - key["interval_list_name"] - .replace("pos ", "") - .replace(" valid times", "") - ) - + 1 - ) + epoch = key["epoch"] pose_estimation_key = { "nwb_file_name": key["nwb_file_name"], "epoch": epoch, @@ -440,7 +423,7 @@ def make(self, key): likelihoods=likelihoods, position_time=position_time, video_time=None, - processor=params.get("processor", "matplotlib"), + processor=params.get("processor", "opencv"), frames=frames_arr, percent_frames=percent_frames, output_video_filename=output_video_filename,