From a5aa589d4aa07014eecaf72f4039f9ef4bf65f0b Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Tue, 15 Aug 2023 11:50:37 -0700 Subject: [PATCH] WIP: PositionSource add part table --- .gitignore | 2 +- src/spyglass/common/common_behav.py | 85 +++++--- src/spyglass/common/common_ephys.py | 23 +-- src/spyglass/common/common_nwbfile.py | 10 +- src/spyglass/common/common_session.py | 4 +- src/spyglass/data_import/insert_sessions.py | 30 ++- src/spyglass/utils/nwb_helper_fn.py | 216 ++++++++++++++------ 7 files changed, 249 insertions(+), 121 deletions(-) diff --git a/.gitignore b/.gitignore index 02b506120..764acc0d4 100644 --- a/.gitignore +++ b/.gitignore @@ -166,7 +166,7 @@ temp_nwb/*s *.json *.gz *.pdf -dj_local_conf.json +dj_local_conf* !dj_local_conf_example.json !/.vscode/extensions.json diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index e00418b1d..9473b06fe 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -27,12 +27,17 @@ class PositionSource(dj.Manual): definition = """ -> Session - -> IntervalList --- - source: varchar(200) # source of data; current options are "trodes" and "dlc" (deep lab cut) - import_file_name: varchar(2000) # path to import file if importing position data + source: varchar(200) # source of data (e.g., trodes, dlc) + import_file_name: varchar(2000) # path to import file if importing """ + class IntervalList(dj.Part): + definition = """ + -> master + -> IntervalList + """ + @classmethod def insert_from_nwbfile(cls, nwb_file_name): """Given an NWB file name, get the spatial series and interval lists from the file, add the interval @@ -43,32 +48,50 @@ def insert_from_nwbfile(cls, nwb_file_name): nwb_file_name : str The name of the NWB file. """ - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) + nwbf = get_nwb_file(nwb_file_name) + all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False) + sess_key = dict(nwb_file_name=nwb_file_name) + pos_source_key = dict(**sess_key, source="trodes", import_file_name="") + + if all_pos is None: + return + + intervals = [] + pos_intervals = [] + + for epoch, epoch_list in enumerate(all_pos.values()): + for index, pdict in enumerate(epoch_list): + pos_interval_name = cls.get_pos_interval_name([epoch, index]) + + intervals.append( + dict( + **sess_key, + interval_list_name=pos_interval_name, + valid_times=pos_dict["valid_times"], + ) + ) - pos_dict = get_all_spatial_series(nwbf, verbose=True) - if pos_dict is not None: - for epoch in pos_dict: - pdict = pos_dict[epoch] - pos_interval_list_name = cls.get_pos_interval_name(epoch) - - # create the interval list and insert it - interval_dict = dict() - interval_dict["nwb_file_name"] = nwb_file_name - interval_dict["interval_list_name"] = pos_interval_list_name - interval_dict["valid_times"] = pdict["valid_times"] - IntervalList().insert1(interval_dict, skip_duplicates=True) - - # add this interval list to the table - key = dict() - key["nwb_file_name"] = nwb_file_name - key["interval_list_name"] = pos_interval_list_name - key["source"] = "trodes" - key["import_file_name"] = "" - cls.insert1(key) + # UNTESTED + IntervalList.insert(intervals, skip_duplicates=True) + cls.IntervalList.insert(intervals, skip_duplicates=True) + cls.insert1(key) @staticmethod def get_pos_interval_name(pos_epoch_num): + """Retun string of the interval name from the epoch number. + + Parameters + ---------- + pos_epoch_num : int or str or list + If list of length 2, then a string of the form "epoch 1 index 2" + + Returns + ------- + str + Position interval name (e.g., pos epoch 1 index 2 valid times) + """ + if isinstance(pos_epoch_num, list) and len(pos_epoch_num) == 2: + pos_epoch_num = f"epoch {pos_epoch_num[0]} index {pos_epoch_num[1]}" return f"pos {pos_epoch_num} valid times" @@ -100,11 +123,13 @@ def make(self, key): for epoch in pos_dict: if key[ "interval_list_name" - ] == PositionSource.get_pos_interval_name(epoch): - pdict = pos_dict[epoch] - key["raw_position_object_id"] = pdict["raw_position_object_id"] - self.insert1(key) - break + ] != PositionSource.get_pos_interval_name(epoch): + continue + + pdict = pos_dict[epoch] + key["raw_position_object_id"] = pdict["raw_position_object_id"] + self.insert1(key) + break def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 5ca0af03c..eb792ff17 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -6,26 +6,26 @@ import pandas as pd import pynwb +from ..utils.dj_helper_fn import fetch_nwb # dj_replace +from ..utils.nwb_helper_fn import ( + estimate_sampling_rate, + get_config, + get_data_interface, + get_electrode_indices, + get_nwb_file, + get_valid_intervals, +) from .common_device import Probe # noqa: F401 from .common_filter import FirFilterParameters +from .common_interval import interval_list_censor # noqa: F401 from .common_interval import ( IntervalList, - interval_list_censor, # noqa: F401 interval_list_contains_ind, interval_list_intersect, ) from .common_nwbfile import AnalysisNwbfile, Nwbfile from .common_region import BrainRegion # noqa: F401 from .common_session import Session # noqa: F401 -from ..utils.dj_helper_fn import fetch_nwb # dj_replace -from ..utils.nwb_helper_fn import ( - estimate_sampling_rate, - get_data_interface, - get_electrode_indices, - get_nwb_file, - get_valid_intervals, - get_config, -) schema = dj.schema("common_ephys") @@ -251,9 +251,8 @@ def make(self, key): print("Estimating sampling rate...") # NOTE: Only use first 1e6 timepoints to save time sampling_rate = estimate_sampling_rate( - np.asarray(rawdata.timestamps[: int(1e6)]), 1.5 + np.asarray(rawdata.timestamps[: int(1e6)]), 1.5, verbose=True ) - print(f"Estimated sampling rate: {sampling_rate}") key["sampling_rate"] = sampling_rate interval_dict = dict() diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 493e98962..a4f7b1203 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -12,7 +12,7 @@ import spikeinterface as si from hdmf.common import DynamicTable -from ..settings import load_config +from ..settings import raw_dir from ..utils.dj_helper_fn import get_child_tables from ..utils.nwb_helper_fn import get_electrode_indices, get_nwb_file @@ -73,12 +73,14 @@ def insert_from_relative_file_name(cls, nwb_file_name): def get_abs_path(nwb_file_name): """Return absolute path for a stored raw NWB file given file name. - The SPYGLASS_BASE_DIR environment variable must be set. + The SPYGLASS_BASE_DIR must be set, either as an environment or part of + dj.config['custom']. See spyglass.settings.load_config Parameters ---------- nwb_file_name : str - The name of an NWB file that has been inserted into the Nwbfile() schema. + The name of an NWB file that has been inserted into the Nwbfile() + schema. Returns ------- @@ -86,7 +88,7 @@ def get_abs_path(nwb_file_name): The absolute path for the given file name. """ - return load_config()["SPYGLASS_RAW_DIR"] + "/" + nwb_file_name + return raw_dir + "/" + nwb_file_name @staticmethod def add_to_lock(nwb_file_name): diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 949f3cd4e..5203eaa3b 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,11 +1,11 @@ import datajoint as dj +from ..settings import config +from ..utils.nwb_helper_fn import get_config, get_nwb_file from .common_device import CameraDevice, DataAcquisitionDevice, Probe from .common_lab import Institution, Lab, LabMember from .common_nwbfile import Nwbfile from .common_subject import Subject -from ..utils.nwb_helper_fn import get_nwb_file, get_config -from ..settings import config schema = dj.schema("common_session") diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index c732151c9..f47db38e9 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -7,7 +7,8 @@ import pynwb from ..common import Nwbfile, get_raw_eseries, populate_all_common -from ..settings import load_config +from ..settings import raw_dir +from ..utils.nwb_helper_fn import get_nwb_copy_filename def insert_sessions(nwb_file_names: Union[str, List[str]]): @@ -18,7 +19,9 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): ---------- nwb_file_names : str or List of str File names in raw directory ($SPYGLASS_RAW_DIR) pointing to - existing .nwb files. Each file represents a session. + existing .nwb files. Each file represents a session. Also accepts + strings with glob wildcards (e.g., *) so long as the wildcard specifies + exactly one file. """ if not isinstance(nwb_file_names, list): @@ -29,11 +32,23 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): nwb_file_name = nwb_file_name.split("/")[-1] nwb_file_abs_path = Path(Nwbfile.get_abs_path(nwb_file_name)) + if not nwb_file_abs_path.exists(): - raise FileNotFoundError(f"File not found: {nwb_file_abs_path}") + possible_matches = sorted(Path(raw_dir).glob(f"*{nwb_file_name}*")) + + if len(possible_matches) == 1: + nwb_file_abs_path = possible_matches[0] + nwb_file_name = nwb_file_abs_path.name + + else: + raise FileNotFoundError( + f"File not found: {nwb_file_abs_path}\n\t" + + f"{len(possible_matches)} possible matches:" + + f"{possible_matches}" + ) # file name for the copied raw data - out_nwb_file_name = nwb_file_abs_path.stem + "_.nwb" + out_nwb_file_name = get_nwb_copy_filename(nwb_file_abs_path.stem) # Check whether the file already exists in the Nwbfile table if len(Nwbfile() & {"nwb_file_name": out_nwb_file_name}): @@ -72,11 +87,12 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): ) nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name) - assert os.path.exists( - nwb_file_abs_path - ), f"File does not exist: {nwb_file_abs_path}" + + if not os.path.exists(nwb_file_abs_path): + raise FileNotFoundError(f"Could not find raw file: {nwb_file_abs_path}") out_nwb_file_abs_path = Nwbfile.get_abs_path(out_nwb_file_name) + if os.path.exists(out_nwb_file_name): warnings.warn( f"Output file {out_nwb_file_abs_path} exists and will be " diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 58b9b4696..15389ba71 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -21,26 +21,37 @@ def get_nwb_file(nwb_file_path): """Return an NWBFile object with the given file path in read mode. - If the file is not found locally, this will check if it has been shared with kachery and if so, download it and open it. + + If the file is not found locally, this will check if it has been shared + with kachery and if so, download it and open it. + Parameters ---------- nwb_file_path : str - Path to the NWB file. + Path to the NWB file or NWB file name. If it does not start with a "/", + get path with Nwbfile.get_abs_path Returns ------- nwbfile : pynwb.NWBFile NWB file object for the given path opened in read mode. """ + if not nwb_file_path.startswith("/"): + from ..common import Nwbfile + + nwb_file_path = Nwbfile.get_abs_path(nwb_file_path) + _, nwbfile = __open_nwb_files.get(nwb_file_path, (None, None)) nwb_uri = None nwb_raw_uri = None + if nwbfile is None: # check to see if the file exists if not os.path.exists(nwb_file_path): print( - f"NWB file {nwb_file_path} does not exist locally; checking kachery" + f"NWB file not found locally; checking kachery for " + + f"{nwb_file_path}" ) # first try the analysis files from ..sharing.sharing_kachery import AnalysisNwbfileKachery @@ -165,49 +176,73 @@ def get_raw_eseries(nwbfile): return ret -def estimate_sampling_rate(timestamps, multiplier): +def estimate_sampling_rate( + timestamps, multiplier=1.75, verbose=False, filename="file" +): """Estimate the sampling rate given a list of timestamps. - Assumes that the most common temporal differences between timestamps approximate the sampling rate. Note that this - can fail for very high sampling rates and irregular timestamps. + Assumes that the most common temporal differences between timestamps + approximate the sampling rate. Note that this can fail for very high + sampling rates and irregular timestamps. Parameters ---------- timestamps : numpy.ndarray 1D numpy array of timestamp values. - multiplier : float or int + multiplier : float or int, optional + Deft + verbose : bool, optional + Print sampling rate to stdout. Default, False + filename : str, optional + Filename to reference when printing or err. Default, "file" Returns ------- estimated_rate : float The estimated sampling rate. + + Raises + ------ + ValueError + If estimated rate is less than 0. """ # approach: # 1. use a box car smoother and a histogram to get the modal value - # 2. identify adjacent samples as those that have a time difference < the multiplier * the modal value + # 2. identify adjacent samples as those that have a + # time difference < the multiplier * the modal value # 3. average the time differences between adjacent samples + sample_diff = np.diff(timestamps[~np.isnan(timestamps)]) + if len(sample_diff) < 10: raise ValueError( f"Only {len(sample_diff)} timestamps are valid. Check the data." ) - nsmooth = 10 - smoother = np.ones(nsmooth) / nsmooth - smooth_diff = np.convolve(sample_diff, smoother, mode="same") - # we histogram with 100 bins out to 3 * mean, which should be fine for any reasonable number of samples + smooth_diff = np.convolve(sample_diff, np.ones(10) / 10, mode="same") + + # we histogram with 100 bins out to 3 * mean, which should be fine for any + # reasonable number of samples hist, bins = np.histogram( smooth_diff, bins=100, range=[0, 3 * np.mean(smooth_diff)] ) - mode = bins[np.where(hist == np.max(hist))] + mode = bins[np.where(hist == np.max(hist))][0] + + adjacent = sample_diff < mode * multiplier - adjacent = sample_diff < mode[0] * multiplier - return np.round(1.0 / np.mean(sample_diff[adjacent])) + sampling_rate = np.round(1.0 / np.mean(sample_diff[adjacent])) + + if sampling_rate < 0: + raise ValueError(f"Error calculating sampling rate. For {filename}") + if verbose: + print(f"Estimated sampling rate for {filename}: {sampling_rate} Hz") + + return sampling_rate def get_valid_intervals( - timestamps, sampling_rate, gap_proportion, min_valid_len + timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=None ): """Finds the set of all valid intervals in a list of timestamps. Valid interval: (start time, stop time) during which there are @@ -219,13 +254,13 @@ def get_valid_intervals( 1D numpy array of timestamp values. sampling_rate : float Sampling rate of the data. - gap_proportion : float, greater than 1; unit: samples - Threshold for detecting a gap; - i.e. if the difference (in samples) between - consecutive timestamps exceeds gap_proportion, - it is considered a gap - min_valid_len : float - Length of smallest valid interval. + gap_proportion : float, optional + Threshold for detecting a gap; i.e. if the difference (in samples) + between consecutive timestamps exceeds gap_proportion, it is considered + a gap. Must be > 1. Default to 2.5 + min_valid_len : float, optional + Length of smallest valid interval. Default to sampling_rate. If greater + than interval duration, print warning and use half the total time. Returns ------- @@ -235,6 +270,15 @@ def get_valid_intervals( eps = 0.0000001 + if not min_valid_len: + min_valid_len = int(sampling_rate) + + total_time = timestamps[-1] - timestamps[0] + if total_time < min_valid_len: + half_total_time = total_time / 2 + print(f"WARNING: Setting minimum valid interval to {half_total_time}") + min_valid_len = half_total_time + # get rid of NaN elements timestamps = timestamps[~np.isnan(timestamps)] # find gaps @@ -308,7 +352,79 @@ def get_electrode_indices(nwb_object, electrode_ids): ] -def get_all_spatial_series(nwbf, verbose=False): +def _get_epoch_groups(position: pynwb.behavior.Position, old_format=True): + if old_format: + epoch_start_time = np.zeros(len(position.spatial_series.values())) + for pos_epoch, spatial_series in enumerate( + position.spatial_series.values() + ): + epoch_start_time[pos_epoch] = spatial_series.timestamps[0] + + return np.argsort(epoch_start_time) + + from itertools import groupby + + epoch_start_time = {} + for pos_epoch, spatial_series in enumerate( + position.spatial_series.values() + ): + epoch_start_time[pos_epoch] = spatial_series.timestamps[0] + + return { + i: [j[0] for j in j] + for i, j in groupby( + sorted(epoch_start_time.items(), key=lambda x: x[1]), lambda x: x[1] + ) + } + + +def _get_pos_dict(position, epoch_groups, nwbf, verbose=False, old_format=True): + # prev, this was just a list. now, we need to gen mult dict per epoch + pos_data_dict = dict() + all_spatial_series = list(position.spatial_series.values()) + if old_format: + # for index, orig_epoch in enumerate(sorted_order): + for index, orig_epoch in enumerate(epoch_groups): + spatial_series = all_spatial_series[orig_epoch] + # get the valid intervals for the position data + timestamps = np.asarray(spatial_series.timestamps) + sampling_rate = estimate_sampling_rate( + timestamps, verbose=verbose, filename=nwbf.session_id + ) + # add the valid intervals to the Interval list + pos_data_dict[index] = { + "valid_times": get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + ), + "raw_position_object_id": spatial_series.object_id, + } + + else: + for epoch, index_list in enumerate(epoch_groups.values()): + pos_data_dict[epoch] = [] + for index in index_list: + spatial_series = all_spatial_series[index] + # get the valid intervals for the position data + timestamps = np.asarray(spatial_series.timestamps) + sampling_rate = estimate_sampling_rate( + timestamps, verbose=verbose, filename=nwbf.session_id + ) + # add the valid intervals to the Interval list + pos_data_dict[epoch].append( + { + "valid_times": get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + ), + "raw_position_object_id": spatial_series.object_id, + } + ) + + return pos_data_dict + + +def get_all_spatial_series(nwbf, verbose=False, old_format=True): """Given an NWBFile, get the spatial series and interval lists from the file and return a dictionary by epoch. Parameters @@ -325,50 +441,17 @@ def get_all_spatial_series(nwbf, verbose=False): is no position data in the file. The 'raw_position_object_id' is the object ID of the SpatialSeries object. """ position = get_data_interface(nwbf, "position", pynwb.behavior.Position) + if position is None: return None - # for some reason the spatial_series do not necessarily come out in order, so we need to figure out the right order - epoch_start_time = np.zeros(len(position.spatial_series.values())) - for pos_epoch, spatial_series in enumerate( - position.spatial_series.values() - ): - epoch_start_time[pos_epoch] = spatial_series.timestamps[0] - - sorted_order = np.argsort(epoch_start_time) - pos_data_dict = dict() - - for index, orig_epoch in enumerate(sorted_order): - spatial_series = list(position.spatial_series.values())[orig_epoch] - pos_data_dict[index] = dict() - # get the valid intervals for the position data - timestamps = np.asarray(spatial_series.timestamps) - - # estimate the sampling rate - timestamps = np.asarray(spatial_series.timestamps) - sampling_rate = estimate_sampling_rate(timestamps, 1.75) - if sampling_rate < 0: - raise ValueError( - f"Error adding position data for position epoch {index}" - ) - if verbose: - print( - "Processing raw position data. Estimated sampling rate: {} Hz".format( - sampling_rate - ) - ) - # add the valid intervals to the Interval list - pos_data_dict[index]["valid_times"] = get_valid_intervals( - timestamps, - sampling_rate, - gap_proportion=2.5, - min_valid_len=int(sampling_rate), - ) - pos_data_dict[index][ - "raw_position_object_id" - ] = spatial_series.object_id - - return pos_data_dict + return _get_pos_dict( + position=position, + epoch_groups=_get_epoch_groups(position, old_format=old_format), + nwbf=nwbf, + verbose=verbose, + old_format=old_format, + ) def get_nwb_copy_filename(nwb_file_name): @@ -376,6 +459,9 @@ def get_nwb_copy_filename(nwb_file_name): filename, file_extension = os.path.splitext(nwb_file_name) + if filename.endswith("_"): + print(f"WARNING: File may already be a copy: {nwb_file_name}") + return f"{filename}_{file_extension}"