diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 9473b06fe..72ea78adc 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -1,5 +1,6 @@ import os import pathlib +from functools import reduce from typing import Dict import datajoint as dj @@ -27,15 +28,18 @@ class PositionSource(dj.Manual): definition = """ -> Session + -> IntervalList --- source: varchar(200) # source of data (e.g., trodes, dlc) import_file_name: varchar(2000) # path to import file if importing """ - class IntervalList(dj.Part): + class SpatialSeries(dj.Part): definition = """ -> master - -> IntervalList + id : int unsigned # index of spatial series + --- + name=null: varchar(32) # name of spatial series """ @classmethod @@ -51,48 +55,79 @@ def insert_from_nwbfile(cls, nwb_file_name): nwbf = get_nwb_file(nwb_file_name) all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False) sess_key = dict(nwb_file_name=nwb_file_name) - pos_source_key = dict(**sess_key, source="trodes", import_file_name="") + src_key = dict(**sess_key, source="trodes", import_file_name="") if all_pos is None: return + sources = [] intervals = [] - pos_intervals = [] + spat_series = [] - for epoch, epoch_list in enumerate(all_pos.values()): - for index, pdict in enumerate(epoch_list): - pos_interval_name = cls.get_pos_interval_name([epoch, index]) + for epoch, epoch_list in all_pos.items(): + ind_key = dict(interval_list_name=cls.get_pos_interval_name(epoch)) - intervals.append( + sources.append(dict(**src_key, **ind_key)) + intervals.append( + dict( + **sess_key, + **ind_key, + valid_times=epoch_list[0]["valid_times"], + ) + ) + + for index, pdict in enumerate(epoch_list): + spat_series.append( dict( **sess_key, - interval_list_name=pos_interval_name, - valid_times=pos_dict["valid_times"], + **ind_key, + id=ndex, + name=pdict.get("name"), ) ) - # UNTESTED - IntervalList.insert(intervals, skip_duplicates=True) - cls.IntervalList.insert(intervals, skip_duplicates=True) - cls.insert1(key) + with cls.connection.transaction: + IntervalList.insert(intervals) + cls.insert(sources) + cls.SpatialSeries.insert(spat_series) @staticmethod - def get_pos_interval_name(pos_epoch_num): - """Retun string of the interval name from the epoch number. + def get_pos_interval_name(epoch_num: int) -> str: + """Return string of the interval name from the epoch number. Parameters ---------- - pos_epoch_num : int or str or list - If list of length 2, then a string of the form "epoch 1 index 2" + pos_epoch_num : int + Input epoch number Returns ------- str - Position interval name (e.g., pos epoch 1 index 2 valid times) + Position interval name (e.g., pos 2 valid times) """ - if isinstance(pos_epoch_num, list) and len(pos_epoch_num) == 2: - pos_epoch_num = f"epoch {pos_epoch_num[0]} index {pos_epoch_num[1]}" - return f"pos {pos_epoch_num} valid times" + try: + int(epoch_num) + except ValueError: + raise ValueError( + f"Epoch number must must be an integer. Received: {epoch_num}" + ) + return f"pos {epoch_num} valid times" + + @staticmethod + def get_epoch_num(name: str) -> int: + """Return the epoch number from the interval name. + + Parameters + ---------- + name : str + Name of position interval (e.g., pos epoch 1 index 2 valid times) + + Returns + ------- + int + epoch number + """ + return int(name.replace("pos ", "").replace(" valid times", "")) @schema @@ -109,37 +144,77 @@ class RawPosition(dj.Imported): definition = """ -> PositionSource - --- - raw_position_object_id: varchar(40) # the object id of the spatial series for this epoch in the NWB file """ + class Object(dj.Part): + definition = """ + -> master + -> PositionSource.SpatialSeries.proj('id') + --- + raw_position_object_id: varchar(40) # id of spatial series in NWB file + """ + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs + ) + + def fetch1_dataframe(self): + INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) + + id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()] + + df_list = [ + pd.DataFrame( + data=rp.data, + index=pd.Index(rp.timestamps, name="time"), + columns=[ + col # use existing columns if already numbered + if "1" in rp.description or "2" in rp.description + # else number them by id + else col + str(id + INDEX_ADJUST) + for col in rp.description.split(", ") + ], + ) + for id, rp in id_rp + ] + + return reduce(lambda x, y: pd.merge(x, y, on="time"), df_list) + def make(self, key): nwb_file_name = key["nwb_file_name"] - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) - - # TODO refactor this. this calculates sampling rate (unused here) and is expensive to do twice - pos_dict = get_all_spatial_series(nwbf) - for epoch in pos_dict: - if key[ - "interval_list_name" - ] != PositionSource.get_pos_interval_name(epoch): - continue + interval_list_name = key["interval_list_name"] - pdict = pos_dict[epoch] - key["raw_position_object_id"] = pdict["raw_position_object_id"] - self.insert1(key) - break + nwbf = get_nwb_file(nwb_file_name) + indices = (PositionSource.SpatialSeries & key).fetch("id") + + # incl_times = False -> don't do extra processing for valid_times + spat_objs = get_all_spatial_series( + nwbf, old_format=False, incl_times=False + )[PositionSource.get_epoch_num(interval_list_name)] + + self.insert1(key) + self.Object.insert( + [ + dict( + nwb_file_name=nwb_file_name, + interval_list_name=interval_list_name, + id=index, + raw_position_object_id=obj["raw_position_object_id"], + ) + for index, obj in enumerate(spat_objs) + if index in indices + ] + ) def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) + raise NotImplementedError( + "fetch_nwb now operates on RawPosition.Object" + ) def fetch1_dataframe(self): - raw_position_nwb = self.fetch_nwb()[0]["raw_position"] - return pd.DataFrame( - data=raw_position_nwb.data, - index=pd.Index(raw_position_nwb.timestamps, name="time"), - columns=raw_position_nwb.description.split(", "), + raise NotImplementedError( + "fetch1_dataframe now operates on RawPosition.Object" ) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index a4f7b1203..dd65d60a9 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -69,8 +69,25 @@ def insert_from_relative_file_name(cls, nwb_file_name): key["nwb_file_abs_path"] = nwb_file_abs_path cls.insert1(key, skip_duplicates=True) - @staticmethod - def get_abs_path(nwb_file_name): + @classmethod + def _get_file_name(cls, nwb_file_name: str) -> str: + """Get valid nwb file name given substring.""" + query = cls & f'nwb_file_name LIKE "%{nwb_file_name}%"' + + if len(query) == 1: + return query.fetch1("nwb_file_name") + + raise ValueError( + f"Found {len(query)} matches for {nwb_file_name}: \n{query}" + ) + + @classmethod + def get_file_key(cls, nwb_file_name: str) -> dict: + """Return primary key using nwb_file_name substring.""" + return {"nwb_file_name": cls._get_file_name(nwb_file_name)} + + @classmethod + def get_abs_path(cls, nwb_file_name) -> str: """Return absolute path for a stored raw NWB file given file name. The SPYGLASS_BASE_DIR must be set, either as an environment or part of @@ -80,15 +97,14 @@ def get_abs_path(nwb_file_name): ---------- nwb_file_name : str The name of an NWB file that has been inserted into the Nwbfile() - schema. + table. May be file substring. May include % wildcard(s). Returns ------- nwb_file_abspath : str The absolute path for the given file name. """ - - return raw_dir + "/" + nwb_file_name + return raw_dir + "/" + cls._get_file_name(nwb_file_name) @staticmethod def add_to_lock(nwb_file_name): diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index f20359c98..e90e8a17a 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -67,10 +67,9 @@ def get_default(cls): query = cls & {"trodes_pos_params_name": "default"} if not len(query) > 0: cls().insert_default(skip_duplicates=True) - default = (cls & {"trodes_pos_params_name": "default"}).fetch1() - else: - default = query.fetch1() - return default + return (cls & {"trodes_pos_params_name": "default"}).fetch1() + + return query.fetch1() @classmethod def get_accepted_params(cls): @@ -88,9 +87,40 @@ class TrodesPosSelection(dj.Manual): definition = """ -> RawPosition -> TrodesPosParams - --- """ + @classmethod + def insert_with_default( + cls, key: dict, skip_duplicates: bool = False + ) -> None: + """Insert key with default parameters. + + Parameters + ---------- + key: Union[dict, str] + Restriction uniquely identifying entr(y/ies) in RawPosition. + skip_duplicates: bool, optional + Skip duplicate entries. + + Raises + ------ + ValueError + Key does not identify any entries in RawPosition. + """ + query = RawPosition & key + if not query: + raise ValueError(f"Found no entries found for {key}") + + _ = TrodesPosParams.get_default() + + cls.insert( + [ + dict(**k, trodes_pos_params_name="default") + for k in query.fetch("KEY", as_dict=True) + ], + skip_duplicates=skip_duplicates, + ) + @schema class TrodesPosV1(dj.Computed): @@ -108,103 +138,112 @@ class TrodesPosV1(dj.Computed): """ def make(self, key): - orig_key = copy.deepcopy(key) + METERS_PER_CM = 0.01 + print(f"Computing position for: {key}") + + orig_key = copy.deepcopy(key) key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) - raw_position = (RawPosition() & key).fetch_nwb()[0] + position_info_parameters = (TrodesPosParams() & key).fetch1("params") + spatial_series = (RawPosition.Object & key).fetch_nwb()[0][ + "raw_position" + ] + spatial_df = (RawPosition.Object & key).fetch1_dataframe() + video_frame_ind = getattr(spatial_df, "video_frame_ind", None) + position = pynwb.behavior.Position() orientation = pynwb.behavior.CompassDirection() velocity = pynwb.behavior.BehavioralTimeSeries() - METERS_PER_CM = 0.01 - raw_pos_df = pd.DataFrame( - data=raw_position["raw_position"].data, - index=pd.Index( - raw_position["raw_position"].timestamps, name="time" + # NOTE: CBroz1 removed a try/except ValueError that surrounded all + # .create_X_series methods. dpeg22 could not recall purpose + + position_info = self.calculate_position_info_from_spatial_series( + spatial_df=spatial_df, + meters_to_pixels=spatial_series.conversion, + **position_info_parameters, + ) + + time_comments = dict( + comments=spatial_series.comments, + timestamps=position_info["time"], + ) + time_comments_ref = dict( + **time_comments, + reference_frame=spatial_series.reference_frame, + ) + + # create nwb objects for insertion into analysis nwb file + position.create_spatial_series( + name="position", + conversion=METERS_PER_CM, + data=position_info["position"], + description="x_position, y_position", + **time_comments_ref, + ) + + orientation.create_spatial_series( + name="orientation", + conversion=1.0, + data=position_info["orientation"], + description="orientation", + **time_comments_ref, + ) + + velocity.create_timeseries( + name="velocity", + conversion=METERS_PER_CM, + unit="m/s", + data=np.concatenate( + ( + position_info["velocity"], + position_info["speed"][:, np.newaxis], + ), + axis=1, ), - columns=raw_position["raw_position"].description.split(", "), + description="x_velocity, y_velocity, speed", + **time_comments, ) - try: - # calculate the processed position - spatial_series = raw_position["raw_position"] - position_info = self.calculate_position_info_from_spatial_series( - spatial_series, - **position_info_parameters, - ) - # create nwb objects for insertion into analysis nwb file - position.create_spatial_series( - name="position", - timestamps=position_info["time"], - conversion=METERS_PER_CM, - data=position_info["position"], - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, - description="x_position, y_position", - ) - orientation.create_spatial_series( - name="orientation", - timestamps=position_info["time"], - conversion=1.0, - data=position_info["orientation"], - reference_frame=spatial_series.reference_frame, - comments=spatial_series.comments, - description="orientation", + if video_frame_ind: + velocity.create_timeseries( + name="video_frame_ind", + unit="index", + data=spatial_df.video_frame_ind.to_numpy(), + description="video_frame_ind", + **time_comments, + ) + else: + print( + "No video frame index found. Assuming all camera frames " + + "are present." ) - velocity.create_timeseries( - name="velocity", - timestamps=position_info["time"], - conversion=METERS_PER_CM, - unit="m/s", - data=np.concatenate( - ( - position_info["velocity"], - position_info["speed"][:, np.newaxis], - ), - axis=1, - ), - comments=spatial_series.comments, - description="x_velocity, y_velocity, speed", + name="video_frame_ind", + unit="index", + data=np.arange(len(position_info["time"])), + description="video_frame_ind", + **time_comments, ) - try: - velocity.create_timeseries( - name="video_frame_ind", - unit="index", - timestamps=position_info["time"], - data=raw_pos_df.video_frame_ind.to_numpy(), - description="video_frame_ind", - comments=spatial_series.comments, - ) - except AttributeError: - print( - "No video frame index found. Assuming all camera frames are present." - ) - velocity.create_timeseries( - name="video_frame_ind", - unit="index", - timestamps=position_info["time"], - data=np.arange(len(position_info["time"])), - description="video_frame_ind", - comments=spatial_series.comments, - ) - except ValueError: - pass # Insert into analysis nwb file nwb_analysis_file = AnalysisNwbfile() - key["position_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], position - ) - key["orientation_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], orientation - ) - key["velocity_object_id"] = nwb_analysis_file.add_nwb_object( - key["analysis_file_name"], velocity + key.update( + dict( + position_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], position + ), + orientation_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], orientation + ), + velocity_object_id=nwb_analysis_file.add_nwb_object( + key["analysis_file_name"], velocity + ), + ) ) AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) @@ -214,6 +253,7 @@ def make(self, key): from ..position_merge import PositionOutput part_name = to_camel_case(self.table_name.split("__")[-1]) + # TODO: The next line belongs in a merge table function PositionOutput._merge_insert( [orig_key], part_name=part_name, skip_duplicates=True @@ -221,7 +261,8 @@ def make(self, key): @staticmethod def calculate_position_info_from_spatial_series( - spatial_series, + spatial_df: pd.DataFrame, + meters_to_pixels: float, max_separation, max_speed, speed_smoothing_std_dev, @@ -232,16 +273,25 @@ def calculate_position_info_from_spatial_series( upsampling_sampling_rate, upsampling_interpolation_method, ): + """Calculate position info from 2D spatial series.""" CM_TO_METERS = 100 + # Accepts x/y 'loc' or 'loc1' format for first pos. Renames to 'loc' + DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2", "xloc1", "yloc1"] + + cols = list(spatial_df.columns) + if len(cols) != 4 or not all([c in DEFAULT_COLS for c in cols]): + choice = dj.utils.user_choice( + f"Unexpected columns in raw position. Assume " + + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + ) + if choice.lower() not in ["yes", "y"]: + raise ValueError(f"Unexpected columns in raw position: {cols}") + # rename first 4 columns, keep rest. Rest dropped below + spatial_df.columns = DEFAULT_COLS[:4] + cols[4:] # Get spatial series properties - time = np.asarray(spatial_series.timestamps) # seconds - position = np.asarray( - pd.DataFrame( - spatial_series.data, - columns=spatial_series.description.split(", "), - ).loc[:, ["xloc", "yloc", "xloc2", "yloc2"]] - ) # meters + time = np.asarray(spatial_df.index) # seconds + position = np.asarray(spatial_df.iloc[:, :4]) # meters # remove NaN times is_nan_time = np.isnan(time) @@ -250,7 +300,6 @@ def calculate_position_info_from_spatial_series( dt = np.median(np.diff(time)) sampling_rate = 1 / dt - meters_to_pixels = spatial_series.conversion # Define LEDs if led1_is_front: @@ -438,7 +487,6 @@ class TrodesPosVideo(dj.Computed): definition = """ -> TrodesPosV1 - --- """ def make(self, key): @@ -446,7 +494,7 @@ def make(self, key): print("Loading position data...") raw_position_df = ( - RawPosition() + RawPosition.Object & { "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["interval_list_name"], diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index da366fb7d..4ddf7d53a 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -67,7 +67,7 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: resolved_base = ( base_dir or dj_spyglass.get("base") - or os.environ.get("SPYGLASS_BASE_DIR", ".") + or os.environ.get("SPYGLASS_BASE_DIR") ) if not resolved_base: raise ValueError( diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index e56deea4c..5d7c093eb 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -6,7 +6,7 @@ from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html -from datajoint.utils import from_camel_case, to_camel_case, get_master +from datajoint.utils import from_camel_case, get_master, to_camel_case from IPython.core.display import HTML from spyglass.common.common_nwbfile import AnalysisNwbfile @@ -772,7 +772,7 @@ def _master_table_pairs( table_list : List[dj.Table] A list of datajoint tables. restriction : str - A restriction string. Defalt True, no restriction. + A restriction string. Default True, no restriction. connection : datajoint.connection.Connection A database connection. Default None, use connection from first table. diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 15389ba71..2d493371a 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -143,13 +143,14 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): if len(ret) > 1: warnings.warn( f"Multiple data interfaces with name '{data_interface_name}' " - f"found in NWBFile with identifier {nwbfile.identifier}. Using the first one found. " + f"found in NWBFile with identifier {nwbfile.identifier}. " + + "Using the first one found. " "Use the data_interface_class argument to restrict the search." ) if len(ret) >= 1: return ret[0] - else: - return None + + return None def get_raw_eseries(nwbfile): @@ -378,25 +379,53 @@ def _get_epoch_groups(position: pynwb.behavior.Position, old_format=True): } -def _get_pos_dict(position, epoch_groups, nwbf, verbose=False, old_format=True): +def _get_pos_dict( + position: dict, + epoch_groups: dict, + session_id: str = None, + verbose: bool = False, + old_format: bool = True, # TODO: remove after changing prod database + incl_times: bool = True, +): + """Return dict with obj ids and valid intervals for each epoch. + + Parameters + ---------- + position : hdmf.utils.LabeledDict + pynwb.behavior.Position.spatial_series + epoch_groups : dict + Epoch start times as keys, spatial series indices as values + session_id : str, optional + Optional session ID for verbose print during sampling rate estimation + verbose : bool, optional + Default to False. Print estimated sampling rate + incl_times : bool, optional + Default to True. Include valid intervals. Requires additional + computation not needed for RawPosition + """ # prev, this was just a list. now, we need to gen mult dict per epoch pos_data_dict = dict() - all_spatial_series = list(position.spatial_series.values()) + all_spatial_series = list(position.values()) if old_format: # for index, orig_epoch in enumerate(sorted_order): for index, orig_epoch in enumerate(epoch_groups): spatial_series = all_spatial_series[orig_epoch] + # get the valid intervals for the position data - timestamps = np.asarray(spatial_series.timestamps) - sampling_rate = estimate_sampling_rate( - timestamps, verbose=verbose, filename=nwbf.session_id - ) - # add the valid intervals to the Interval list - pos_data_dict[index] = { - "valid_times": get_valid_intervals( + valid_times = None + if incl_times: + timestamps = np.asarray(spatial_series.timestamps) + sampling_rate = estimate_sampling_rate( + timestamps, verbose=verbose, filename=session_id + ) + valid_times = get_valid_intervals( timestamps=timestamps, sampling_rate=sampling_rate, - ), + ) + + # add the valid intervals to the Interval list + pos_data_dict[index] = { + "valid_times": valid_times, "raw_position_object_id": spatial_series.object_id, } @@ -406,26 +435,35 @@ def _get_pos_dict(position, epoch_groups, nwbf, verbose=False, old_format=True): for index in index_list: spatial_series = all_spatial_series[index] # get the valid intervals for the position data - timestamps = np.asarray(spatial_series.timestamps) - sampling_rate = estimate_sampling_rate( - timestamps, verbose=verbose, filename=nwbf.session_id - ) + valid_times = None + if incl_times: + timestamps = np.asarray(spatial_series.timestamps) + sampling_rate = estimate_sampling_rate( + timestamps, verbose=verbose, filename=session_id + ) + valid_times = get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + ) # add the valid intervals to the Interval list pos_data_dict[epoch].append( { - "valid_times": get_valid_intervals( - timestamps=timestamps, - sampling_rate=sampling_rate, - ), + "valid_times": valid_times, "raw_position_object_id": spatial_series.object_id, + "name": spatial_series.name, } ) return pos_data_dict -def get_all_spatial_series(nwbf, verbose=False, old_format=True): - """Given an NWBFile, get the spatial series and interval lists from the file and return a dictionary by epoch. +def get_all_spatial_series( + nwbf, verbose=False, old_format=True, incl_times=True +) -> dict: + """ + Given an NWB, get the spatial series and return a dictionary by epoch. + + If incl_times is True, then the valid intervals are included in the output. Parameters ---------- @@ -433,24 +471,32 @@ def get_all_spatial_series(nwbf, verbose=False, old_format=True): The source NWB file object. verbose : bool Flag representing whether to print the sampling rate. + incl_times : bool + Include valid times in the output. Default, True. Set to False for only + spatial series object IDs. Returns ------- pos_data_dict : dict - Dict mapping indices to a dict with keys 'valid_times' and 'raw_position_object_id'. Returns None if there - is no position data in the file. The 'raw_position_object_id' is the object ID of the SpatialSeries object. + Dict mapping indices to a dict with keys 'valid_times' and + 'raw_position_object_id'. Returns None if there is no position data in + the file. The 'raw_position_object_id' is the object ID of the + SpatialSeries object. """ - position = get_data_interface(nwbf, "position", pynwb.behavior.Position) + pos_interface = get_data_interface( + nwbf, "position", pynwb.behavior.Position + ) - if position is None: + if pos_interface is None: return None return _get_pos_dict( - position=position, - epoch_groups=_get_epoch_groups(position, old_format=old_format), - nwbf=nwbf, + position=pos_interface.spatial_series, + epoch_groups=_get_epoch_groups(pos_interface, old_format=old_format), + session_id=nwbf.session_id, verbose=verbose, old_format=old_format, + incl_times=incl_times, )