diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index ff8ce2e19..a7a31758e 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -1,11 +1,8 @@ name: Publish docs on: - pull_request: - branches: - - master - types: - - closed push: + tags: # See PEP 440 for valid version format + - "*.*.*" # For docs bump, use X.X.XaX branches: - test_branch diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e6ebb45a..2b8aafc57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,11 @@ # Change Log -## 0.4.1 (Unreleased) +## [0.4.1] (June 30, 2023) - Add mkdocs automated deployment. #527, #537, #549, #551 -- Add class for Merge Tables. #556, #564 +- Add class for Merge Tables. #556, #564, #565 -## 0.4.0 (May 22, 2023) +## [0.4.0] (May 22, 2023) - Updated call to `spikeinterface.preprocessing.whiten` to use dtype np.float16. #446, @@ -33,7 +33,7 @@ - Updated `environment_position.yml`. #502 - Renamed `FirFilter` class to `FirFilterParameters`. #512 -## 0.3.4 (March 30, 2023) +## [0.3.4] (March 30, 2023) - Fixed error in spike sorting pipeline referencing the "probe_type" column which is no longer accessible from the `Electrode` table. #437 @@ -44,18 +44,28 @@ - Fixed inconsistency between capitalized/uncapitalized versions of "Intan" for DataAcquisitionAmplifier and DataAcquisitionDevice.adc_circuit. #430, #438 -## 0.3.3 (March 29, 2023) +## [0.3.3] (March 29, 2023) - Fixed errors from referencing the changed primary key for `Probe`. #429 -## 0.3.2 (March 28, 2023) +## [0.3.2] (March 28, 2023) - Fixed import of `common_nwbfile`. #424 -## 0.3.1 (March 24, 2023) +## [0.3.1] (March 24, 2023) - Fixed import error due to `sortingview.Workspace`. #421 -## 0.3.0 (March 24, 2023) +## [0.3.0] (March 24, 2023) -To be added. +- Refactor common for non Frank Lab data, allow file-based mods #420 +- Allow creation and linkage of device metadata from YAML #400 +- Move helper functions to utils directory #386 + +[0.4.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.1 +[0.4.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.0 +[0.3.4]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.4 +[0.3.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.3 +[0.3.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.2 +[0.3.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.1 +[0.3.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.3.0 diff --git a/docs/src/misc/merge_tables.md b/docs/src/misc/merge_tables.md index 46b1cb4b2..9fcd98ef1 100644 --- a/docs/src/misc/merge_tables.md +++ b/docs/src/misc/merge_tables.md @@ -70,12 +70,19 @@ These functions are described in the ### Restricting -One quirk of these utilities is that they take restrictions as arguments, -rather than with operators. So `Table & "field='value'"` becomes -`MergeTable.merge_view(restriction={'field':'value}`). This is because -`merge_view` is a `Union` rather than a true Table. While `merge_view` can -accept all valid restrictions, `merge_get_part` and `merge_get_parent` have -additional restriction logic when supplied with `dicts`. +In short: restrict Merge Tables with arguments, not the `&` operator. + +- Normally: `Table & "field='value'"` +- Instead: `MergeTable.merge_view(restriction="field='value'"`). + +_Caution_. The `&` operator may look like it's working when using `dict`, but +this is because invalid keys will be ignored. `Master & {'part_field':'value'}` +is equivalent to `Master` alone +([source](https://docs.datajoint.org/python/queries/06-Restriction.html#restriction-by-a-mapping)). + +When provided as arguments, methods like `merge_get_part` and `merge_get_parent` +will override the permissive treatment of mappings described above to only +return relevant tables. ### Building Downstream @@ -171,8 +178,7 @@ There are also functions for retrieving part/parent table(s) and fetching data. the format specified by keyword arguments and one's DataJoint config. ```python -result3 = (LFPOutput & common_keys_CH[0]).merge_get_part(join_master=True) -result4 = LFPOutput().merge_get_part(restriction=common_keys_CH[0]) +result4 = LFPOutput.merge_get_part(restriction=common_keys_CH[0],join_master=True) result5 = LFPOutput.merge_get_parent(restriction='nwb_file_name LIKE "CH%"') result6 = result5.fetch('lfp_sampling_rate') # Sample rate for all CH* files result7 = LFPOutput.merge_fetch("filter_name", "nwb_file_name") diff --git a/environment.yml b/environment.yml index 7a8c2775f..d7420426f 100644 --- a/environment.yml +++ b/environment.yml @@ -5,7 +5,7 @@ channels: - franklab - edeno dependencies: - - python>=3.8,<3.10 + - python>=3.9,<3.10 - jupyterlab>=3.* - pydotplus - dask diff --git a/environment_position.yml b/environment_position.yml index ef2b725a4..f06cc2348 100644 --- a/environment_position.yml +++ b/environment_position.yml @@ -15,7 +15,7 @@ channels: - franklab - edeno dependencies: - - python>=3.8, <3.10 + - python>=3.9, <3.10 - jupyterlab>=3.* - pydotplus>=2.0.* - libgcc diff --git a/pyproject.toml b/pyproject.toml index e45acc9cb..25822535e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "spyglass-neuro" description = "Neuroscience data analysis framework for reproducible research" readme = "README.md" -requires-python = ">=3.8,<3.10" +requires-python = ">=3.9,<3.10" license = { file = "LICENSE" } authors = [ { name = "Loren Frank", email = "loren.frank@ucsf.edu" }, @@ -78,7 +78,7 @@ test = [ "kachery-cloud", ] docs = [ - "hatch", # Get version from env + "hatch", # Get version from env "mike", # Docs versioning "mkdocs", # Docs core "mkdocs-exclude", # Docs exclude files diff --git a/src/spyglass/lfp/v1/lfp_artifact.py b/src/spyglass/lfp/v1/lfp_artifact.py index ce3ea502c..3985707e7 100644 --- a/src/spyglass/lfp/v1/lfp_artifact.py +++ b/src/spyglass/lfp/v1/lfp_artifact.py @@ -96,7 +96,7 @@ def make(self, key): ).fetch1("artifact_params") artifact_detection_algorithm = artifact_params[ - "ripple_detection_algorithm" + "artifact_detection_algorithm" ] artifact_detection_params = artifact_params[ "artifact_detection_algorithm_params" @@ -121,11 +121,13 @@ def make(self, key): # set up a name for no-artifact times using recording id # we need some name here for recording_name key["artifact_removed_interval_list_name"] = "_".join( - key["nwb_file_name"], - key["target_interval_list_name"], - "LFP", - key["artifact_params_name"], - "artifact_removed_valid_times", + [ + key["nwb_file_name"], + key["target_interval_list_name"], + "LFP", + key["artifact_params_name"], + "artifact_removed_valid_times", + ] ) LFPArtifactRemovedIntervalList.insert1(key, replace=True) diff --git a/src/spyglass/position/position_merge.py b/src/spyglass/position/position_merge.py index 035745b8d..5a21b8134 100644 --- a/src/spyglass/position/position_merge.py +++ b/src/spyglass/position/position_merge.py @@ -1,19 +1,14 @@ import functools as ft import os from pathlib import Path -from typing import Dict import datajoint as dj import numpy as np import pandas as pd from tqdm import tqdm as tqdm -from ..common.common_interval import IntervalList -from ..common.common_nwbfile import AnalysisNwbfile -from ..common.common_position import ( - IntervalPositionInfo as CommonIntervalPositionInfo, -) -from ..utils.dj_helper_fn import fetch_nwb +from ..common.common_position import IntervalPositionInfo as CommonPos +from ..utils.dj_merge_tables import _Merge from .v1.dlc_utils import check_videofile, get_video_path, make_video from .v1.position_dlc_pose_estimation import DLCPoseEstimationSelection from .v1.position_dlc_selection import DLCPosV1 @@ -21,25 +16,19 @@ schema = dj.schema("position_merge") -_valid_data_sources = ["DLC", "Trodes", "Common"] - @schema -class PositionOutput(dj.Manual): +class PositionOutput(_Merge): """ Table to identify source of Position Information from upstream options (e.g. DLC, Trodes, etc...) To add another upstream option, a new Part table - should be added in the same syntax as DLCPos and TrodesPos and - - Note: all part tables need to be named using the source+"Pos" convention - i.e. if the source='DLC', then the table is DLCPos + should be added in the same syntax as DLCPos and TrodesPos. """ definition = """ - -> IntervalList - source: varchar(40) - version: int - position_id: int + merge_id : uuid + --- + source: varchar(32) --- """ @@ -50,22 +39,10 @@ class DLCPosV1(dj.Part): definition = """ -> PositionOutput - -> DLCPosV1 --- - -> AnalysisNwbfile - position_object_id : varchar(80) - orientation_object_id : varchar(80) - velocity_object_id : varchar(80) + -> DLCPosV1 """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - class TrodesPosV1(dj.Part): """ Table to pass-through upstream Trodes Position Tracking information @@ -73,22 +50,10 @@ class TrodesPosV1(dj.Part): definition = """ -> PositionOutput - -> TrodesPosV1 --- - -> AnalysisNwbfile - position_object_id : varchar(80) - orientation_object_id : varchar(80) - velocity_object_id : varchar(80) + -> TrodesPosV1 """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - class CommonPos(dj.Part): """ Table to pass-through upstream Trodes Position Tracking information @@ -96,119 +61,14 @@ class CommonPos(dj.Part): definition = """ -> PositionOutput - -> CommonIntervalPositionInfo --- - -> AnalysisNwbfile - position_object_id : varchar(80) - orientation_object_id : varchar(80) - velocity_object_id : varchar(80) + -> CommonPos """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - - def insert1(self, key, params: Dict = None, **kwargs): - """Overrides insert1 to also insert into specific part table. - - Parameters - ---------- - key : Dict - key specifying the entry to insert - params : Dict, optional - A dictionary containing all table entries - not specified by the parent table (PosMerge) - """ - assert ( - key["source"] in _valid_data_sources - ), f"source needs to be one of {_valid_data_sources}" - position_id = key.get("position_id", None) - if position_id is None: - key["position_id"] = ( - dj.U().aggr(self & key, n="max(position_id)").fetch1("n") or 0 - ) + 1 - else: - id = (self & key).fetch("position_id") - if len(id) > 0: - position_id = max(id) + 1 - else: - position_id = max(0, position_id) - key["position_id"] = position_id - super().insert1(key, **kwargs) - source = key["source"] - if source in ["Common"]: - table_name = f"{source}Pos" - else: - version = key["version"] - table_name = f"{source}PosV{version}" - part_table = getattr(self, table_name) - # TODO: The parent table to refer to is hard-coded here, expecting it to be the second - # Table in the definition. This could be more flexible. - if params: - table_query = ( - dj.FreeTable(dj.conn(), full_table_name=part_table.parents()[1]) - & key - & params - ) - else: - table_query = ( - dj.FreeTable(dj.conn(), full_table_name=part_table.parents()[1]) - & key - ) - if any( - "head" in col - for col in list(table_query.fetch().dtype.fields.keys()) - ): - ( - analysis_file_name, - position_object_id, - orientation_object_id, - velocity_object_id, - ) = table_query.fetch1( - "analysis_file_name", - "head_position_object_id", - "head_orientation_object_id", - "head_velocity_object_id", - ) - else: - ( - analysis_file_name, - position_object_id, - orientation_object_id, - velocity_object_id, - ) = table_query.fetch1( - "analysis_file_name", - "position_object_id", - "orientation_object_id", - "velocity_object_id", - ) - part_table.insert1( - { - **key, - "analysis_file_name": analysis_file_name, - "position_object_id": position_object_id, - "orientation_object_id": orientation_object_id, - "velocity_object_id": velocity_object_id, - **params, - }, - ) - - def fetch_nwb(self, *attrs, **kwargs): - source = self.fetch1("source") - if source in ["Common"]: - table_name = f"{source}Pos" - else: - version = self.fetch1("version") - table_name = f"{source}PosV{version}" - part_table = getattr(self, table_name) & self - return part_table.fetch_nwb() - def fetch1_dataframe(self): - nwb_data = self.fetch_nwb()[0] + # proj replaces operator restriction to enable + # (TableName & restriction).fetch1_dataframe() + nwb_data = self.fetch_nwb(self.proj())[0] index = pd.Index( np.asarray(nwb_data["position"].get_spatial_series().timestamps), name="time", @@ -285,7 +145,6 @@ class PositionVideoSelection(dj.Manual): plot_id : int plot : varchar(40) # Which position info to overlay on video file --- - position_ids : mediumblob output_dir : varchar(255) # directory where to save output video """ @@ -323,11 +182,15 @@ class PositionVideo(dj.Computed): """ def make(self, key): - assert key["plot"] in ["DLC", "Trodes", "Common", "All"] + raise NotImplementedError("work in progress -DPG") + + plot = key.get("plot") + if plot not in ["DLC", "Trodes", "Common", "All"]: + raise ValueError(f"Plot {key['plot']} not supported") + # CBroz: I was told only tests should `assert`, code should `raise` + M_TO_CM = 100 - output_dir, position_ids = (PositionVideoSelection & key).fetch1( - "output_dir", "position_ids" - ) + output_dir = (PositionVideoSelection & key).fetch1("output_dir") print("Loading position data...") # raw_position_df = ( @@ -337,122 +200,27 @@ def make(self, key): # "interval_list_name": key["interval_list_name"], # } # ).fetch1_dataframe() + query = { "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["interval_list_name"], } - if key["plot"] == "DLC": - assert position_ids["dlc_position_id"] - pos_df = ( - PositionOutput() - & { - **query, - "source": "DLC", - "position_id": position_ids["dlc_position_id"], - } - ).fetch1_dataframe() - elif key["plot"] == "Trodes": - assert position_ids["trodes_position_id"] - pos_df = ( - PositionOutput() - & { - **query, - "source": "Trodes", - "position_id": position_ids["trodes_position_id"], - } - ).fetch1_dataframe() - elif key["plot"] == "Common": - assert position_ids["common_position_id"] - pos_df = ( - PositionOutput() - & { - **query, - "source": "Common", - "position_id": position_ids["common_position_id"], - } - ).fetch1_dataframe() - elif key["plot"] == "All": + merge_entries = { + "DLC": PositionOutput.DLCPosV1 & query, + "Trodes": PositionOutput.TrodesPosV1 & query, + "Common": PositionOutput.CommonPos & query, + } + + position_mean_dict = {} + if plot == "All": # Check which entries exist in PositionOutput merge_dict = {} - if "dlc_position_id" in position_ids: - if ( - len( - PositionOutput() - & { - **query, - "source": "DLC", - "position_id": position_ids["dlc_position_id"], - } - ) - > 0 - ): - dlc_df = ( - ( - PositionOutput() - & { - **query, - "source": "DLC", - "position_id": position_ids["dlc_position_id"], - } - ) - .fetch1_dataframe() - .drop(columns=["velocity_x", "velocity_y", "speed"]) - ) - merge_dict["DLC"] = dlc_df - if "trodes_position_id" in position_ids: - if ( - len( - PositionOutput() - & { - **query, - "source": "Trodes", - "position_id": position_ids["trodes_position_id"], - } + for source, entries in merge_entries.items(): + if entries: + merge_dict[source] = entries.fetch1_dataframe().drop( + columns=["velocity_x", "velocity_y", "speed"] ) - > 0 - ): - trodes_df = ( - ( - PositionOutput() - & { - **query, - "source": "Trodes", - "position_id": position_ids[ - "trodes_position_id" - ], - } - ) - .fetch1_dataframe() - .drop(columns=["velocity_x", "velocity_y", "speed"]) - ) - merge_dict["Trodes"] = trodes_df - if "common_position_id" in position_ids: - if ( - len( - PositionOutput() - & { - **query, - "source": "Common", - "position_id": position_ids["common_position_id"], - } - ) - > 0 - ): - common_df = ( - ( - PositionOutput() - & { - **query, - "source": "Common", - "position_id": position_ids[ - "common_position_id" - ], - } - ) - .fetch1_dataframe() - .drop(columns=["velocity_x", "velocity_y", "speed"]) - ) - merge_dict["Common"] = common_df + pos_df = ft.reduce( lambda left, right,: pd.merge( left[1], @@ -463,15 +231,33 @@ def make(self, key): ), merge_dict.items(), ) - print("Loading video data...") - epoch = ( - int( - key["interval_list_name"] - .replace("pos ", "") - .replace(" valid times", "") + position_mean_dict = { + source: { + "position": np.asarray( + pos_df[[f"position_x_{source}", f"position_y_{source}"]] + ), + "orientation": np.asarray( + pos_df[[f"orientation_{source}"]] + ), + } + for source in merge_dict.keys() + } + else: + if plot == "DLC": + # CBroz - why is this extra step needed for DLC? + pos_df_key = merge_entries[plot].fetch1(as_dict=True) + pos_df = (PositionOutput & pos_df_key).fetch1_dataframe() + elif plot in ["Trodes", "Common"]: + pos_df = merge_entries[plot].fetch1_dataframe() + + position_mean_dict[plot]["position"] = np.asarray( + pos_df[["position_x", "position_y"]] ) - + 1 - ) + position_mean_dict[plot]["orientation"] = np.asarray( + pos_df[["orientation"]] + ) + + print("Loading video data...") ( video_path, @@ -479,20 +265,24 @@ def make(self, key): meters_per_pixel, video_time, ) = get_video_path( - {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} + { + "nwb_file_name": key["nwb_file_name"], + "epoch": int( + "".join(filter(str.isdigit, key["interval_list_name"])) + ) + + 1, + } ) video_dir = os.path.dirname(video_path) + "/" video_frame_col_name = [ col for col in pos_df.columns if "video_frame_ind" in col - ] - video_frame_inds = ( - pos_df[video_frame_col_name[0]].astype(int).to_numpy() - ) - if key["plot"] in ["DLC", "All"]: - temp_key = (PositionOutput.DLCPosV1 & key).fetch1("KEY") - video_path = (DLCPoseEstimationSelection & temp_key).fetch1( - "video_path" - ) + ][0] + video_frame_inds = pos_df[video_frame_col_name].astype(int).to_numpy() + if plot in ["DLC", "All"]: + video_path = ( + DLCPoseEstimationSelection + & (PositionOutput.DLCPosV1 & key).fetch1("KEY") + ).fetch1("video_path") else: video_path = check_videofile( video_dir, key["output_dir"], video_filename @@ -506,28 +296,7 @@ def make(self, key): # centroids = {'red': np.asarray(raw_position_df[['xloc', 'yloc']]), # 'green': np.asarray(raw_position_df[['xloc2', 'yloc2']])} - position_mean_dict = {} - if key["plot"] in ["DLC", "Trodes", "Common"]: - position_mean_dict[key["plot"]]["position"] = np.asarray( - pos_df[["position_x", "position_y"]] - ) - position_mean_dict[key["plot"]]["orientation"] = np.asarray( - pos_df[["orientation"]] - ) - elif key["plot"] == "All": - position_mean_dict = { - source: { - "position": np.asarray( - pos_df[[f"position_x_{source}", f"position_y_{source}"]] - ), - "orientation": np.asarray( - pos_df[[f"orientation_{source}"]] - ), - } - for source in merge_dict.keys() - } - position_time = np.asarray(pos_df.index) - cm_per_pixel = meters_per_pixel * M_TO_CM + print("Making video...") make_video( @@ -535,10 +304,10 @@ def make(self, key): video_frame_inds, position_mean_dict, video_time, - position_time, + np.asarray(pos_df.index), processor="opencv", output_video_filename=output_video_filename, - cm_to_pixels=cm_per_pixel, + cm_to_pixels=meters_per_pixel * M_TO_CM, disable_progressbar=False, ) self.insert1(key) diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index b48ee6be4..c72cb9931 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -1,9 +1,11 @@ +import copy from pathlib import Path import datajoint as dj import numpy as np import pandas as pd import pynwb +from datajoint.utils import to_camel_case from tqdm import tqdm as tqdm from ...common.common_nwbfile import AnalysisNwbfile @@ -52,59 +54,60 @@ class DLCPosV1(dj.Computed): """ def make(self, key): + orig_key = copy.deepcopy(key) key["pose_eval_result"] = self.evaluate_pose_estimation(key) - position_nwb_data = (DLCCentroid & key).fetch_nwb()[0] - orientation_nwb_data = (DLCOrientation & key).fetch_nwb()[0] - position_object = position_nwb_data["dlc_position"].spatial_series[ - "position" - ] - velocity_object = position_nwb_data["dlc_velocity"].time_series[ - "velocity" - ] - video_frame_object = position_nwb_data["dlc_velocity"].time_series[ - "video_frame_ind" - ] - orientation_object = orientation_nwb_data[ - "dlc_orientation" - ].spatial_series["orientation"] + + pos_nwb = (DLCCentroid & key).fetch_nwb()[0] + ori_nwb = (DLCOrientation & key).fetch_nwb()[0] + + pos_obj = pos_nwb["dlc_position"].spatial_series["position"] + vel_obj = pos_nwb["dlc_velocity"].time_series["velocity"] + vid_frame_obj = pos_nwb["dlc_velocity"].time_series["video_frame_ind"] + ori_obj = ori_nwb["dlc_orientation"].spatial_series["orientation"] + position = pynwb.behavior.Position() orientation = pynwb.behavior.CompassDirection() velocity = pynwb.behavior.BehavioralTimeSeries() + position.create_spatial_series( - name=position_object.name, - timestamps=np.asarray(position_object.timestamps), - conversion=position_object.conversion, - data=np.asarray(position_object.data), - reference_frame=position_object.reference_frame, - comments=position_object.comments, - description=position_object.description, + name=pos_obj.name, + timestamps=np.asarray(pos_obj.timestamps), + conversion=pos_obj.conversion, + data=np.asarray(pos_obj.data), + reference_frame=pos_obj.reference_frame, + comments=pos_obj.comments, + description=pos_obj.description, ) + orientation.create_spatial_series( - name=orientation_object.name, - timestamps=np.asarray(orientation_object.timestamps), - conversion=orientation_object.conversion, - data=np.asarray(orientation_object.data), - reference_frame=orientation_object.reference_frame, - comments=orientation_object.comments, - description=orientation_object.description, + name=ori_obj.name, + timestamps=np.asarray(ori_obj.timestamps), + conversion=ori_obj.conversion, + data=np.asarray(ori_obj.data), + reference_frame=ori_obj.reference_frame, + comments=ori_obj.comments, + description=ori_obj.description, ) + velocity.create_timeseries( - name=velocity_object.name, - timestamps=np.asarray(velocity_object.timestamps), - conversion=velocity_object.conversion, - unit=velocity_object.unit, - data=np.asarray(velocity_object.data), - comments=velocity_object.comments, - description=velocity_object.description, + name=vel_obj.name, + timestamps=np.asarray(vel_obj.timestamps), + conversion=vel_obj.conversion, + unit=vel_obj.unit, + data=np.asarray(vel_obj.data), + comments=vel_obj.comments, + description=vel_obj.description, ) + velocity.create_timeseries( - name=video_frame_object.name, - unit=video_frame_object.unit, - timestamps=np.asarray(video_frame_object.timestamps), - data=np.asarray(video_frame_object.data), - description=video_frame_object.description, - comments=video_frame_object.comments, + name=vid_frame_obj.name, + unit=vid_frame_obj.unit, + timestamps=np.asarray(vid_frame_obj.timestamps), + data=np.asarray(vid_frame_obj.data), + description=vid_frame_obj.description, + comments=vid_frame_obj.comments, ) + # Add to Analysis NWB file key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] @@ -124,23 +127,15 @@ def make(self, key): nwb_file_name=key["nwb_file_name"], analysis_file_name=key["analysis_file_name"], ) - self.insert1(key) - from ..position_merge import PositionOutput - key["source"] = "DLC" - key["version"] = 1 - dlc_key = key.copy() - del dlc_key["pose_eval_result"] - key["interval_list_name"] = f"pos {key['epoch']-1} valid times" - valid_fields = PositionOutput().fetch().dtype.fields.keys() - entries_to_delete = [ - entry for entry in key.keys() if entry not in valid_fields - ] - for entry in entries_to_delete: - del key[entry] + from ..position_merge import PositionOutput - PositionOutput().insert1(key=key, params=dlc_key, skip_duplicates=True) + part_name = to_camel_case(self.table_name.split("__")[-1]) + # TODO: The next line belongs in a merge table function + PositionOutput._merge_insert( + [orig_key], part_name=part_name, skip_duplicates=True + ) def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb( diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index a31aa158d..68f43526e 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -1,3 +1,4 @@ +import copy import os from pathlib import Path @@ -8,6 +9,7 @@ import pandas as pd import pynwb import pynwb.behavior +from datajoint.utils import to_camel_case from position_tools import ( get_angle, get_centriod, @@ -106,6 +108,7 @@ class TrodesPosV1(dj.Computed): """ def make(self, key): + orig_key = copy.deepcopy(key) print(f"Computing position for: {key}") key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] @@ -215,19 +218,13 @@ def make(self, key): AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) self.insert1(key) + from ..position_merge import PositionOutput - key["source"] = "Trodes" - key["version"] = 1 - trodes_key = key.copy() - valid_fields = PositionOutput().fetch().dtype.fields.keys() - entries_to_delete = [ - entry for entry in key.keys() if entry not in valid_fields - ] - for entry in entries_to_delete: - del key[entry] - PositionOutput().insert1( - key=key, params=trodes_key, skip_duplicates=True + part_name = to_camel_case(self.table_name.split("__")[-1]) + # TODO: The next line belongs in a merge table function + PositionOutput._merge_insert( + [orig_key], part_name=part_name, skip_duplicates=True ) @staticmethod diff --git a/src/spyglass/position_linearization/__init__.py b/src/spyglass/position_linearization/__init__.py new file mode 100644 index 000000000..49103ec16 --- /dev/null +++ b/src/spyglass/position_linearization/__init__.py @@ -0,0 +1,3 @@ +from spyglass.position_linearization.position_linearization_merge import ( + LinearizedPositionOutput, +) diff --git a/src/spyglass/position_linearization/position_linearization_merge.py b/src/spyglass/position_linearization/position_linearization_merge.py new file mode 100644 index 000000000..68f9d063a --- /dev/null +++ b/src/spyglass/position_linearization/position_linearization_merge.py @@ -0,0 +1,30 @@ +import datajoint as dj + +from spyglass.position_linearization.v1.linearization import ( + LinearizedPositionV1, +) # noqa F401 + +from ..utils.dj_merge_tables import _Merge + +schema = dj.schema("position_linearization_merge") + + +@schema +class LinearizedPositionOutput(_Merge): + definition = """ + merge_id: uuid + --- + source: varchar(32) + """ + + class LinearizedPositionV1(dj.Part): + definition = """ + -> LinearizedPositionOutput + --- + -> LinearizedPositionV1 + """ + + def fetch1_dataframe(self): + return self.fetch_nwb(self.proj())[0]["linearized_position"].set_index( + "time" + ) diff --git a/src/spyglass/position_linearization/v1/__init__.py b/src/spyglass/position_linearization/v1/__init__.py new file mode 100644 index 000000000..ec3dffc65 --- /dev/null +++ b/src/spyglass/position_linearization/v1/__init__.py @@ -0,0 +1,7 @@ +from spyglass.position_linearization.v1.linearization import ( + LinearizationParameters, + LinearizationSelection, + LinearizedPositionV1, + NodePicker, + TrackGraph, +) diff --git a/src/spyglass/position_linearization/v1/linearization.py b/src/spyglass/position_linearization/v1/linearization.py new file mode 100644 index 000000000..7f30be8c9 --- /dev/null +++ b/src/spyglass/position_linearization/v1/linearization.py @@ -0,0 +1,187 @@ +import copy +import datajoint as dj +from datajoint.utils import to_camel_case +import numpy as np +from track_linearization import ( + get_linearized_position, + make_track_graph, + plot_graph_as_1D, + plot_track_graph, +) + +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.position.position_merge import PositionOutput +from spyglass.utils.dj_helper_fn import fetch_nwb + +schema = dj.schema("position_linearization_v1") + + +@schema +class LinearizationParameters(dj.Lookup): + """Choose whether to use an HMM to linearize position. This can help when + the eucledian distances between separate arms are too close and the previous + position has some information about which arm the animal is on.""" + + definition = """ + linearization_param_name : varchar(80) # name for this set of parameters + --- + use_hmm = 0 : int # use HMM to determine linearization + # How much to prefer route distances between successive time points that are closer to the euclidean distance. Smaller numbers mean the route distance is more likely to be close to the euclidean distance. + route_euclidean_distance_scaling = 1.0 : float + sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm). + # Biases the transition matrix to prefer the current track segment. + diagonal_bias = 0.5 : float + """ + + +@schema +class TrackGraph(dj.Manual): + """Graph representation of track representing the spatial environment. + Used for linearizing position.""" + + definition = """ + track_graph_name : varchar(80) + ---- + environment : varchar(80) # Type of Environment + node_positions : blob # 2D position of track_graph nodes, shape (n_nodes, 2) + edges: blob # shape (n_edges, 2) + linear_edge_order : blob # order of track graph edges in the linear space, shape (n_edges, 2) + linear_edge_spacing : blob # amount of space between edges in the linear space, shape (n_edges,) + """ + + def get_networkx_track_graph(self, track_graph_parameters=None): + if track_graph_parameters is None: + track_graph_parameters = self.fetch1() + return make_track_graph( + node_positions=track_graph_parameters["node_positions"], + edges=track_graph_parameters["edges"], + ) + + def plot_track_graph(self, ax=None, draw_edge_labels=False, **kwds): + """Plot the track graph in 2D position space.""" + track_graph = self.get_networkx_track_graph() + plot_track_graph( + track_graph, ax=ax, draw_edge_labels=draw_edge_labels, **kwds + ) + + def plot_track_graph_as_1D( + self, + ax=None, + axis="x", + other_axis_start=0.0, + draw_edge_labels=False, + node_size=300, + node_color="#1f77b4", + ): + """Plot the track graph in 1D to see how the linearization is set up.""" + track_graph_parameters = self.fetch1() + track_graph = self.get_networkx_track_graph( + track_graph_parameters=track_graph_parameters + ) + plot_graph_as_1D( + track_graph, + edge_order=track_graph_parameters["linear_edge_order"], + edge_spacing=track_graph_parameters["linear_edge_spacing"], + ax=ax, + axis=axis, + other_axis_start=other_axis_start, + draw_edge_labels=draw_edge_labels, + node_size=node_size, + node_color=node_color, + ) + + +@schema +class LinearizationSelection(dj.Lookup): + definition = """ + -> PositionOutput + -> TrackGraph + -> LinearizationParameters + --- + """ + + +@schema +class LinearizedPositionV1(dj.Computed): + """Linearized position for a given interval""" + + definition = """ + -> LinearizationSelection + --- + -> AnalysisNwbfile + linearized_position_object_id : varchar(40) + """ + + def make(self, key): + orig_key = copy.deepcopy(key) + print(f"Computing linear position for: {key}") + + position_nwb = PositionOutput.fetch_nwb(key)[0] + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) + position = np.asarray( + position_nwb["position"].get_spatial_series().data + ) + time = np.asarray( + position_nwb["position"].get_spatial_series().timestamps + ) + + linearization_parameters = ( + LinearizationParameters() + & {"linearization_param_name": key["linearization_param_name"]} + ).fetch1() + track_graph_info = ( + TrackGraph() & {"track_graph_name": key["track_graph_name"]} + ).fetch1() + + track_graph = make_track_graph( + node_positions=track_graph_info["node_positions"], + edges=track_graph_info["edges"], + ) + + linear_position_df = get_linearized_position( + position=position, + track_graph=track_graph, + edge_spacing=track_graph_info["linear_edge_spacing"], + edge_order=track_graph_info["linear_edge_order"], + use_HMM=linearization_parameters["use_hmm"], + route_euclidean_distance_scaling=linearization_parameters[ + "route_euclidean_distance_scaling" + ], + sensor_std_dev=linearization_parameters["sensor_std_dev"], + diagonal_bias=linearization_parameters["diagonal_bias"], + ) + + linear_position_df["time"] = time + + # Insert into analysis nwb file + nwb_analysis_file = AnalysisNwbfile() + + key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=linear_position_df, + ) + + nwb_analysis_file.add( + nwb_file_name=key["nwb_file_name"], + analysis_file_name=key["analysis_file_name"], + ) + + self.insert1(key) + + from ..position_linearization_merge import LinearizedPositionOutput + + part_name = to_camel_case(self.table_name.split("__")[-1]) + + LinearizedPositionOutput._merge_insert( + [orig_key], part_name=part_name, skip_duplicates=True + ) + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs + ) + + def fetch1_dataframe(self): + return self.fetch_nwb()[0]["linearized_position"].set_index("time") diff --git a/src/spyglass/spikesorting/spikesorting_sorting.py b/src/spyglass/spikesorting/spikesorting_sorting.py index ede6a0ff1..4e2cd0c27 100644 --- a/src/spyglass/spikesorting/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/spikesorting_sorting.py @@ -201,6 +201,7 @@ def make(self, key: dict): # need to remove tempdir and whiten from sorter_params sorter_params.pop("tempdir", None) sorter_params.pop("whiten", None) + sorter_params.pop("outputs", None) # Detect peaks for clusterless decoding detected_spikes = detect_peaks(recording, **sorter_params) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 7c99ac893..7c0d2ee9a 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -3,7 +3,7 @@ from pprint import pprint import datajoint as dj -from datajoint.condition import make_condition +from datajoint.condition import AndList, make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html from datajoint.utils import from_camel_case, to_camel_case @@ -18,7 +18,14 @@ class Merge(dj.Manual): - """Adds funcs to support standard Merge table operations.""" + """Adds funcs to support standard Merge table operations. + + Many methods have the @classmethod decorator to permit MergeTable.method() + symtax. This makes access to instance attributes (e.g., (MergeTable & + "example='restriction'").restriction) harder, but these attributes have + limited utility when the user wants to, for example, restrict the merged + view rather than the master table itself. + """ def __init__(self): super().__init__() @@ -50,25 +57,28 @@ def _merge_restrict_parts( restriction: dict = True, as_objects: bool = True, return_empties: bool = True, + add_invalid_restrict: bool = True, ) -> list: """Returns a list of parts with restrictions applied. Parameters --------- - restriction: dict, optional - Restriction to apply to the merged view. Default True, no restrictions. + restriction: str, optional + Restriction to apply to the parts. Default True, no restrictions. as_objects: bool, optional Default True. Return part tables as objects return_empties: bool, optional Default True. Return empty part tables + add_invalid_restrict: bool, optional + Default True. Include part for which the restriction is invalid. Returns ------ list list of datajoint tables, parts of Merge Table """ - if not dj.conn.connection.dependencies._loaded: - dj.conn.connection.dependencies.load() # Otherwise parts returns none + + cls._ensure_dependencies_loaded() if not restriction: restriction = True @@ -81,7 +91,7 @@ def _merge_restrict_parts( if ( not return_empties and isinstance(restr_str, str) - and cls()._reserved_sk in restr_str + and f"`{cls()._reserved_sk}`" in restr_str ): parts_all = [ part @@ -95,13 +105,21 @@ def _merge_restrict_parts( ] if isinstance(restriction, dict): # restr by source already done above _ = restriction.pop(cls()._reserved_sk, None) # won't work for str + # If a dict restriction has all invalid keys, it is treated as True + if not add_invalid_restrict: + parts_all = [ # so exclude tables w/ nonmatching attrs + p + for p in parts_all + if all([k in p.heading.names for k in restriction.keys()]) + ] parts = [] for part in parts_all: try: parts.append(part.restrict(restriction)) except DataJointError: # If restriction not valid on given part - parts.append(part) + if add_invalid_restrict: + parts.append(part) if not return_empties: parts = [p for p in parts if len(p)] @@ -113,9 +131,11 @@ def _merge_restrict_parts( @classmethod def _merge_restrict_parents( cls, - restriction: dict = True, + restriction: str = True, + parent_name: str = None, as_objects: bool = True, return_empties: bool = True, + add_invalid_restrict: bool = True, ) -> list: """Returns a list of part parents with restrictions applied. @@ -124,40 +144,49 @@ def _merge_restrict_parents( Parameters --------- - restriction: dict, optional + restriction: str, optional Restriction to apply to the returned parent. Default True, no restrictions. + parent_name: str, optional + CamelCase name of the parent. as_objects: bool, optional Default True. Return part tables as objects return_empties: bool, optional Default True. Return empty part tables + add_invalid_restrict: bool, optional + Default True. Include part for which the restriction is invalid. Returns ------ list list of datajoint tables, parents of parts of Merge Table """ + # .restict(restriction) does not work on returned part FreeTable + # & part.fetch below restricts parent to entries in merge table part_parents = [ - parent & part - # .restict(restriction) - # .fetch( - # *part.heading.secondary_attributes, as_dict=True - # ) + parent + & part.fetch(*part.heading.secondary_attributes, as_dict=True) for part in cls()._merge_restrict_parts( - restriction=restriction, return_empties=return_empties + restriction=restriction, + return_empties=return_empties, + add_invalid_restrict=add_invalid_restrict, ) for parent in part.parents(as_objects=True) # ID respective parents if cls().table_name not in parent.full_table_name # Not merge table ] + if parent_name: + part_parents = [ + p + for p in part_parents + if from_camel_case(parent_name) in p.full_table_name + ] if not as_objects: part_parents = [p.full_table_name for p in part_parents] return part_parents @classmethod - def _merge_repr( - cls, restriction: dict = True, **kwargs - ) -> dj.expression.Union: + def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: """Merged view, including null entries for columns unique to one part table. Parameters @@ -172,7 +201,11 @@ def _merge_repr( parts = [ cls() * p # join with master to include sec key (i.e., 'source') - for p in cls._merge_restrict_parts(restriction=restriction) + for p in cls._merge_restrict_parts( + restriction=restriction, + add_invalid_restrict=False, + return_empties=False, + ) ] primary_attrs = list( @@ -201,13 +234,17 @@ def _merge_repr( return query @classmethod - def _merge_insert(cls, rows: list, **kwargs) -> None: - """Insert rows into merge table, ensuring db integrity and mutual exclusivity + def _merge_insert( + cls, rows: list, part_name: str = None, mutual_exclusvity=True, **kwargs + ) -> None: + """Insert rows into merge table. Ensure mutual exclusivity Parameters --------- rows: List[dict] An iterable where an element is a dictionary. + part: str, optional + CamelCase name of the part table Raises ------ @@ -217,6 +254,7 @@ def _merge_insert(cls, rows: list, **kwargs) -> None: If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ + cls._ensure_dependencies_loaded() try: for r in iter(rows): @@ -227,46 +265,71 @@ def _merge_insert(cls, rows: list, **kwargs) -> None: raise TypeError('Input "rows" must be a list of dictionaries') parts = cls._merge_restrict_parts(as_objects=True) + if part_name: + parts = [ + p + for p in parts + if from_camel_case(part_name) in p.full_table_name + ] + master_entries = [] parts_entries = {p: [] for p in parts} for row in rows: - key = {} - for part in parts: - master = part.parents(as_objects=True)[-1] + keys = [] # empty to-be-inserted key + for part in parts: # check each part + part_parent = part.parents(as_objects=True)[-1] part_name = to_camel_case(part.table_name.split("__")[-1]) - if master & row: - if not key: - key = (master & row).fetch1("KEY") - master_pk = { - cls()._reserved_pk: dj.hash.key_hash(key), - } - parts_entries[part].append({**master_pk, **key}) - master_entries.append( - {**master_pk, cls()._reserved_sk: part_name} - ) - else: + if part_parent & row: # if row is in part parent + if keys and mutual_exclusvity: # if key from other part raise ValueError( "Mutual Exclusivity Error! Entry exists in more " + f"than one table - Entry: {row}" ) - if not key: + keys = (part_parent & row).fetch("KEY") # get pk + if len(keys) > 1: + raise ValueError( + "Ambiguous entry. Data has mult rows in " + + f"{part_name}:\n\tData:{row}\n\t{keys}" + ) + master_pk = { # make uuid + cls()._reserved_pk: dj.hash.key_hash(keys[0]), + } + parts_entries[part].append({**master_pk, **keys[0]}) + master_entries.append( + {**master_pk, cls()._reserved_sk: part_name} + ) + + if not keys: raise ValueError( "Non-existing entry in any of the parent tables - Entry: " + f"{row}" ) - # 1. nullcontext() allows use within `make` but decreases reliability - # 2. cls.connection.transaction is more reliable but throws errors if - # used within another transaction, i.e. in `make` - - with nullcontext(): # TODO: ensure this block within transaction + with cls._safe_context(): super().insert(cls(), master_entries, **kwargs) for part, part_entries in parts_entries.items(): part.insert(part_entries, **kwargs) @classmethod - def insert(cls, rows: list, **kwargs): + def _safe_context(cls): + """Return transaction if not already in one.""" + return ( + cls.connection.transaction + if not cls.connection.in_transaction + else nullcontext() + ) + + @classmethod + def _ensure_dependencies_loaded(cls) -> None: + """Ensure connection dependencies loaded. + + Otherwise parts returns none + """ + if not dj.conn.connection.dependencies._loaded: + dj.conn.connection.dependencies.load() + + def insert(self, rows: list, mutual_exclusvity=True, **kwargs): """Merges table specific insert Ensuring db integrity and mutual exclusivity @@ -275,6 +338,8 @@ def insert(cls, rows: list, **kwargs): --------- rows: List[dict] An iterable where an element is a dictionary. + mutual_exclusvity: bool + Check for mutual exclusivity before insert. Default True. Raises ------ @@ -284,12 +349,14 @@ def insert(cls, rows: list, **kwargs): If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ - cls._merge_insert(rows, **kwargs) + self._merge_insert(rows, mutual_exclusvity=mutual_exclusvity, **kwargs) @classmethod def merge_view(cls, restriction: dict = True): """Prints merged view, including null entries for unique columns. + Note: To handle this Union as a table-like object, use `merge_resrict` + Parameters --------- restriction: dict, optional @@ -309,7 +376,7 @@ def merge_html(cls, restriction: dict = True): return HTML(repr_html(cls._merge_repr(restriction=restriction))) @classmethod - def merge_restrict(cls, restriction: dict = True) -> dj.U: + def merge_restrict(cls, restriction: str = True) -> dj.U: """Given a restriction, return a merged view with restriction applied. Example @@ -371,7 +438,6 @@ def merge_delete_parent( kwargs: dict Additional keyword arguments for DataJoint delete. """ - part_parents = cls._merge_restrict_parents( restriction=restriction, as_objects=True, return_empties=False ) @@ -379,39 +445,60 @@ def merge_delete_parent( if dry_run: return part_parents - super().delete(cls(), **kwargs) - for part_parent in part_parents: - super().delete(part_parent, **kwargs) + with cls._safe_context(): + super().delete(cls(), **kwargs) + for part_parent in part_parents: + super().delete(part_parent, **kwargs) + + @classmethod + def fetch_nwb( + cls, restriction: str = True, multi_source=False, *attrs, **kwargs + ): + """Return the AnalysisNwbfile file linked in the source. - def fetch_nwb(self, *attrs, **kwargs): - part_parents = self._merge_restrict_parents( - restriction=self.restriction, return_empties=False + Parameters + ---------- + restriction: str, optional + Restriction to apply to parents before running fetch. Default none. + multi_source: bool + Return from multiple parents. Default False. + """ + part_parents = cls._merge_restrict_parents( + restriction=restriction, + return_empties=False, + add_invalid_restrict=False, ) - if len(part_parents) == 1: - return fetch_nwb( - part_parents[0], - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - else: + if not multi_source and len(part_parents) != 1: raise ValueError( - f"{len(part_parents)} possible sources found in Merge Table" - + part_parents + f"{len(part_parents)} possible sources found in Merge Table:" + + " and ".join([p.full_table_name for p in part_parents]) ) + nwbs = [] + for part_parent in part_parents: + nwbs.extend( + fetch_nwb( + part_parent, + (AnalysisNwbfile, "analysis_file_abs_path"), + *attrs, + **kwargs, + ) + ) + return nwbs + + @classmethod def merge_get_part( - self, - restriction: dict = True, + cls, + restriction: str = True, join_master: bool = False, restrict_part=True, + multi_source=False, ) -> dj.Table: """Retrieve part table from a restricted Merge table. - Note: This returns the whole unrestricted part table. The provided - restriction is only used to identify the relevant part as a native - table. + Note: unlike other Merge Table methods, returns the native table, not + a FreeTable Parameters ---------- @@ -423,7 +510,13 @@ def merge_get_part( restrict_part: bool Apply restriction to part. Default True. If False, return the native part table. + multi_source: bool + Return multiple parts. Default False. + Returns + ------ + Union[dj.Table, List[dj.Table]] + Native part table(s) of Merge. If `multi_source`, returns list. Example ------- @@ -433,42 +526,53 @@ def merge_get_part( Raises ------ ValueError - If multiple sources are found, lists and suggests restricting + If multiple sources are found, but not expected lists and suggests + restricting """ - sources = [ to_camel_case(n.split("__")[-1].strip("`")) # friendly part name - for n in self._merge_restrict_parts( - restriction=restriction, as_objects=False, return_empties=False + for n in cls._merge_restrict_parts( + restriction=restriction, + as_objects=False, + return_empties=False, + add_invalid_restrict=False, ) ] - if len(sources) != 1: + if not multi_source and len(sources) != 1: raise ValueError( f"Found multiple potential parts: {sources}\n\t" - + "Try adding a restriction before invoking `get_part`." + + "Try adding a restriction before invoking `get_part`.\n\t" + + "Or permitting multiple sources with `multi_source=True`." ) - part = ( - getattr(self, sources[0])().restrict(restriction) + parts = [ + getattr(cls, source)().restrict(restriction) if restrict_part # Re-apply restriction or don't - else getattr(self, sources[0])() - ) + else getattr(cls, source)() + for source in sources + ] + if join_master: + parts = [cls * part for part in parts] - return self * part if join_master else part + return parts if multi_source else parts[0] @classmethod def merge_get_parent( - self, restriction: dict = True, join_master: bool = False - ) -> list: + cls, + restriction: str = True, + join_master: bool = False, + multi_source=False, + ) -> dj.FreeTable: """Returns a list of part parents with restrictions applied. Rather than part tables, we look at parents of those parts, the source - of the data. + of the data, and only the rows that have keys inserted in the merge + table. Parameters ---------- - restriction: dict + restriction: str Optional restriction to apply before determining parent to return. Default True. join_master: bool @@ -476,30 +580,38 @@ def merge_get_parent( Returns ------ - list - list of datajoint tables, parents of parts of Merge Table + dj.FreeTable + Parent of parts of Merge Table as FreeTable. """ - part_parents = self._merge_restrict_parents( - restriction=restriction, as_objects=True, return_empties=False + + part_parents = cls._merge_restrict_parents( + restriction=restriction, + as_objects=True, + return_empties=False, + add_invalid_restrict=False, ) - if len(part_parents) != 1: + if not multi_source and len(sources) != 1: raise ValueError( f"Found multiple potential parents: {part_parents}\n\t" - + "Try adding a restriction when invoking `get_parent`." + + "Try adding a string restriction when invoking `get_parent`." + + "Or permitting multiple sources with `multi_source=True`." ) - if join_master: # Alt: Master * Part shows source - return self * part_parents[0] - else: # Current default aligns with func name - return part_parents[0] + if join_master: + part_parents = [cls * part for part in parts] + + return part_parents if multi_source else part_parents[0] @classmethod - def merge_fetch(cls, *attrs, **kwargs) -> list: + def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: """Perform a fetch across all parts. If >1 result, return as a list. Parameters ---------- + restriction: str + Optional restriction to apply before determining parent to return. + Default True. attrs, kwargs arguments passed to DataJoint `fetch` call @@ -509,8 +621,11 @@ def merge_fetch(cls, *attrs, **kwargs) -> list: Table contents, with type determined by kwargs """ results = [] - parts = cls()._merge_restrict_parts( - restriction=cls._restriction, return_empties=False + parts = self()._merge_restrict_parts( + restriction=restriction, + as_objects=True, + return_empties=False, + add_invalid_restrict=False, ) for part in parts: @@ -526,8 +641,22 @@ def merge_fetch(cls, *attrs, **kwargs) -> list: # for recarray, pd.DataFrame, or dict, and fetched contents differ if # attrs or "KEY" called. Intercept format, merge, and then transform? + if not results: + print( + "No merge_fetch results.\n\t" + + "If not restriction, try: `M.merge_fetch(True,'attr')\n\t" + + "If restricting by source, use dict: " + + "`M.merge_fetch({'source':'X'})" + ) return results[0] if len(results) == 1 else results + @classmethod + def merge_populate(source: str, key=None): + raise NotImplementedError( + "CBroz: In the future, this command will support executing " + + "part_parent `make` and then inserting all entries into Merge" + ) + _Merge = Merge @@ -558,6 +687,8 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ + restriction = AndList((table.restriction, restriction)) + if not restriction: restriction = True