diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index bb86390e9..32215124c 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -152,7 +152,7 @@ class RawPosition(dj.Imported): -> PositionSource """ - class Object(dj.Part): + class PosObject(dj.Part): definition = """ -> master -> PositionSource.SpatialSeries.proj('id') @@ -170,6 +170,9 @@ def fetch1_dataframe(self): id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()] + if len(set(rp.interval for _, rp in id_rp)) > 1: + print("WARNING: loading DataFrame with multiple intervals.") + df_list = [ pd.DataFrame( data=rp.data, @@ -200,7 +203,7 @@ def make(self, key): ] self.insert1(key) - self.Object.insert( + self.PosObject.insert( [ dict( nwb_file_name=nwb_file_name, @@ -213,15 +216,34 @@ def make(self, key): ] ) - def fetch_nwb(self, *attrs, **kwargs): - raise NotImplementedError( - "fetch_nwb now operates on RawPosition.Object" - ) + def fetch_nwb(self, *attrs, **kwargs) -> list: + """ + Returns a condatenated list of nwb objects from RawPosition.PosObject + """ + ret = [] + for pos_obj in self.PosObject: + ret.append([nwb for nwb in pos_obj.fetch_nwb(*attrs, **kwargs)]) + return ret def fetch1_dataframe(self): - raise NotImplementedError( - "fetch1_dataframe now operates on RawPosition.Object" - ) + """Returns a dataframe with all RawPosition.PosObject items. + + Uses interval_list_name as column index. + """ + ret = {} + + pos_obj_set = self.PosObject & self.restriction + unique_intervals = set(pos_obj_set.fetch("interval_list_name")) + + for interval in unique_intervals: + ret[interval] = ( + pos_obj_set & {"interval_list_name": interval} + ).fetch1_dataframe() + + if len(unique_intervals) == 1: + return next(iter(ret.values())) + + return pd.concat(ret, axis=1) @schema diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index da7456232..5ff40c990 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -180,10 +180,10 @@ def make(self, key): ) position_info_parameters = (TrodesPosParams() & key).fetch1("params") - spatial_series = (RawPosition.Object & key).fetch_nwb()[0][ + spatial_series = (RawPosition.PosObject & key).fetch_nwb()[0][ "raw_position" ] - spatial_df = (RawPosition.Object & key).fetch1_dataframe() + spatial_df = (RawPosition.PosObject & key).fetch1_dataframe() video_frame_ind = getattr(spatial_df, "video_frame_ind", None) position = pynwb.behavior.Position() @@ -295,8 +295,8 @@ def make(self, key): def calculate_position_info_from_spatial_series( spatial_df: pd.DataFrame, meters_to_pixels: float, - max_separation, - max_speed, + max_LED_separation, + max_plausible_speed, speed_smoothing_std_dev, position_smoothing_duration, orient_smoothing_std_dev, @@ -347,7 +347,7 @@ def calculate_position_info_from_spatial_series( # Set points to NaN where the front and back LEDs are too separated dist_between_LEDs = get_distance(back_LED, front_LED) - is_too_separated = dist_between_LEDs >= max_separation + is_too_separated = dist_between_LEDs >= max_LED_separation back_LED[is_too_separated] = np.nan front_LED[is_too_separated] = np.nan @@ -367,8 +367,8 @@ def calculate_position_info_from_spatial_series( ) # Set to points to NaN where the speed is too fast - is_too_fast = (front_LED_speed > max_speed) | ( - back_LED_speed > max_speed + is_too_fast = (front_LED_speed > max_plausible_speed) | ( + back_LED_speed > max_plausible_speed ) back_LED[is_too_fast] = np.nan front_LED[is_too_fast] = np.nan @@ -528,7 +528,7 @@ def make(self, key): print("Loading position data...") raw_position_df = ( - RawPosition.Object + RawPosition.PosObject & { "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["interval_list_name"],