From 6481c1b7c561213653f16617a32c451bcb1ce607 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 21 Aug 2023 13:20:42 -0700 Subject: [PATCH 1/9] Selective fetch from cbroz/master --- src/spyglass/common/common_behav.py | 251 ++++++++++---- src/spyglass/common/common_nwbfile.py | 35 +- src/spyglass/position/v1/dlc_utils.py | 28 +- .../position/v1/position_trodes_position.py | 314 +++++++++++------- src/spyglass/settings.py | 164 +++++++-- src/spyglass/utils/nwb_helper_fn.py | 289 +++++++++++----- tests/conftest.py | 5 +- 7 files changed, 759 insertions(+), 327 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 4881e40a5..bb86390e9 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -1,5 +1,6 @@ import os import pathlib +from functools import reduce from typing import Dict import datajoint as dj @@ -29,50 +30,110 @@ class PositionSource(dj.Manual): -> Session -> IntervalList --- - source: varchar(200) # source of data; current options are "trodes" and "dlc" (deep lab cut) - import_file_name: varchar(2000) # path to import file if importing position data + source: varchar(200) # source of data (e.g., trodes, dlc) + import_file_name: varchar(2000) # path to import file if importing """ + class SpatialSeries(dj.Part): + definition = """ + -> master + id : int unsigned # index of spatial series + --- + name=null: varchar(32) # name of spatial series + """ + @classmethod def insert_from_nwbfile(cls, nwb_file_name): - """Given an NWB file name, get the spatial series and interval lists from the file, add the interval - lists to the IntervalList table, and populate the RawPosition table if possible. + """Add intervals to ItervalList and PositionSource. + + Given an NWB file name, get the spatial series and interval lists from + the file, add the interval lists to the IntervalList table, and + populate the RawPosition table if possible. Parameters ---------- nwb_file_name : str The name of the NWB file. """ - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) + nwbf = get_nwb_file(nwb_file_name) + all_pos = get_all_spatial_series(nwbf, verbose=True) + sess_key = dict(nwb_file_name=nwb_file_name) + src_key = dict(**sess_key, source="trodes", import_file_name="") + + if all_pos is None: + return + + sources = [] + intervals = [] + spat_series = [] - pos_dict = get_all_spatial_series(nwbf, verbose=True) - if pos_dict is not None: - for epoch in pos_dict: - pdict = pos_dict[epoch] - pos_interval_list_name = cls.get_pos_interval_name(epoch) - - # create the interval list and insert it - interval_dict = dict() - interval_dict["nwb_file_name"] = nwb_file_name - interval_dict["interval_list_name"] = pos_interval_list_name - interval_dict["valid_times"] = pdict["valid_times"] - IntervalList().insert1(interval_dict, skip_duplicates=True) - - # add this interval list to the table - key = dict() - key["nwb_file_name"] = nwb_file_name - key["interval_list_name"] = pos_interval_list_name - key["source"] = "trodes" - key["import_file_name"] = "" - cls.insert1(key) + for epoch, epoch_list in all_pos.items(): + ind_key = dict(interval_list_name=cls.get_pos_interval_name(epoch)) + + sources.append(dict(**src_key, **ind_key)) + intervals.append( + dict( + **sess_key, + **ind_key, + valid_times=epoch_list[0]["valid_times"], + ) + ) + + for index, pdict in enumerate(epoch_list): + spat_series.append( + dict( + **sess_key, + **ind_key, + id=index, + name=pdict.get("name"), + ) + ) + + with cls.connection.transaction: + IntervalList.insert(intervals) + cls.insert(sources) + cls.SpatialSeries.insert(spat_series) # make map from epoch intervals to position intervals populate_position_interval_map_session(nwb_file_name) @staticmethod - def get_pos_interval_name(pos_epoch_num): - return f"pos {pos_epoch_num} valid times" + def get_pos_interval_name(epoch_num: int) -> str: + """Return string of the interval name from the epoch number. + + Parameters + ---------- + pos_epoch_num : int + Input epoch number + + Returns + ------- + str + Position interval name (e.g., pos 2 valid times) + """ + try: + int(epoch_num) + except ValueError: + raise ValueError( + f"Epoch number must must be an integer. Received: {epoch_num}" + ) + return f"pos {epoch_num} valid times" + + @staticmethod + def get_epoch_num(name: str) -> int: + """Return the epoch number from the interval name. + + Parameters + ---------- + name : str + Name of position interval (e.g., pos epoch 1 index 2 valid times) + + Returns + ------- + int + epoch number + """ + return int(name.replace("pos ", "").replace(" valid times", "")) @schema @@ -89,35 +150,77 @@ class RawPosition(dj.Imported): definition = """ -> PositionSource - --- - raw_position_object_id: varchar(40) # the object id of the spatial series for this epoch in the NWB file """ + class Object(dj.Part): + definition = """ + -> master + -> PositionSource.SpatialSeries.proj('id') + --- + raw_position_object_id: varchar(40) # id of spatial series in NWB file + """ + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs + ) + + def fetch1_dataframe(self): + INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) + + id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()] + + df_list = [ + pd.DataFrame( + data=rp.data, + index=pd.Index(rp.timestamps, name="time"), + columns=[ + col # use existing columns if already numbered + if "1" in rp.description or "2" in rp.description + # else number them by id + else col + str(id + INDEX_ADJUST) + for col in rp.description.split(", ") + ], + ) + for id, rp in id_rp + ] + + return reduce(lambda x, y: pd.merge(x, y, on="time"), df_list) + def make(self, key): nwb_file_name = key["nwb_file_name"] - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) - - # TODO refactor this. this calculates sampling rate (unused here) and is expensive to do twice - pos_dict = get_all_spatial_series(nwbf) - for epoch in pos_dict: - if key[ - "interval_list_name" - ] == PositionSource.get_pos_interval_name(epoch): - pdict = pos_dict[epoch] - key["raw_position_object_id"] = pdict["raw_position_object_id"] - self.insert1(key) - break + interval_list_name = key["interval_list_name"] + + nwbf = get_nwb_file(nwb_file_name) + indices = (PositionSource.SpatialSeries & key).fetch("id") + + # incl_times = False -> don't do extra processing for valid_times + spat_objs = get_all_spatial_series(nwbf, incl_times=False)[ + PositionSource.get_epoch_num(interval_list_name) + ] + + self.insert1(key) + self.Object.insert( + [ + dict( + nwb_file_name=nwb_file_name, + interval_list_name=interval_list_name, + id=index, + raw_position_object_id=obj["raw_position_object_id"], + ) + for index, obj in enumerate(spat_objs) + if index in indices + ] + ) def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) + raise NotImplementedError( + "fetch_nwb now operates on RawPosition.Object" + ) def fetch1_dataframe(self): - raw_position_nwb = self.fetch_nwb()[0]["raw_position"] - return pd.DataFrame( - data=raw_position_nwb.data, - index=pd.Index(raw_position_nwb.timestamps, name="time"), - columns=raw_position_nwb.description.split(", "), + raise NotImplementedError( + "fetch1_dataframe now operates on RawPosition.Object" ) @@ -130,7 +233,7 @@ class StateScriptFile(dj.Imported): """ def make(self, key): - """Add a new row to the StateScriptFile table. Requires keys "nwb_file_name", "file_object_id".""" + """Add a new row to the StateScriptFile table.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -140,8 +243,8 @@ def make(self, key): ) or nwbf.processing.get("associated files") if associated_files is None: print( - f'Unable to import StateScriptFile: no processing module named "associated_files" ' - f"found in {nwb_file_name}." + "Unable to import StateScriptFile: no processing module named " + + '"associated_files" found in {nwb_file_name}.' ) return @@ -150,13 +253,16 @@ def make(self, key): associated_file_obj, ndx_franklab_novela.AssociatedFiles ): print( - f'Data interface {associated_file_obj.name} within "associated_files" processing module is not ' - f"of expected type ndx_franklab_novela.AssociatedFiles\n" + f"Data interface {associated_file_obj.name} within " + + '"associated_files" processing module is not ' + + "of expected type ndx_franklab_novela.AssociatedFiles\n" ) return + # parse the task_epochs string - # TODO update associated_file_obj.task_epochs to be an array of 1-based ints, - # not a comma-separated string of ints + # TODO: update associated_file_obj.task_epochs to be an array of + # 1-based ints, not a comma-separated string of ints + epoch_list = associated_file_obj.task_epochs.split(",") # only insert if this is the statescript file print(associated_file_obj.description) @@ -184,8 +290,9 @@ class VideoFile(dj.Imported): Notes ----- - The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is used. - If PTP is not used, the video timestamps come from videoTimeStamps.cameraHWFrameCount . + The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is + used. If PTP is not used, the video timestamps come from + videoTimeStamps.cameraHWFrameCount . """ @@ -198,6 +305,13 @@ class VideoFile(dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key, verbose=True): + if not self.connection.in_transaction: + self.populate(key) + return + nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -226,7 +340,9 @@ def make(self, key): if isinstance(video, pynwb.image.ImageSeries): video = [video] for video_obj in video: - # check to see if the times for this video_object are largely overlapping with the task epoch times + # check to see if the times for this video_object are largely + # overlapping with the task epoch times + if len( interval_list_contains(valid_times, video_obj.timestamps) > 0.9 * len(video_obj.timestamps) @@ -237,14 +353,18 @@ def make(self, key): key["camera_name"] = video_obj.device.camera_name else: raise KeyError( - f"No camera with camera_name: {camera_name} found in CameraDevice table." + f"No camera with camera_name: {camera_name} found " + + "in CameraDevice table." ) key["video_file_object_id"] = video_obj.object_id self.insert1(key) is_found = True - if not is_found: - print(f"No video found corresponding to epoch {interval_list_name}") + if not is_found and verbose: + print( + f"No video found corresponding to file {nwb_file_name}, " + + f"epoch {interval_list_name}" + ) def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) @@ -258,16 +378,17 @@ def update_entries(cls, restrict={}): video_nwb = (cls & row).fetch_nwb()[0] if len(video_nwb) != 1: raise ValueError( - f"expecting 1 video file per entry, but {len(video_nwb)} files found" + f"Expecting 1 video file per entry. {len(video_nwb)} found" ) row["camera_name"] = video_nwb[0]["video_file"].device.camera_name cls.update1(row=row) @classmethod def get_abs_path(cls, key: Dict): - """Return the absolute path for a stored video file given a key with the nwb_file_name and epoch number + """Return the absolute path for a stored video file given a key. - The SPYGLASS_VIDEO_DIR environment variable must be set. + Key must include the nwb_file_name and epoch number. The + SPYGLASS_VIDEO_DIR environment variable must be set. Parameters ---------- diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 493e98962..52c070d2b 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -12,7 +12,7 @@ import spikeinterface as si from hdmf.common import DynamicTable -from ..settings import load_config +from ..settings import raw_dir from ..utils.dj_helper_fn import get_child_tables from ..utils.nwb_helper_fn import get_electrode_indices, get_nwb_file @@ -69,24 +69,47 @@ def insert_from_relative_file_name(cls, nwb_file_name): key["nwb_file_abs_path"] = nwb_file_abs_path cls.insert1(key, skip_duplicates=True) - @staticmethod - def get_abs_path(nwb_file_name): + @classmethod + def _get_file_name(cls, nwb_file_name: str) -> str: + """Get valid nwb file name given substring.""" + query = cls & f'nwb_file_name LIKE "%{nwb_file_name}%"' + + if len(query) == 1: + return query.fetch1("nwb_file_name") + + raise ValueError( + f"Found {len(query)} matches for {nwb_file_name}: \n{query}" + ) + + @classmethod + def get_file_key(cls, nwb_file_name: str) -> dict: + """Return primary key using nwb_file_name substring.""" + return {"nwb_file_name": cls._get_file_name(nwb_file_name)} + + @classmethod + def get_abs_path(cls, nwb_file_name, new_file=False) -> str: """Return absolute path for a stored raw NWB file given file name. - The SPYGLASS_BASE_DIR environment variable must be set. + The SPYGLASS_BASE_DIR must be set, either as an environment or part of + dj.config['custom']. See spyglass.settings.load_config Parameters ---------- nwb_file_name : str - The name of an NWB file that has been inserted into the Nwbfile() schema. + The name of an NWB file that has been inserted into the Nwbfile() + table. May be file substring. May include % wildcard(s). + new_file : bool, optional + Adding a new file to Nwbfile table. Defaults to False. Returns ------- nwb_file_abspath : str The absolute path for the given file name. """ + if new_file: + return raw_dir + "/" + nwb_file_name - return load_config()["SPYGLASS_RAW_DIR"] + "/" + nwb_file_name + return raw_dir + "/" + cls._get_file_name(nwb_file_name) @staticmethod def add_to_lock(nwb_file_name): diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 84a5384d9..3abfb8f72 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -18,6 +18,8 @@ import pandas as pd from tqdm import tqdm as tqdm +from ...settings import raw_dir + def _set_permissions(directory, mode, username: str, groupname: str = None): """ @@ -321,13 +323,17 @@ def get_video_path(key): from ...common.common_behav import VideoFile - video_info = ( - VideoFile() - & {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} - ).fetch1() - nwb_path = ( - f"{os.getenv('SPYGLASS_BASE_DIR')}/raw/{video_info['nwb_file_name']}" - ) + vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} + VideoFile()._no_transaction_make(vf_key, verbose=False) + video_query = VideoFile & vf_key + + if len(video_query) != 1: + print(f"Found {len(video_query)} videos for {vf_key}") + return None, None, None, None + + video_info = video_query.fetch1() + nwb_path = f"{raw_dir}/{video_info['nwb_file_name']}" + with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out: nwb_file = in_out.read() nwb_video = nwb_file.objects[video_info["video_file_object_id"]] @@ -338,6 +344,7 @@ def get_video_path(key): video_filename = video_filepath.split(video_dir)[-1] meters_per_pixel = nwb_video.device.meters_per_pixel timestamps = np.asarray(nwb_video.timestamps) + return video_dir, video_filename, meters_per_pixel, timestamps @@ -526,7 +533,9 @@ def get_gpu_memory(): if subproccess command errors. """ - output_to_list = lambda x: x.decode("ascii").split("\n")[:-1] + def output_to_list(x): + return x.decode("ascii").split("\n")[:-1] + query_cmd = "nvidia-smi --query-gpu=memory.used --format=csv" try: memory_use_info = output_to_list( @@ -534,7 +543,8 @@ def get_gpu_memory(): )[1:] except subprocess.CalledProcessError as err: raise RuntimeError( - f"command {err.cmd} return with error (code {err.returncode}): {err.output}" + f"command {err.cmd} return with error (code {err.returncode}): " + + f"{err.output}" ) from err memory_use_values = { i: int(x.split()[0]) for i, x in enumerate(memory_use_info) diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 68f43526e..da7456232 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -41,12 +41,13 @@ class TrodesPosParams(dj.Manual): params: longblob """ - @classmethod - def insert_default(cls, **kwargs): - """ - Insert default parameter set for position determination - """ - params = { + @property + def default_pk(self): + return {"trodes_pos_params_name": "default"} + + @property + def default_params(self): + return { "max_separation": 9.0, "max_speed": 300.0, "position_smoothing_duration": 0.125, @@ -57,25 +58,29 @@ def insert_default(cls, **kwargs): "upsampling_sampling_rate": None, "upsampling_interpolation_method": "linear", } + + @classmethod + def insert_default(cls, **kwargs): + """ + Insert default parameter set for position determination + """ cls.insert1( - {"trodes_pos_params_name": "default", "params": params}, + {**cls().default_pk, "params": cls().default_params}, skip_duplicates=True, ) @classmethod def get_default(cls): - query = cls & {"trodes_pos_params_name": "default"} + query = cls & cls().default_pk if not len(query) > 0: cls().insert_default(skip_duplicates=True) - default = (cls & {"trodes_pos_params_name": "default"}).fetch1() - else: - default = query.fetch1() - return default + return (cls & cls().default_pk).fetch1() + + return query.fetch1() @classmethod def get_accepted_params(cls): - default = cls.get_default() - return list(default["params"].keys()) + return [k for k in cls().default_params.keys()] @schema @@ -88,9 +93,66 @@ class TrodesPosSelection(dj.Manual): definition = """ -> RawPosition -> TrodesPosParams - --- """ + @classmethod + def insert_with_default( + cls, + key: dict, + skip_duplicates: bool = False, + edit_defaults: dict = {}, + edit_name: str = None, + ) -> None: + """Insert key with default parameters. + + To change defaults, supply a dict as edit_defaults with a name for + the new paramset as edit_name. + + Parameters + ---------- + key: Union[dict, str] + Restriction uniquely identifying entr(y/ies) in RawPosition. + skip_duplicates: bool, optional + Skip duplicate entries. + edit_defauts: dict, optional + Dictionary of overrides to default parameters. + edit_name: str, optional + If edit_defauts is passed, the name of the new entry + + Raises + ------ + ValueError + Key does not identify any entries in RawPosition. + """ + query = RawPosition & key + if not query: + raise ValueError(f"Found no entries found for {key}") + + param_pk, param_name = list(TrodesPosParams().default_pk.items())[0] + + if bool(edit_defaults) ^ bool(edit_name): # XOR: only one of them + raise ValueError("Must specify both edit_defauts and edit_name") + + elif edit_defaults and edit_name: + TrodesPosParams.insert1( + { + param_pk: edit_name, + "params": { + **TrodesPosParams().default_params, + **edit_defaults, + }, + }, + skip_duplicates=skip_duplicates, + ) + + cls.insert( + [ + {**k, param_pk: edit_name or param_name} + for k in query.fetch("KEY", as_dict=True) + ], + skip_duplicates=skip_duplicates, + ) + @schema class TrodesPosV1(dj.Computed): @@ -108,111 +170,112 @@ class TrodesPosV1(dj.Computed): """ def make(self, key): - orig_key = copy.deepcopy(key) + METERS_PER_CM = 0.01 + print(f"Computing position for: {key}") + + orig_key = copy.deepcopy(key) key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) - raw_position = (RawPosition() & key).fetch_nwb()[0] + position_info_parameters = (TrodesPosParams() & key).fetch1("params") + spatial_series = (RawPosition.Object & key).fetch_nwb()[0][ + "raw_position" + ] + spatial_df = (RawPosition.Object & key).fetch1_dataframe() + video_frame_ind = getattr(spatial_df, "video_frame_ind", None) + position = pynwb.behavior.Position() orientation = pynwb.behavior.CompassDirection() velocity = pynwb.behavior.BehavioralTimeSeries() - METERS_PER_CM = 0.01 - raw_pos_df = pd.DataFrame( - data=raw_position["raw_position"].data, - index=pd.Index( - raw_position["raw_position"].timestamps, name="time" + # NOTE: CBroz1 removed a try/except ValueError that surrounded all + # .create_X_series methods. dpeg22 could not recall purpose + + position_info = self.calculate_position_info_from_spatial_series( + spatial_df=spatial_df, + meters_to_pixels=spatial_series.conversion, + **position_info_parameters, + ) + + time_comments = dict( + comments=spatial_series.comments, + timestamps=position_info["time"], + ) + time_comments_ref = dict( + **time_comments, + reference_frame=spatial_series.reference_frame, + ) + + # create nwb objects for insertion into analysis nwb file + position.create_spatial_series( + name="position", + conversion=METERS_PER_CM, + data=position_info["position"], + description="x_position, y_position", + **time_comments_ref, + ) + + orientation.create_spatial_series( + name="orientation", + conversion=1.0, + data=position_info["orientation"], + description="orientation", + **time_comments_ref, + ) + + velocity.create_timeseries( + name="velocity", + conversion=METERS_PER_CM, + unit="m/s", + data=np.concatenate( + ( + position_info["velocity"], + position_info["speed"][:, np.newaxis], + ), + axis=1, ), - columns=raw_position["raw_position"].description.split(", "), + description="x_velocity, y_velocity, speed", + **time_comments, ) - try: - # calculate the processed position - spatial_series = raw_position["raw_position"] - position_info = self.calculate_position_info_from_spatial_series( - spatial_series, - position_info_parameters["max_separation"], - position_info_parameters["max_speed"], - position_info_parameters["speed_smoothing_std_dev"], - position_info_parameters["position_smoothing_duration"], - position_info_parameters["orient_smoothing_std_dev"], - position_info_parameters["led1_is_front"], - position_info_parameters["is_upsampled"], - position_info_parameters["upsampling_sampling_rate"], - position_info_parameters["upsampling_interpolation_method"], - ) - # create nwb objects for insertion into analysis nwb file - position.create_spatial_series( - name="position", - timestamps=position_info["time"], - conversion=METERS_PER_CM, - data=position_info["position"], - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, - description="x_position, y_position", - ) - orientation.create_spatial_series( - name="orientation", - timestamps=position_info["time"], - conversion=1.0, - data=position_info["orientation"], - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, - description="orientation", + if video_frame_ind: + velocity.create_timeseries( + name="video_frame_ind", + unit="index", + data=spatial_df.video_frame_ind.to_numpy(), + description="video_frame_ind", + **time_comments, + ) + else: + print( + "No video frame index found. Assuming all camera frames " + + "are present." ) - velocity.create_timeseries( - name="velocity", - timestamps=position_info["time"], - conversion=METERS_PER_CM, - unit="m/s", - data=np.concatenate( - ( - position_info["velocity"], - position_info["speed"][:, np.newaxis], - ), - axis=1, - ), - comments=spatial_series.comments, - description="x_velocity, y_velocity, speed", + name="video_frame_ind", + unit="index", + data=np.arange(len(position_info["time"])), + description="video_frame_ind", + **time_comments, ) - try: - velocity.create_timeseries( - name="video_frame_ind", - unit="index", - timestamps=position_info["time"], - data=raw_pos_df.video_frame_ind.to_numpy(), - description="video_frame_ind", - comments=spatial_series.comments, - ) - except AttributeError: - print( - "No video frame index found. Assuming all camera frames are present." - ) - velocity.create_timeseries( - name="video_frame_ind", - unit="index", - timestamps=position_info["time"], - data=np.arange(len(position_info["time"])), - description="video_frame_ind", - comments=spatial_series.comments, - ) - except ValueError: - pass # Insert into analysis nwb file nwb_analysis_file = AnalysisNwbfile() - key["position_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], position - ) - key["orientation_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], orientation - ) - key["velocity_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], velocity + key.update( + dict( + position_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], position + ), + orientation_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], orientation + ), + velocity_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], velocity + ), + ) ) AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) @@ -222,6 +285,7 @@ def make(self, key): from ..position_merge import PositionOutput part_name = to_camel_case(self.table_name.split("__")[-1]) + # TODO: The next line belongs in a merge table function PositionOutput._merge_insert( [orig_key], part_name=part_name, skip_duplicates=True @@ -229,9 +293,10 @@ def make(self, key): @staticmethod def calculate_position_info_from_spatial_series( - spatial_series, - max_LED_separation, - max_plausible_speed, + spatial_df: pd.DataFrame, + meters_to_pixels: float, + max_separation, + max_speed, speed_smoothing_std_dev, position_smoothing_duration, orient_smoothing_std_dev, @@ -240,16 +305,25 @@ def calculate_position_info_from_spatial_series( upsampling_sampling_rate, upsampling_interpolation_method, ): + """Calculate position info from 2D spatial series.""" CM_TO_METERS = 100 + # Accepts x/y 'loc' or 'loc1' format for first pos. Renames to 'loc' + DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2", "xloc1", "yloc1"] + + cols = list(spatial_df.columns) + if len(cols) != 4 or not all([c in DEFAULT_COLS for c in cols]): + choice = dj.utils.user_choice( + f"Unexpected columns in raw position. Assume " + + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + ) + if choice.lower() not in ["yes", "y"]: + raise ValueError(f"Unexpected columns in raw position: {cols}") + # rename first 4 columns, keep rest. Rest dropped below + spatial_df.columns = DEFAULT_COLS[:4] + cols[4:] # Get spatial series properties - time = np.asarray(spatial_series.timestamps) # seconds - position = np.asarray( - pd.DataFrame( - spatial_series.data, - columns=spatial_series.description.split(", "), - ).loc[:, ["xloc", "yloc", "xloc2", "yloc2"]] - ) # meters + time = np.asarray(spatial_df.index) # seconds + position = np.asarray(spatial_df.iloc[:, :4]) # meters # remove NaN times is_nan_time = np.isnan(time) @@ -258,7 +332,6 @@ def calculate_position_info_from_spatial_series( dt = np.median(np.diff(time)) sampling_rate = 1 / dt - meters_to_pixels = spatial_series.conversion # Define LEDs if led1_is_front: @@ -274,7 +347,7 @@ def calculate_position_info_from_spatial_series( # Set points to NaN where the front and back LEDs are too separated dist_between_LEDs = get_distance(back_LED, front_LED) - is_too_separated = dist_between_LEDs >= max_LED_separation + is_too_separated = dist_between_LEDs >= max_separation back_LED[is_too_separated] = np.nan front_LED[is_too_separated] = np.nan @@ -294,8 +367,8 @@ def calculate_position_info_from_spatial_series( ) # Set to points to NaN where the speed is too fast - is_too_fast = (front_LED_speed > max_plausible_speed) | ( - back_LED_speed > max_plausible_speed + is_too_fast = (front_LED_speed > max_speed) | ( + back_LED_speed > max_speed ) back_LED[is_too_fast] = np.nan front_LED[is_too_fast] = np.nan @@ -446,7 +519,8 @@ class TrodesPosVideo(dj.Computed): definition = """ -> TrodesPosV1 - --- + --- + has_video : bool """ def make(self, key): @@ -454,7 +528,7 @@ def make(self, key): print("Loading position data...") raw_position_df = ( - RawPosition() + RawPosition.Object & { "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["interval_list_name"], @@ -480,6 +554,11 @@ def make(self, key): ) = get_video_path( {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ) + + if not video_path: + self.insert1(dict(**key, has_video=False)) + return + video_dir = os.path.dirname(video_path) + "/" video_path = check_videofile( video_path=video_dir, video_filename=video_filename @@ -513,6 +592,7 @@ def make(self, key): cm_to_pixels=cm_per_pixel, disable_progressbar=False, ) + self.insert1(dict(**key, has_video=True)) @staticmethod def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 9e5305831..4ddf7d53a 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -20,17 +20,18 @@ sorting="spikesorting", # "SPYGLASS_SORTING_DIR" waveforms="waveforms", temp="tmp", + video="video", ), kachery=dict( - cloud="kachery-storage", - storage="kachery-storage", + cloud="kachery_storage", + storage="kachery_storage", temp="tmp", ), ) def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: - """Gets syglass dirs from dj.config or environment variables. + """Gets Spyglass dirs from dj.config or environment variables. Uses a relative_dirs dict defined in settings.py to (a) gather user settings from dj.config or os environment variables or defaults relative to @@ -62,15 +63,17 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: dj_spyglass = dj.config.get("custom", {}).get("spyglass_dirs", {}) dj_kachery = dj.config.get("custom", {}).get("kachery_dirs", {}) + resolved_base = ( base_dir or dj_spyglass.get("base") - or os.environ.get("SPYGLASS_BASE_DIR", ".") + or os.environ.get("SPYGLASS_BASE_DIR") ) if not resolved_base: raise ValueError( "SPYGLASS_BASE_DIR not defined in dj.config or os env vars" ) + config_dirs = {"SPYGLASS_BASE_DIR": resolved_base} for prefix, dirs in relative_dirs.items(): for dir, dir_str in dirs.items(): @@ -85,19 +88,64 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: ).replace('"', "") config_dirs.update({dir_env_fmt: dir_location}) - _set_env_with_dict(config_dirs) + kachery_zone_dict = { + "KACHERY_ZONE": ( + os.environ.get("KACHERY_ZONE") + or dj.config.get("custom", {}).get("kachery_zone") + or "franklab.default" + ) + } + + loaded_env = _load_env_vars(env_defaults) + _set_env_with_dict({**config_dirs, **kachery_zone_dict, **loaded_env}) _mkdirs_from_dict_vals(config_dirs) _set_dj_config_stores(config_dirs) config = dict( - **config_defaults, - **config_dirs, + **config_defaults, **config_dirs, **kachery_zone_dict, **loaded_env ) config_loaded = True return config -def base_dir() -> str: +def _load_env_vars(env_dict: dict) -> dict: + """Loads env vars from dict {str: Any}.""" + loaded_dict = {} + for var, val in env_dict.items(): + loaded_dict[var] = os.getenv(var, val) + return loaded_dict + + +def _set_env_with_dict(env_dict: dict): + """Sets env vars from dict {str: Any} where Any is convertible to str.""" + for var, val in env_dict.items(): + os.environ[var] = str(val) + + +def _mkdirs_from_dict_vals(dir_dict: dict): + for dir_str in dir_dict.values(): + Path(dir_str).mkdir(exist_ok=True) + + +def _set_dj_config_stores(dir_dict: dict): + raw_dir = dir_dict["SPYGLASS_RAW_DIR"] + analysis_dir = dir_dict["SPYGLASS_ANALYSIS_DIR"] + + dj.config["stores"] = { + "raw": { + "protocol": "file", + "location": str(raw_dir), + "stage": str(raw_dir), + }, + "analysis": { + "protocol": "file", + "location": str(analysis_dir), + "stage": str(analysis_dir), + }, + } + + +def load_base_dir() -> str: """Retrieve the base directory from the configuration. Returns @@ -111,13 +159,13 @@ def base_dir() -> str: return config.get("SPYGLASS_BASE_DIR") -def raw_dir() -> str: - """Retrieve the base directory from the configuration. +def load_raw_dir() -> str: + """Retrieve the raw directory from the configuration. Returns ------- str - The base directory path. + The raw directory path. """ global config if not config_loaded or not config: @@ -125,31 +173,79 @@ def raw_dir() -> str: return config.get("SPYGLASS_RAW_DIR") -def _set_env_with_dict(env_dict: dict): - """Sets env vars from dict {str: Any} where Any is convertible to str.""" - env_to_set = {**env_defaults, **env_dict} - for var, val in env_to_set.items(): - os.environ[var] = str(val) +def load_analysis_dir() -> str: + """Retrieve the analysis directory from the configuration. + Returns + ------- + str + The recording directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_ANALYSIS_DIR") -def _mkdirs_from_dict_vals(dir_dict: dict): - for dir_str in dir_dict.values(): - Path(dir_str).mkdir(exist_ok=True) +def load_recording_dir() -> str: + """Retrieve the recording directory from the configuration. -def _set_dj_config_stores(dir_dict: dict): - raw_dir = dir_dict["SPYGLASS_RAW_DIR"] - analysis_dir = dir_dict["SPYGLASS_ANALYSIS_DIR"] + Returns + ------- + str + The recording directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_RECORDING_DIR") - dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - "analysis": { - "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), - }, - } + +def load_sorting_dir() -> str: + """Retrieve the sorting directory from the configuration. + + Returns + ------- + str + The sorting directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_SORTING_DIR") + + +def load_temp_dir() -> str: + """Retrieve the temp directory from the configuration. + + Returns + ------- + str + The temp directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_TEMP_DIR") + + +def load_waveform_dir() -> str: + """Retrieve the temp directory from the configuration. + + Returns + ------- + str + The temp directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_TEMP_DIR") + + +base_dir = load_base_dir() +raw_dir = load_raw_dir() +recording_dir = load_recording_dir() +temp_dir = load_temp_dir() +analysis_dir = load_analysis_dir() +sorting_dir = load_sorting_dir() diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 58b9b4696..835dfedc7 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -3,6 +3,7 @@ import os import os.path import warnings +from itertools import groupby from pathlib import Path import numpy as np @@ -21,31 +22,41 @@ def get_nwb_file(nwb_file_path): """Return an NWBFile object with the given file path in read mode. - If the file is not found locally, this will check if it has been shared with kachery and if so, download it and open it. + + If the file is not found locally, this will check if it has been shared + with kachery and if so, download it and open it. + Parameters ---------- nwb_file_path : str - Path to the NWB file. + Path to the NWB file or NWB file name. If it does not start with a "/", + get path with Nwbfile.get_abs_path Returns ------- nwbfile : pynwb.NWBFile NWB file object for the given path opened in read mode. """ + if not nwb_file_path.startswith("/"): + from ..common import Nwbfile + + nwb_file_path = Nwbfile.get_abs_path(nwb_file_path) + _, nwbfile = __open_nwb_files.get(nwb_file_path, (None, None)) - nwb_uri = None - nwb_raw_uri = None + if nwbfile is None: # check to see if the file exists if not os.path.exists(nwb_file_path): print( - f"NWB file {nwb_file_path} does not exist locally; checking kachery" + "NWB file not found locally; checking kachery for " + + f"{nwb_file_path}" ) # first try the analysis files from ..sharing.sharing_kachery import AnalysisNwbfileKachery - # the download functions assume just the filename, so we need to get that from the path + # the download functions assume just the filename, so we need to + # get that from the path if not AnalysisNwbfileKachery.download_file( os.path.basename(nwb_file_path) ): @@ -98,7 +109,8 @@ def close_nwb_files(): def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): - """Search for a specified NWBDataInterface or DynamicTable in the processing modules of an NWB file. + """ + Search for NWBDataInterface or DynamicTable in processing modules of an NWB. Parameters ---------- @@ -107,13 +119,15 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): data_interface_name : str The name of the NWBDataInterface or DynamicTable to search for. data_interface_class : type, optional - The class (or superclass) to search for. This argument helps to prevent accessing an object with the same - name but the incorrect type. Default: no restriction. + The class (or superclass) to search for. This argument helps to prevent + accessing an object with the same name but the incorrect type. Default: + no restriction. Warns ----- UserWarning - If multiple NWBDataInterface and DynamicTable objects with the matching name are found. + If multiple NWBDataInterface and DynamicTable objects with the matching + name are found. Returns ------- @@ -132,19 +146,21 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): if len(ret) > 1: warnings.warn( f"Multiple data interfaces with name '{data_interface_name}' " - f"found in NWBFile with identifier {nwbfile.identifier}. Using the first one found. " + f"found in NWBFile with identifier {nwbfile.identifier}. " + + "Using the first one found. " "Use the data_interface_class argument to restrict the search." ) if len(ret) >= 1: return ret[0] - else: - return None + + return None def get_raw_eseries(nwbfile): """Return all ElectricalSeries in the acquisition group of an NWB file. - ElectricalSeries found within LFP objects in the acquisition will also be returned. + ElectricalSeries found within LFP objects in the acquisition will also be + returned. Parameters ---------- @@ -165,49 +181,73 @@ def get_raw_eseries(nwbfile): return ret -def estimate_sampling_rate(timestamps, multiplier): +def estimate_sampling_rate( + timestamps, multiplier=1.75, verbose=False, filename="file" +): """Estimate the sampling rate given a list of timestamps. - Assumes that the most common temporal differences between timestamps approximate the sampling rate. Note that this - can fail for very high sampling rates and irregular timestamps. + Assumes that the most common temporal differences between timestamps + approximate the sampling rate. Note that this can fail for very high + sampling rates and irregular timestamps. Parameters ---------- timestamps : numpy.ndarray 1D numpy array of timestamp values. - multiplier : float or int + multiplier : float or int, optional + Deft + verbose : bool, optional + Print sampling rate to stdout. Default, False + filename : str, optional + Filename to reference when printing or err. Default, "file" Returns ------- estimated_rate : float The estimated sampling rate. + + Raises + ------ + ValueError + If estimated rate is less than 0. """ # approach: # 1. use a box car smoother and a histogram to get the modal value - # 2. identify adjacent samples as those that have a time difference < the multiplier * the modal value + # 2. identify adjacent samples as those that have a + # time difference < the multiplier * the modal value # 3. average the time differences between adjacent samples + sample_diff = np.diff(timestamps[~np.isnan(timestamps)]) + if len(sample_diff) < 10: raise ValueError( f"Only {len(sample_diff)} timestamps are valid. Check the data." ) - nsmooth = 10 - smoother = np.ones(nsmooth) / nsmooth - smooth_diff = np.convolve(sample_diff, smoother, mode="same") - # we histogram with 100 bins out to 3 * mean, which should be fine for any reasonable number of samples + smooth_diff = np.convolve(sample_diff, np.ones(10) / 10, mode="same") + + # we histogram with 100 bins out to 3 * mean, which should be fine for any + # reasonable number of samples hist, bins = np.histogram( smooth_diff, bins=100, range=[0, 3 * np.mean(smooth_diff)] ) - mode = bins[np.where(hist == np.max(hist))] + mode = bins[np.where(hist == np.max(hist))][0] + + adjacent = sample_diff < mode * multiplier + + sampling_rate = np.round(1.0 / np.mean(sample_diff[adjacent])) - adjacent = sample_diff < mode[0] * multiplier - return np.round(1.0 / np.mean(sample_diff[adjacent])) + if sampling_rate < 0: + raise ValueError(f"Error calculating sampling rate. For {filename}") + if verbose: + print(f"Estimated sampling rate for {filename}: {sampling_rate} Hz") + + return sampling_rate def get_valid_intervals( - timestamps, sampling_rate, gap_proportion, min_valid_len + timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=None ): """Finds the set of all valid intervals in a list of timestamps. Valid interval: (start time, stop time) during which there are @@ -219,13 +259,13 @@ def get_valid_intervals( 1D numpy array of timestamp values. sampling_rate : float Sampling rate of the data. - gap_proportion : float, greater than 1; unit: samples - Threshold for detecting a gap; - i.e. if the difference (in samples) between - consecutive timestamps exceeds gap_proportion, - it is considered a gap - min_valid_len : float - Length of smallest valid interval. + gap_proportion : float, optional + Threshold for detecting a gap; i.e. if the difference (in samples) + between consecutive timestamps exceeds gap_proportion, it is considered + a gap. Must be > 1. Default to 2.5 + min_valid_len : float, optional + Length of smallest valid interval. Default to sampling_rate. If greater + than interval duration, print warning and use half the total time. Returns ------- @@ -235,6 +275,15 @@ def get_valid_intervals( eps = 0.0000001 + if not min_valid_len: + min_valid_len = int(sampling_rate) + + total_time = timestamps[-1] - timestamps[0] + if total_time < min_valid_len: + half_total_time = total_time / 2 + print(f"WARNING: Setting minimum valid interval to {half_total_time}") + min_valid_len = half_total_time + # get rid of NaN elements timestamps = timestamps[~np.isnan(timestamps)] # find gaps @@ -261,16 +310,20 @@ def get_valid_intervals( def get_electrode_indices(nwb_object, electrode_ids): - """Given an NWB file or electrical series object, return the indices of the specified electrode_ids. + """Return indices of the specified electrode_ids given an NWB file. - If an ElectricalSeries is given, then the indices returned are relative to the selected rows in - ElectricalSeries.electrodes. For example, if electricalseries.electrodes = [5], and row index 5 of - nwbfile.electrodes has ID 10, then calling get_electrode_indices(electricalseries, 10) will return 0, the - index of the matching electrode in electricalseries.electrodes. + Also accepts electrical series object. If an ElectricalSeries is given, + then the indices returned are relative to the selected rows in + ElectricalSeries.electrodes. For example, if electricalseries.electrodes = + [5], and row index 5 of nwbfile.electrodes has ID 10, then calling + get_electrode_indices(electricalseries, 10) will return 0, the index of the + matching electrode in electricalseries.electrodes. - Indices for electrode_ids that are not in the electrical series are returned as np.nan + Indices for electrode_ids that are not in the electrical series are + returned as np.nan - If an NWBFile is given, then the row indices with the matching IDs in the file's electrodes table are returned. + If an NWBFile is given, then the row indices with the matching IDs in the + file's electrodes table are returned. Parameters ---------- @@ -285,8 +338,9 @@ def get_electrode_indices(nwb_object, electrode_ids): Array of indices of the specified electrode IDs. """ if isinstance(nwb_object, pynwb.ecephys.ElectricalSeries): - # electrodes is a DynamicTableRegion which may contain a subset of the rows in NWBFile.electrodes - # match against only the subset of electrodes referenced by this ElectricalSeries + # electrodes is a DynamicTableRegion which may contain a subset of the + # rows in NWBFile.electrodes match against only the subset of + # electrodes referenced by this ElectricalSeries electrode_table_indices = nwb_object.electrodes.data[:] selected_elect_ids = [ nwb_object.electrodes.table.id[x] for x in electrode_table_indices @@ -299,7 +353,9 @@ def get_electrode_indices(nwb_object, electrode_ids): "nwb_object must be of type ElectricalSeries or NWBFile" ) - # for each electrode_id, find its index in selected_elect_ids and return that if it's there and invalid_electrode_index if not. + # for each electrode_id, find its index in selected_elect_ids and return + # that if it's there and invalid_electrode_index if not. + return [ selected_elect_ids.index(elect_id) if elect_id in selected_elect_ids @@ -308,8 +364,79 @@ def get_electrode_indices(nwb_object, electrode_ids): ] -def get_all_spatial_series(nwbf, verbose=False): - """Given an NWBFile, get the spatial series and interval lists from the file and return a dictionary by epoch. +def _get_epoch_groups(position: pynwb.behavior.Position): + epoch_start_time = {} + for pos_epoch, spatial_series in enumerate( + position.spatial_series.values() + ): + epoch_start_time[pos_epoch] = spatial_series.timestamps[0] + + return { + i: [j[0] for j in j] + for i, j in groupby( + sorted(epoch_start_time.items(), key=lambda x: x[1]), lambda x: x[1] + ) + } + + +def _get_pos_dict( + position: dict, + epoch_groups: dict, + session_id: str = None, + verbose: bool = False, + incl_times: bool = True, +): + """Return dict with obj ids and valid intervals for each epoch. + + Parameters + ---------- + position : hdmf.utils.LabeledDict + pynwb.behavior.Position.spatial_series + epoch_groups : dict + Epoch start times as keys, spatial series indices as values + session_id : str, optional + Optional session ID for verbose print during sampling rate estimation + verbose : bool, optional + Default to False. Print estimated sampling rate + incl_times : bool, optional + Default to True. Include valid intervals. Requires additional + computation not needed for RawPosition + """ + # prev, this was just a list. now, we need to gen mult dict per epoch + pos_data_dict = dict() + all_spatial_series = list(position.values()) + + for epoch, index_list in enumerate(epoch_groups.values()): + pos_data_dict[epoch] = [] + for index in index_list: + spatial_series = all_spatial_series[index] + valid_times = None + if incl_times: # get the valid intervals for the position data + timestamps = np.asarray(spatial_series.timestamps) + sampling_rate = estimate_sampling_rate( + timestamps, verbose=verbose, filename=session_id + ) + valid_times = get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + ) + # add the valid intervals to the Interval list + pos_data_dict[epoch].append( + { + "valid_times": valid_times, + "raw_position_object_id": spatial_series.object_id, + "name": spatial_series.name, + } + ) + + return pos_data_dict + + +def get_all_spatial_series(nwbf, verbose=False, incl_times=True) -> dict: + """ + Given an NWB, get the spatial series and return a dictionary by epoch. + + If incl_times is True, then the valid intervals are included in the output. Parameters ---------- @@ -317,58 +444,32 @@ def get_all_spatial_series(nwbf, verbose=False): The source NWB file object. verbose : bool Flag representing whether to print the sampling rate. + incl_times : bool + Include valid times in the output. Default, True. Set to False for only + spatial series object IDs. Returns ------- pos_data_dict : dict - Dict mapping indices to a dict with keys 'valid_times' and 'raw_position_object_id'. Returns None if there - is no position data in the file. The 'raw_position_object_id' is the object ID of the SpatialSeries object. + Dict mapping indices to a dict with keys 'valid_times' and + 'raw_position_object_id'. Returns None if there is no position data in + the file. The 'raw_position_object_id' is the object ID of the + SpatialSeries object. """ - position = get_data_interface(nwbf, "position", pynwb.behavior.Position) - if position is None: - return None - - # for some reason the spatial_series do not necessarily come out in order, so we need to figure out the right order - epoch_start_time = np.zeros(len(position.spatial_series.values())) - for pos_epoch, spatial_series in enumerate( - position.spatial_series.values() - ): - epoch_start_time[pos_epoch] = spatial_series.timestamps[0] - - sorted_order = np.argsort(epoch_start_time) - pos_data_dict = dict() + pos_interface = get_data_interface( + nwbf, "position", pynwb.behavior.Position + ) - for index, orig_epoch in enumerate(sorted_order): - spatial_series = list(position.spatial_series.values())[orig_epoch] - pos_data_dict[index] = dict() - # get the valid intervals for the position data - timestamps = np.asarray(spatial_series.timestamps) - - # estimate the sampling rate - timestamps = np.asarray(spatial_series.timestamps) - sampling_rate = estimate_sampling_rate(timestamps, 1.75) - if sampling_rate < 0: - raise ValueError( - f"Error adding position data for position epoch {index}" - ) - if verbose: - print( - "Processing raw position data. Estimated sampling rate: {} Hz".format( - sampling_rate - ) - ) - # add the valid intervals to the Interval list - pos_data_dict[index]["valid_times"] = get_valid_intervals( - timestamps, - sampling_rate, - gap_proportion=2.5, - min_valid_len=int(sampling_rate), - ) - pos_data_dict[index][ - "raw_position_object_id" - ] = spatial_series.object_id + if pos_interface is None: + return None - return pos_data_dict + return _get_pos_dict( + position=pos_interface.spatial_series, + epoch_groups=_get_epoch_groups(pos_interface), + session_id=nwbf.session_id, + verbose=verbose, + incl_times=incl_times, + ) def get_nwb_copy_filename(nwb_file_name): @@ -376,6 +477,9 @@ def get_nwb_copy_filename(nwb_file_name): filename, file_extension = os.path.splitext(nwb_file_name) + if filename.endswith("_"): + print(f"WARNING: File may already be a copy: {nwb_file_name}") + return f"{filename}_{file_extension}" @@ -393,7 +497,8 @@ def change_group_permissions( # Loop through nwb file directories and change group permissions for target_content in target_contents: print( - f"For {target_content}, changing group to {set_group_name} and giving read/write/execute permissions" + f"For {target_content}, changing group to {set_group_name} " + + "and giving read/write/execute permissions" ) # Change group os.system(f"chgrp -R {set_group_name} {target_content}") diff --git a/tests/conftest.py b/tests/conftest.py index 1356a80ef..eae26c2c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,10 +71,7 @@ def _set_env(): """Set environment variables.""" print("Setting datajoint and kachery environment variables.") - spyglass_base_dir = pathlib.Path(tempfile.mkdtemp()) - from spyglass.settings import load_config - - _ = load_config(str(spyglass_base_dir), force_reload=True) + os.environ["SPYGLASS_BASE_DIR"] = str(tempfile.mkdtemp()) dj.config["database.host"] = "localhost" dj.config["database.port"] = DATAJOINT_SERVER_PORT From 2dacdc6c8e7bae17b1ac46bbbfcf736fafc9bd87 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 22 Aug 2023 06:48:15 -0700 Subject: [PATCH 2/9] Fetch additional file from cbroz1/master to pass CI/CD --- src/spyglass/data_import/insert_sessions.py | 43 +++++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 9c9cc906e..5466b58a6 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -7,7 +7,8 @@ import pynwb from ..common import Nwbfile, get_raw_eseries, populate_all_common -from ..settings import load_config +from ..settings import raw_dir +from ..utils.nwb_helper_fn import get_nwb_copy_filename def insert_sessions(nwb_file_names: Union[str, List[str]]): @@ -18,9 +19,10 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): ---------- nwb_file_names : str or List of str File names in raw directory ($SPYGLASS_RAW_DIR) pointing to - existing .nwb files. Each file represents a session. + existing .nwb files. Each file represents a session. Also accepts + strings with glob wildcards (e.g., *) so long as the wildcard specifies + exactly one file. """ - _ = load_config() if not isinstance(nwb_file_names, list): nwb_file_names = [nwb_file_names] @@ -29,18 +31,32 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): if "/" in nwb_file_name: nwb_file_name = nwb_file_name.split("/")[-1] - nwb_file_abs_path = Path(Nwbfile.get_abs_path(nwb_file_name)) + nwb_file_abs_path = Path( + Nwbfile.get_abs_path(nwb_file_name, new_file=True) + ) + if not nwb_file_abs_path.exists(): - raise FileNotFoundError(f"File not found: {nwb_file_abs_path}") + possible_matches = sorted(Path(raw_dir).glob(f"*{nwb_file_name}*")) + + if len(possible_matches) == 1: + nwb_file_abs_path = possible_matches[0] + nwb_file_name = nwb_file_abs_path.name + + else: + raise FileNotFoundError( + f"File not found: {nwb_file_abs_path}\n\t" + + f"{len(possible_matches)} possible matches:" + + f"{possible_matches}" + ) # file name for the copied raw data - out_nwb_file_name = nwb_file_abs_path.stem + "_.nwb" + out_nwb_file_name = get_nwb_copy_filename(nwb_file_abs_path.stem) # Check whether the file already exists in the Nwbfile table if len(Nwbfile() & {"nwb_file_name": out_nwb_file_name}): warnings.warn( f"Cannot insert data from {nwb_file_name}: {out_nwb_file_name}" - + "is already in Nwbfile table." + + " is already in Nwbfile table." ) continue @@ -72,12 +88,15 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): + f"with link to raw ephys data: {out_nwb_file_name}" ) - nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name) - assert os.path.exists( - nwb_file_abs_path - ), f"File does not exist: {nwb_file_abs_path}" + nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) + + if not os.path.exists(nwb_file_abs_path): + raise FileNotFoundError(f"Could not find raw file: {nwb_file_abs_path}") + + out_nwb_file_abs_path = Nwbfile.get_abs_path( + out_nwb_file_name, new_file=True + ) - out_nwb_file_abs_path = Nwbfile.get_abs_path(out_nwb_file_name) if os.path.exists(out_nwb_file_name): warnings.warn( f"Output file {out_nwb_file_abs_path} exists and will be " From a2a1272a35f1dec368e053462ffa47ff59d2c001 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Tue, 22 Aug 2023 10:56:51 -0700 Subject: [PATCH 3/9] Add RawPos fetch method implementations. Object -> PosObject --- src/spyglass/common/common_behav.py | 40 ++++++++++++++----- .../position/v1/position_trodes_position.py | 16 ++++---- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index bb86390e9..32215124c 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -152,7 +152,7 @@ class RawPosition(dj.Imported): -> PositionSource """ - class Object(dj.Part): + class PosObject(dj.Part): definition = """ -> master -> PositionSource.SpatialSeries.proj('id') @@ -170,6 +170,9 @@ def fetch1_dataframe(self): id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()] + if len(set(rp.interval for _, rp in id_rp)) > 1: + print("WARNING: loading DataFrame with multiple intervals.") + df_list = [ pd.DataFrame( data=rp.data, @@ -200,7 +203,7 @@ def make(self, key): ] self.insert1(key) - self.Object.insert( + self.PosObject.insert( [ dict( nwb_file_name=nwb_file_name, @@ -213,15 +216,34 @@ def make(self, key): ] ) - def fetch_nwb(self, *attrs, **kwargs): - raise NotImplementedError( - "fetch_nwb now operates on RawPosition.Object" - ) + def fetch_nwb(self, *attrs, **kwargs) -> list: + """ + Returns a condatenated list of nwb objects from RawPosition.PosObject + """ + ret = [] + for pos_obj in self.PosObject: + ret.append([nwb for nwb in pos_obj.fetch_nwb(*attrs, **kwargs)]) + return ret def fetch1_dataframe(self): - raise NotImplementedError( - "fetch1_dataframe now operates on RawPosition.Object" - ) + """Returns a dataframe with all RawPosition.PosObject items. + + Uses interval_list_name as column index. + """ + ret = {} + + pos_obj_set = self.PosObject & self.restriction + unique_intervals = set(pos_obj_set.fetch("interval_list_name")) + + for interval in unique_intervals: + ret[interval] = ( + pos_obj_set & {"interval_list_name": interval} + ).fetch1_dataframe() + + if len(unique_intervals) == 1: + return next(iter(ret.values())) + + return pd.concat(ret, axis=1) @schema diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index da7456232..5ff40c990 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -180,10 +180,10 @@ def make(self, key): ) position_info_parameters = (TrodesPosParams() & key).fetch1("params") - spatial_series = (RawPosition.Object & key).fetch_nwb()[0][ + spatial_series = (RawPosition.PosObject & key).fetch_nwb()[0][ "raw_position" ] - spatial_df = (RawPosition.Object & key).fetch1_dataframe() + spatial_df = (RawPosition.PosObject & key).fetch1_dataframe() video_frame_ind = getattr(spatial_df, "video_frame_ind", None) position = pynwb.behavior.Position() @@ -295,8 +295,8 @@ def make(self, key): def calculate_position_info_from_spatial_series( spatial_df: pd.DataFrame, meters_to_pixels: float, - max_separation, - max_speed, + max_LED_separation, + max_plausible_speed, speed_smoothing_std_dev, position_smoothing_duration, orient_smoothing_std_dev, @@ -347,7 +347,7 @@ def calculate_position_info_from_spatial_series( # Set points to NaN where the front and back LEDs are too separated dist_between_LEDs = get_distance(back_LED, front_LED) - is_too_separated = dist_between_LEDs >= max_separation + is_too_separated = dist_between_LEDs >= max_LED_separation back_LED[is_too_separated] = np.nan front_LED[is_too_separated] = np.nan @@ -367,8 +367,8 @@ def calculate_position_info_from_spatial_series( ) # Set to points to NaN where the speed is too fast - is_too_fast = (front_LED_speed > max_speed) | ( - back_LED_speed > max_speed + is_too_fast = (front_LED_speed > max_plausible_speed) | ( + back_LED_speed > max_plausible_speed ) back_LED[is_too_fast] = np.nan front_LED[is_too_fast] = np.nan @@ -528,7 +528,7 @@ def make(self, key): print("Loading position data...") raw_position_df = ( - RawPosition.Object + RawPosition.PosObject & { "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["interval_list_name"], From 1a62d15e190d47cee1571012032fe07ff3d0a7cf Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Wed, 23 Aug 2023 09:08:11 -0700 Subject: [PATCH 4/9] Refactor PosIntervalMap helpers --- src/spyglass/common/common_behav.py | 143 +++++++++++++--------------- 1 file changed, 66 insertions(+), 77 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 32215124c..0ea194223 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -119,6 +119,10 @@ def get_pos_interval_name(epoch_num: int) -> str: ) return f"pos {epoch_num} valid times" + @classmethod + def _is_valid_name(self, name) -> bool: + return name.startswith("pos ") and not name.endswith(" valid times") + @staticmethod def get_epoch_num(name: str) -> int: """Return the epoch number from the interval name. @@ -133,11 +137,13 @@ def get_epoch_num(name: str) -> int: int epoch number """ + if not self._is_valid_name(name): + raise ValueError(f"Invalid interval name: {name}") return int(name.replace("pos ", "").replace(" valid times", "")) @schema -class RawPosition(dj.Imported): +class PosObject(dj.Imported): """ Notes @@ -456,8 +462,9 @@ def make(self, key): self._no_transaction_make(key) def _no_transaction_make(self, key): - # Find correspondence between pos valid times names and epochs - # Use epsilon to tolerate small differences in epoch boundaries across epoch/pos intervals + # Find correspondence between pos valid times names and epochs. Use + # epsilon to tolerate small differences in epoch boundaries across + # epoch/pos intervals if not self.connection.in_transaction: # if not called in the context of a make function, call its own make function @@ -465,130 +472,112 @@ def _no_transaction_make(self, key): return # *** HARD CODED VALUES *** - EPSILON = 0.11 # tolerated time difference in epoch boundaries across epoch/pos intervals - # ************************* + EPSILON = 0.11 # tolerated time diff in bounds across epoch/pos + no_pop_msg = "CANNOT POPULATE PositionIntervalMap" - # Unpack key nwb_file_name = key["nwb_file_name"] - - # Get pos interval list names - pos_interval_list_names = get_pos_interval_list_names(nwb_file_name) + pos_intervals = get_pos_interval_list_names(nwb_file_name) # Skip populating if no pos interval list names - if len(pos_interval_list_names) == 0: - print( - f"NO POS INTERVALS FOR {key}; CANNOT POPULATE PositionIntervalMap" - ) + if len(pos_intervals) == 0: + print(f"NO POS INTERVALS FOR {key}; {no_pop_msg}") return - # Get the interval times valid_times = (IntervalList & key).fetch1("valid_times") - time_interval = [ + time_bounds = [ valid_times[0][0] - EPSILON, valid_times[-1][-1] + EPSILON, - ] # [start, end], widen to tolerate small differences in epoch boundaries across epoch/pos intervals - - # compare the position intervals against our interval - matching_pos_interval_list_names = [] - for ( - pos_interval_list_name - ) in pos_interval_list_names: # for each pos valid time interval list - pos_valid_times = ( - IntervalList - & { - "nwb_file_name": nwb_file_name, - "interval_list_name": pos_interval_list_name, - } - ).fetch1( + ] + + matching_pos_intervals = [] + restr = ( + f"nwb_file_name='{nwb_file_name}' AND interval_list_name=" + "'{}'" + ) + for pos_interval in pos_intervals: + # cbroz: fetch1->fetch. fetch1 would fail w/o result + pos_times = (IntervalList & restr.format(pos_interval)).fetch( "valid_times" - ) # get interval valid times - if len(pos_valid_times) == 0: + ) + + if len(pos_times) == 0: continue - pos_time_interval = [ - pos_valid_times[0][0], - pos_valid_times[-1][-1], - ] # [pos valid time interval start, pos valid time interval end] - if (time_interval[0] < pos_time_interval[0]) and ( - time_interval[1] > pos_time_interval[1] - ): # if pos valid time interval within epoch interval - matching_pos_interval_list_names.append( - pos_interval_list_name - ) # add pos interval list name to list of matching pos interval list names + + if all( + [ + time_bounds[0] <= time <= time_bounds[1] + for time in [pos_times[0][0], pos_times[-1][-1]] + ] + ): + matching_pos_intervals.append(pos_intervals) # Check that each pos interval was matched to only one epoch - if len(matching_pos_interval_list_names) > 1: + if len(matching_pos_intervals) != 1: print( - f"MULTIPLE POS INTERVALS MATCHED TO EPOCH {key}; CANNOT POPULATE PositionIntervalMap" - ) - print(matching_pos_interval_list_names) - return - # Check that at least one pos interval was matched to an epoch - if len(matching_pos_interval_list_names) == 0: - print( - f"No pos intervals matched to epoch {key}; CANNOT POPULATE PositionIntervalMap" + f"Found {len(matching_pos_intervals)} pos intervals for {key}; " + + f"{no_pop_msg}\n{matching_pos_intervals}" ) return + # Insert into table - key["position_interval_name"] = matching_pos_interval_list_names[0] + key["position_interval_name"] = matching_pos_intervals[0] self.insert1(key, allow_direct_insert=True) print( - f'Populated PosIntervalMap for {nwb_file_name}, {key["interval_list_name"]}' + "Populated PosIntervalMap for " + + f'{nwb_file_name}, {key["interval_list_name"]}' ) -def get_pos_interval_list_names(nwb_file_name): +def get_pos_interval_list_names(nwb_file_name) -> list: return [ interval_list_name for interval_list_name in ( IntervalList & {"nwb_file_name": nwb_file_name} ).fetch("interval_list_name") - if ( - (interval_list_name.split(" ")[0] == "pos") - and (" ".join(interval_list_name.split(" ")[2:]) == "valid times") - ) + if PositionSource._is_valid_name(interval_list_name) ] def convert_epoch_interval_name_to_position_interval_name( key: dict, populate_missing: bool = True ) -> str: - """Converts a primary key for IntervalList to the corresponding position interval name. + """Converts IntervalList key to the corresponding position interval name. Parameters ---------- key : dict Lookup key populate_missing: bool - whether to populate PositionIntervalMap for the key if missing. Should be False if this function is used inside of another populate call + Whether to populate PositionIntervalMap for the key if missing. Should + be False if this function is used inside of another populate call. + Defaults to True Returns ------- position_interval_name : str """ - # get the interval list name if epoch given in key instead of interval list name + # get the interval list name if given epoch but not interval list name if "interval_list_name" not in key and "epoch" in key: key["interval_list_name"] = get_interval_list_name_from_epoch( key["nwb_file_name"], key["epoch"] ) - pos_interval_names = (PositionIntervalMap & key).fetch( - "position_interval_name" - ) - if len(pos_interval_names) == 0: + pos_query = PositionIntervalMap & key + + if len(pos_query) == 0: if populate_missing: PositionIntervalMap()._no_transaction_make(key) - pos_interval_names = (PositionIntervalMap & key).fetch( - "position_interval_name" - ) else: raise KeyError( - f"{key} must be populated in the PositionIntervalMap table prior to your current populate call" + f"{key} must be populated in the PositionIntervalMap table " + + "prior to your current populate call" ) - if len(pos_interval_names) == 0: + + if len(pos_query) == 0: print(f"No position intervals found for {key}") return [] - if len(pos_interval_names) == 1: - return pos_interval_names[0] + + if len(pos_query) == 1: + return pos_query.fetch1("position_interval_name") def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str: @@ -613,19 +602,19 @@ def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str: ) if (x.split("_")[0] == f"{epoch:02}") ] - if len(interval_names) == 0: - print(f"No interval list name found for {nwb_file_name} epoch {epoch}") - return None - if len(interval_names) > 1: + + if len(interval_names) != 1: print( - f"Multiple interval list names found for {nwb_file_name} epoch {epoch}" + f"Found {len(interval_name)} interval list names found for " + + f"{nwb_file_name} epoch {epoch}" ) return None + return interval_names[0] def populate_position_interval_map_session(nwb_file_name: str): - for interval_name in (TaskEpoch() & {"nwb_file_name": nwb_file_name}).fetch( + for interval_name in (TaskEpoch & {"nwb_file_name": nwb_file_name}).fetch( "interval_list_name" ): PositionIntervalMap.populate( From 0b81a7813292707fe412a95c984c3e52aea91850 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Wed, 23 Aug 2023 10:03:24 -0700 Subject: [PATCH 5/9] Revert typo --- .gitignore | 2 +- src/spyglass/common/common_behav.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 69148ee54..fad710aac 100644 --- a/.gitignore +++ b/.gitignore @@ -167,7 +167,7 @@ temp_nwb/*s *.json *.gz *.pdf -dj_local_conf.json +dj_local_con*.json !dj_local_conf_example.json !/.vscode/extensions.json diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 0ea194223..765d34211 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -143,7 +143,7 @@ def get_epoch_num(name: str) -> int: @schema -class PosObject(dj.Imported): +class RawPosition(dj.Imported): """ Notes From 0bb20603cea93864bccee2dd40aee1817248977c Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 28 Aug 2023 14:10:35 -0700 Subject: [PATCH 6/9] Bugfixes for ripple --- src/spyglass/common/common_behav.py | 17 ++++++------ src/spyglass/common/common_nwbfile.py | 2 +- src/spyglass/common/common_session.py | 13 ++++++--- src/spyglass/common/populate_all_common.py | 9 +++---- src/spyglass/lfp/v1/lfp.py | 11 +++++--- .../position/v1/position_trodes_position.py | 4 +-- src/spyglass/ripple/v1/ripple.py | 10 ++----- src/spyglass/settings.py | 27 +++++++++++++++---- src/spyglass/utils/nwb_helper_fn.py | 9 +++---- 9 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 765d34211..d26b2c0f6 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -57,7 +57,7 @@ def insert_from_nwbfile(cls, nwb_file_name): """ nwbf = get_nwb_file(nwb_file_name) all_pos = get_all_spatial_series(nwbf, verbose=True) - sess_key = dict(nwb_file_name=nwb_file_name) + sess_key = Nwbfile.get_file_key(nwb_file_name) src_key = dict(**sess_key, source="trodes", import_file_name="") if all_pos is None: @@ -119,9 +119,9 @@ def get_pos_interval_name(epoch_num: int) -> str: ) return f"pos {epoch_num} valid times" - @classmethod - def _is_valid_name(self, name) -> bool: - return name.startswith("pos ") and not name.endswith(" valid times") + @staticmethod + def _is_valid_name(name) -> bool: + return name.startswith("pos ") and name.endswith(" valid times") @staticmethod def get_epoch_num(name: str) -> int: @@ -137,7 +137,7 @@ def get_epoch_num(name: str) -> int: int epoch number """ - if not self._is_valid_name(name): + if not PositionSource._is_valid_name(name): raise ValueError(f"Invalid interval name: {name}") return int(name.replace("pos ", "").replace(" valid times", "")) @@ -226,10 +226,7 @@ def fetch_nwb(self, *attrs, **kwargs) -> list: """ Returns a condatenated list of nwb objects from RawPosition.PosObject """ - ret = [] - for pos_obj in self.PosObject: - ret.append([nwb for nwb in pos_obj.fetch_nwb(*attrs, **kwargs)]) - return ret + return self.PosObject().fetch_nwb(*attrs, **kwargs) def fetch1_dataframe(self): """Returns a dataframe with all RawPosition.PosObject items. @@ -502,6 +499,8 @@ def _no_transaction_make(self, key): if len(pos_times) == 0: continue + pos_times = pos_times[0] + if all( [ time_bounds[0] <= time <= time_bounds[1] diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 52c070d2b..cfb66ab2d 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -59,7 +59,7 @@ def insert_from_relative_file_name(cls, nwb_file_name): nwb_file_name : str The relative path to the NWB file. """ - nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name) + nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) assert os.path.exists( nwb_file_abs_path ), f"File does not exist: {nwb_file_abs_path}" diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 4054de52e..8bd79cb6e 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,11 +1,13 @@ import os + import datajoint as dj +from ..settings import config, debug_mode +from ..utils.nwb_helper_fn import get_config, get_nwb_file from .common_device import CameraDevice, DataAcquisitionDevice, Probe from .common_lab import Institution, Lab, LabMember from .common_nwbfile import Nwbfile from .common_subject import Subject -from ..utils.nwb_helper_fn import get_nwb_file, get_config schema = dj.schema("common_session") @@ -79,9 +81,10 @@ def make(self, key): print("Subject...") Subject().insert_from_nwbfile(nwbf) - print("Populate DataAcquisitionDevice...") - DataAcquisitionDevice.insert_from_nwbfile(nwbf, config) - print() + if not debug_mode: # TODO: remove when demo files agree on device + print("Populate DataAcquisitionDevice...") + DataAcquisitionDevice.insert_from_nwbfile(nwbf, config) + print() print("Populate CameraDevice...") CameraDevice.insert_from_nwbfile(nwbf) @@ -254,6 +257,8 @@ def create_spyglass_view(session_group_name: str): # datajoint prohibits deleting from a subtable without # also deleting the parent table. # See: https://docs.datajoint.org/python/computation/03-master-part.html + + @schema class SessionGroupSession(dj.Manual): definition = """ diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index da563ab23..4dcd3562c 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -1,9 +1,4 @@ -from .common_behav import ( - PositionSource, - RawPosition, - StateScriptFile, - VideoFile, -) +from .common_behav import PositionSource, RawPosition, StateScriptFile, VideoFile from .common_dio import DIOEvents from .common_ephys import Electrode, ElectrodeGroup, Raw, SampleCount from .common_nwbfile import Nwbfile @@ -35,11 +30,13 @@ def populate_all_common(nwb_file_name): print("Populate DIOEvents...") DIOEvents.populate(fp) + # sensor data (from analog ProcessingModule) is temporarily removed from NWBFile # to reduce file size while it is not being used. add it back in by commenting out # the removal code in spyglass/data_import/insert_sessions.py when ready # print('Populate SensorData') # SensorData.populate(fp) + print("Populate TaskEpochs") TaskEpoch.populate(fp) print("Populate StateScriptFile") diff --git a/src/spyglass/lfp/v1/lfp.py b/src/spyglass/lfp/v1/lfp.py index 109792d74..32d168f46 100644 --- a/src/spyglass/lfp/v1/lfp.py +++ b/src/spyglass/lfp/v1/lfp.py @@ -5,7 +5,6 @@ import pandas as pd from spyglass.common.common_ephys import Raw -from spyglass.lfp.lfp_electrode import LFPElectrodeGroup from spyglass.common.common_filter import FirFilterParameters from spyglass.common.common_interval import ( IntervalList, @@ -14,6 +13,7 @@ ) from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.common.common_session import Session # noqa: F401 +from spyglass.lfp.lfp_electrode import LFPElectrodeGroup from spyglass.utils.dj_helper_fn import fetch_nwb # dj_replace schema = dj.schema("lfp_v1") @@ -60,6 +60,9 @@ def make(self, key): # get the NWB object with the data nwbf_key = {"nwb_file_name": key["nwb_file_name"]} rawdata = (Raw & nwbf_key).fetch_nwb()[0]["raw"] + + # CBroz: assumes Raw sampling rate matches FirFilterParameters set? + # if we just pull rate from Raw, why include in Param table? sampling_rate, raw_interval_list_name = (Raw & nwbf_key).fetch1( "sampling_rate", "interval_list_name" ) @@ -95,8 +98,10 @@ def make(self, key): # get the LFP filter that matches the raw data filter = ( FirFilterParameters() - & {"filter_name": key["filter_name"]} - & {"filter_sampling_rate": sampling_rate} + & { + "filter_name": key["filter_name"], + "filter_sampling_rate": sampling_rate, + } # not key['filter_sampling_rate']? ).fetch(as_dict=True)[0] # there should only be one filter that matches, so we take the first of the dictionaries diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 5ff40c990..fcba493ff 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -48,8 +48,8 @@ def default_pk(self): @property def default_params(self): return { - "max_separation": 9.0, - "max_speed": 300.0, + "max_LED_separation": 9.0, + "max_plausible_speed": 300.0, "position_smoothing_duration": 0.125, "speed_smoothing_std_dev": 0.100, "orient_smoothing_std_dev": 0.001, diff --git a/src/spyglass/ripple/v1/ripple.py b/src/spyglass/ripple/v1/ripple.py index 65998ec25..d7e82bcb0 100644 --- a/src/spyglass/ripple/v1/ripple.py +++ b/src/spyglass/ripple/v1/ripple.py @@ -5,15 +5,9 @@ from ripple_detection import Karlsson_ripple_detector, Kay_ripple_detector from ripple_detection.core import gaussian_smooth, get_envelope -from spyglass.common.common_interval import ( - IntervalList, - interval_list_intersect, -) +from spyglass.common.common_interval import IntervalList, interval_list_intersect from spyglass.common.common_nwbfile import AnalysisNwbfile -from spyglass.lfp.analysis.v1.lfp_band import ( - LFPBandSelection, - LFPBandV1, -) +from spyglass.lfp.analysis.v1.lfp_band import LFPBandSelection, LFPBandV1 from spyglass.position import PositionOutput from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.nwb_helper_fn import get_electrode_indices diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 4ddf7d53a..203888bd4 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -61,8 +61,9 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: if config_loaded and not force_reload: return config - dj_spyglass = dj.config.get("custom", {}).get("spyglass_dirs", {}) - dj_kachery = dj.config.get("custom", {}).get("kachery_dirs", {}) + dj_custom = dj.config.get("custom", {}) + dj_spyglass = dj_custom.get("spyglass_dirs", {}) + dj_kachery = dj_custom.get("kachery_dirs", {}) resolved_base = ( base_dir @@ -102,7 +103,11 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: _set_dj_config_stores(config_dirs) config = dict( - **config_defaults, **config_dirs, **kachery_zone_dict, **loaded_env + debug_mode=dj_custom.get("debug_mode", False), + **config_defaults, + **config_dirs, + **kachery_zone_dict, + **loaded_env, ) config_loaded = True return config @@ -145,6 +150,9 @@ def _set_dj_config_stores(dir_dict: dict): } +# TODO: Change redundancy here to class with @properties + + def load_base_dir() -> str: """Retrieve the base directory from the configuration. @@ -235,12 +243,19 @@ def load_waveform_dir() -> str: Returns ------- str - The temp directory path. + The waveform directory path. """ global config if not config_loaded or not config: config = load_config() - return config.get("SPYGLASS_TEMP_DIR") + return config.get("SPYGLASS_WAVEFORM_DIR") + + +def load_debug_mode() -> bool: + global config + if not config_loaded or not config: + config = load_config() + return config.get("debug_mode", False) base_dir = load_base_dir() @@ -249,3 +264,5 @@ def load_waveform_dir() -> str: temp_dir = load_temp_dir() analysis_dir = load_analysis_dir() sorting_dir = load_sorting_dir() +waveform_dir = load_waveform_dir() +debug_mode = load_debug_mode() diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 835dfedc7..58b5cf243 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -247,7 +247,7 @@ def estimate_sampling_rate( def get_valid_intervals( - timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=None + timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=0 ): """Finds the set of all valid intervals in a list of timestamps. Valid interval: (start time, stop time) during which there are @@ -264,7 +264,7 @@ def get_valid_intervals( between consecutive timestamps exceeds gap_proportion, it is considered a gap. Must be > 1. Default to 2.5 min_valid_len : float, optional - Length of smallest valid interval. Default to sampling_rate. If greater + Length of smallest valid interval. Default to 0. If greater than interval duration, print warning and use half the total time. Returns @@ -275,10 +275,8 @@ def get_valid_intervals( eps = 0.0000001 - if not min_valid_len: - min_valid_len = int(sampling_rate) - total_time = timestamps[-1] - timestamps[0] + if total_time < min_valid_len: half_total_time = total_time / 2 print(f"WARNING: Setting minimum valid interval to {half_total_time}") @@ -419,6 +417,7 @@ def _get_pos_dict( valid_times = get_valid_intervals( timestamps=timestamps, sampling_rate=sampling_rate, + min_valid_len=int(sampling_rate), ) # add the valid intervals to the Interval list pos_data_dict[epoch].append( From f12013bea67f29be379eee9f081b0473e556a006 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 28 Aug 2023 14:14:49 -0700 Subject: [PATCH 7/9] Blackify --- notebooks/12_Ripple_Detection.ipynb | 50 +++++++++++++--------- src/spyglass/common/populate_all_common.py | 7 ++- src/spyglass/ripple/v1/ripple.py | 5 ++- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/notebooks/12_Ripple_Detection.ipynb b/notebooks/12_Ripple_Detection.ipynb index dda53d683..1d96cebc2 100644 --- a/notebooks/12_Ripple_Detection.ipynb +++ b/notebooks/12_Ripple_Detection.ipynb @@ -460,12 +460,16 @@ ], "source": [ "electrodes = (\n", - " (sgc.Electrode() & {\"nwb_file_name\": nwb_file_name}) * \n", - " (lfp_analysis.LFPBandSelection.LFPBandElectrode() & {\n", - " \"nwb_file_name\": nwb_file_name,\n", - " \"filter_name\": filter_name,\n", - " \"target_interval_list_name\": interval_list_name\n", - " }) * sgc.BrainRegion\n", + " (sgc.Electrode() & {\"nwb_file_name\": nwb_file_name})\n", + " * (\n", + " lfp_analysis.LFPBandSelection.LFPBandElectrode()\n", + " & {\n", + " \"nwb_file_name\": nwb_file_name,\n", + " \"filter_name\": filter_name,\n", + " \"target_interval_list_name\": interval_list_name,\n", + " }\n", + " )\n", + " * sgc.BrainRegion\n", ").fetch(format=\"frame\")\n", "electrodes" ] @@ -817,8 +821,7 @@ "source": [ "hpc_names = [\"ca1\", \"hippocampus\", \"CA1\", \"Hippocampus\"]\n", "electrodes.loc[\n", - " (electrodes.region_name.isin(hpc_names))\n", - " & (electrodes.probe_electrode == 0)\n", + " (electrodes.region_name.isin(hpc_names)) & (electrodes.probe_electrode == 0)\n", "]" ] }, @@ -838,14 +841,16 @@ "metadata": {}, "outputs": [], "source": [ - "electrode_list = np.unique((\n", - " electrodes.loc[\n", - " (electrodes.region_name.isin(hpc_names))\n", - " & (electrodes.probe_electrode == 0)\n", - " ]\n", - " .reset_index()\n", - " .electrode_id\n", - ").tolist())\n", + "electrode_list = np.unique(\n", + " (\n", + " electrodes.loc[\n", + " (electrodes.region_name.isin(hpc_names))\n", + " & (electrodes.probe_electrode == 0)\n", + " ]\n", + " .reset_index()\n", + " .electrode_id\n", + " ).tolist()\n", + ")\n", "\n", "electrode_list.sort()" ] @@ -876,7 +881,10 @@ "outputs": [], "source": [ "group_name = \"CA1_test\"\n", - "lfp_band_key = (lfp_analysis.LFPBandV1() & {\"filter_name\": filter_name, \"nwb_file_name\": nwb_file_name}).fetch1(\"KEY\")\n", + "lfp_band_key = (\n", + " lfp_analysis.LFPBandV1()\n", + " & {\"filter_name\": filter_name, \"nwb_file_name\": nwb_file_name}\n", + ").fetch1(\"KEY\")\n", "sgrip.RippleLFPSelection.set_lfp_electrodes(\n", " lfp_band_key,\n", " electrode_list=electrode_list,\n", @@ -1402,11 +1410,13 @@ } ], "source": [ - "pos_key = PositionOutput.merge_get_part({\n", + "pos_key = PositionOutput.merge_get_part(\n", + " {\n", " \"nwb_file_name\": nwb_file_name,\n", " \"position_info_param_name\": \"default\",\n", " \"interval_list_name\": \"pos 1 valid times\",\n", - " }).fetch1(\"KEY\")\n", + " }\n", + ").fetch1(\"KEY\")\n", "(PositionOutput & pos_key).fetch1_dataframe()" ] }, @@ -1492,7 +1502,7 @@ "key = {\n", " \"ripple_param_name\": \"default\",\n", " **rip_sel_key,\n", - " \"pos_merge_id\": pos_key[\"merge_id\"]\n", + " \"pos_merge_id\": pos_key[\"merge_id\"],\n", "}\n", "sgrip.RippleTimesV1().populate(key)" ] diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 4dcd3562c..da0bedac0 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -1,4 +1,9 @@ -from .common_behav import PositionSource, RawPosition, StateScriptFile, VideoFile +from .common_behav import ( + PositionSource, + RawPosition, + StateScriptFile, + VideoFile, +) from .common_dio import DIOEvents from .common_ephys import Electrode, ElectrodeGroup, Raw, SampleCount from .common_nwbfile import Nwbfile diff --git a/src/spyglass/ripple/v1/ripple.py b/src/spyglass/ripple/v1/ripple.py index d7e82bcb0..51832de7d 100644 --- a/src/spyglass/ripple/v1/ripple.py +++ b/src/spyglass/ripple/v1/ripple.py @@ -5,7 +5,10 @@ from ripple_detection import Karlsson_ripple_detector, Kay_ripple_detector from ripple_detection.core import gaussian_smooth, get_envelope -from spyglass.common.common_interval import IntervalList, interval_list_intersect +from spyglass.common.common_interval import ( + IntervalList, + interval_list_intersect, +) from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.lfp.analysis.v1.lfp_band import LFPBandSelection, LFPBandV1 from spyglass.position import PositionOutput From f76666f9cec6740bc29a86178478390d2f50fe6c Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 18 Sep 2023 11:50:40 -0700 Subject: [PATCH 8/9] WIP: minor edits --- src/spyglass/common/common_behav.py | 5 ++++- src/spyglass/common/common_lab.py | 2 +- src/spyglass/common/common_nwbfile.py | 3 ++- src/spyglass/data_import/insert_sessions.py | 8 +++++--- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index d26b2c0f6..d362f3a6c 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -507,7 +507,10 @@ def _no_transaction_make(self, key): for time in [pos_times[0][0], pos_times[-1][-1]] ] ): - matching_pos_intervals.append(pos_intervals) + matching_pos_intervals.append(pos_interval) + + if len(matching_pos_intervals) > 1: + break # Check that each pos interval was matched to only one epoch if len(matching_pos_intervals) != 1: diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 56333b679..9be9d15b2 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -39,7 +39,7 @@ def insert_from_nwbfile(cls, nwbf): The NWB file with experimenter information. """ if isinstance(nwbf, str): - nwb_file_abspath = Nwbfile.get_abs_path(nwbf) + nwb_file_abspath = Nwbfile.get_abs_path(nwbf, new_file=True) nwbf = get_nwb_file(nwb_file_abspath) if nwbf.experimenter is None: diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index cfb66ab2d..1becbdf9e 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -78,7 +78,8 @@ def _get_file_name(cls, nwb_file_name: str) -> str: return query.fetch1("nwb_file_name") raise ValueError( - f"Found {len(query)} matches for {nwb_file_name}: \n{query}" + f"Found {len(query)} matches for {nwb_file_name} in Nwbfile table:" + + f" \n{query}" ) @classmethod diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 5466b58a6..caee7682e 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -7,7 +7,7 @@ import pynwb from ..common import Nwbfile, get_raw_eseries, populate_all_common -from ..settings import raw_dir +from ..settings import debug_mode, raw_dir from ..utils.nwb_helper_fn import get_nwb_copy_filename @@ -50,7 +50,7 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): ) # file name for the copied raw data - out_nwb_file_name = get_nwb_copy_filename(nwb_file_abs_path.stem) + out_nwb_file_name = get_nwb_copy_filename(nwb_file_abs_path.name) # Check whether the file already exists in the Nwbfile table if len(Nwbfile() & {"nwb_file_name": out_nwb_file_name}): @@ -97,7 +97,9 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): out_nwb_file_name, new_file=True ) - if os.path.exists(out_nwb_file_name): + if os.path.exists(out_nwb_file_abs_path): + if debug_mode: + return out_nwb_file_abs_path warnings.warn( f"Output file {out_nwb_file_abs_path} exists and will be " + "overwritten." From 4e6b29f344ace0a2c02276e556c02bb94d081d6c Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Thu, 21 Sep 2023 11:58:48 -0700 Subject: [PATCH 9/9] Add restriction to fetch1_dataframe --- src/spyglass/common/common_behav.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index d362f3a6c..68f17c156 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -226,7 +226,11 @@ def fetch_nwb(self, *attrs, **kwargs) -> list: """ Returns a condatenated list of nwb objects from RawPosition.PosObject """ - return self.PosObject().fetch_nwb(*attrs, **kwargs) + return ( + self.PosObject() + .restrict(self.restriction) # Avoids fetch_nwb on whole table + .fetch_nwb(*attrs, **kwargs) + ) def fetch1_dataframe(self): """Returns a dataframe with all RawPosition.PosObject items.