diff --git a/CHANGELOG.md b/CHANGELOG.md index db2d7182d..80f2f4a13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ - Migrate `config` helper scripts to Spyglass codebase. #662 - Revise contribution guidelines. #655 -- Minor bug fixes. #656, #657, #659, #651 +- Minor bug fixes. #656, #657, #659, #651, #671 ## [0.4.2] (October 10, 2023) diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 0150561cf..259958181 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -212,22 +212,14 @@ def generate_pos_components( } @staticmethod - def calculate_position_info( - spatial_df: pd.DataFrame, - meters_to_pixels: float, - position_smoothing_duration, - led1_is_front, - is_upsampled, - upsampling_sampling_rate, - upsampling_interpolation_method, + def _fix_kwargs( orient_smoothing_std_dev=None, speed_smoothing_std_dev=None, max_LED_separation=None, max_plausible_speed=None, **kwargs, ): - CM_TO_METERS = 100 - + """Handles discrepancies between common and v1 param names.""" if not orient_smoothing_std_dev: orient_smoothing_std_dev = kwargs.get( "head_orient_smoothing_std_dev" @@ -242,36 +234,149 @@ def calculate_position_info( [speed_smoothing_std_dev, max_LED_separation, max_plausible_speed] ): raise ValueError( - "Missing required parameters:\n\t" + "Missing at least one required parameter:\n\t" + f"speed_smoothing_std_dev: {speed_smoothing_std_dev}\n\t" + f"max_LED_separation: {max_LED_separation}\n\t" + f"max_plausible_speed: {max_plausible_speed}" ) + return ( + orient_smoothing_std_dev, + speed_smoothing_std_dev, + max_LED_separation, + max_plausible_speed, + ) - # Accepts x/y 'loc' or 'loc1' format for first pos. Renames to 'loc' - DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2", "xloc1", "yloc1"] - ALTERNATIVE_COLS = ["xloc1", "xloc2", "yloc1", "yloc2"] - - if all([c in spatial_df.columns for c in DEFAULT_COLS[:4]]): - # move the 4 position columns to front, continue - spatial_df = spatial_df[DEFAULT_COLS[:4]] - elif all([c in spatial_df.columns for c in ALTERNATIVE_COLS]): - # move the 4 position columns to front, rename to default, continue - spatial_df = spatial_df[ALTERNATIVE_COLS] - spatial_df.columns = DEFAULT_COLS[:4] - else: - cols = list(spatial_df.columns) - if len(cols) != 4 or not all([c in DEFAULT_COLS for c in cols]): - choice = dj.utils.user_choice( - "Unexpected columns in raw position. Assume " - + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + @staticmethod + def _fix_col_names(spatial_df): + """Renames columns in spatial dataframe according to previous norm + + Accepts unnamed first led, 1 or 0 indexed. + Prompts user for confirmation of renaming unexpected columns. + For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2" + """ + + DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] + ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] + ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] + + input_cols = list(spatial_df.columns) + + has_default = all([c in input_cols for c in DEFAULT_COLS]) + has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) + has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) + + # if unexpected columns, ask user to confirm + if len(input_cols) != 4 or not (has_default or has_0_idx or has_1_idx): + choice = dj.utils.user_choice( + "Unexpected columns in raw position. Assume " + + f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n" + ) + if choice.lower() not in ["yes", "y"]: + raise ValueError( + f"Unexpected columns in raw position: {input_cols}" ) - if choice.lower() not in ["yes", "y"]: - raise ValueError( - f"Unexpected columns in raw position: {cols}" - ) - # rename first 4 columns, keep rest. Rest dropped below - spatial_df.columns = DEFAULT_COLS[:4] + cols[4:] + spatial_df.columns = DEFAULT_COLS + input_cols[4:] + + # Ensure data order, only 4 col + spatial_df = ( + spatial_df[DEFAULT_COLS] + if has_default + else spatial_df[ZERO_IDX_COLS] + if has_0_idx + else spatial_df[ONE_IDX_COLS] + ) + + # rename to default + spatial_df.columns = DEFAULT_COLS + + return spatial_df + + @staticmethod + def _upsample( + front_LED, + back_LED, + time, + sampling_rate, + upsampling_sampling_rate, + upsampling_interpolation_method, + **kwargs, + ): + position_df = pd.DataFrame( + { + "time": time, + "back_LED_x": back_LED[:, 0], + "back_LED_y": back_LED[:, 1], + "front_LED_x": front_LED[:, 0], + "front_LED_y": front_LED[:, 1], + } + ).set_index("time") + + upsampling_start_time = time[0] + upsampling_end_time = time[-1] + + n_samples = ( + int( + np.ceil( + (upsampling_end_time - upsampling_start_time) + * upsampling_sampling_rate + ) + ) + + 1 + ) + new_time = np.linspace( + upsampling_start_time, upsampling_end_time, n_samples + ) + new_index = pd.Index( + np.unique(np.concatenate((position_df.index, new_time))), + name="time", + ) + position_df = ( + position_df.reindex(index=new_index) + .interpolate(method=upsampling_interpolation_method) + .reindex(index=new_time) + ) + + time = np.asarray(position_df.index) + back_LED = np.asarray(position_df.loc[:, ["back_LED_x", "back_LED_y"]]) + front_LED = np.asarray( + position_df.loc[:, ["front_LED_x", "front_LED_y"]] + ) + + sampling_rate = upsampling_sampling_rate + + return front_LED, back_LED, time, sampling_rate + + def calculate_position_info( + self, + spatial_df: pd.DataFrame, + meters_to_pixels: float, + position_smoothing_duration, + led1_is_front, + is_upsampled, + upsampling_sampling_rate, + upsampling_interpolation_method, + orient_smoothing_std_dev=None, + speed_smoothing_std_dev=None, + max_LED_separation=None, + max_plausible_speed=None, + **kwargs, + ): + CM_TO_METERS = 100 + + ( + orient_smoothing_std_dev, + speed_smoothing_std_dev, + max_LED_separation, + max_plausible_speed, + ) = self._fix_kwargs( + orient_smoothing_std_dev, + speed_smoothing_std_dev, + max_LED_separation, + max_plausible_speed, + **kwargs, + ) + + spatial_df = self._fix_col_names(spatial_df) # Get spatial series properties time = np.asarray(spatial_df.index) # seconds position = np.asarray(spatial_df.iloc[:, :4]) # meters @@ -338,51 +443,15 @@ def calculate_position_info( ) if is_upsampled: - position_df = pd.DataFrame( - { - "time": time, - "back_LED_x": back_LED[:, 0], - "back_LED_y": back_LED[:, 1], - "front_LED_x": front_LED[:, 0], - "front_LED_y": front_LED[:, 1], - } - ).set_index("time") - - upsampling_start_time = time[0] - upsampling_end_time = time[-1] - - n_samples = ( - int( - np.ceil( - (upsampling_end_time - upsampling_start_time) - * upsampling_sampling_rate - ) - ) - + 1 - ) - new_time = np.linspace( - upsampling_start_time, upsampling_end_time, n_samples - ) - new_index = pd.Index( - np.unique(np.concatenate((position_df.index, new_time))), - name="time", - ) - position_df = ( - position_df.reindex(index=new_index) - .interpolate(method=upsampling_interpolation_method) - .reindex(index=new_time) + front_LED, back_LED, time, sampling_rate = self._upsample( + front_LED, + back_LED, + time, + sampling_rate, + upsampling_sampling_rate, + upsampling_interpolation_method, ) - time = np.asarray(position_df.index) - back_LED = np.asarray( - position_df.loc[:, ["back_LED_x", "back_LED_y"]] - ) - front_LED = np.asarray( - position_df.loc[:, ["front_LED_x", "front_LED_y"]] - ) - - sampling_rate = upsampling_sampling_rate - # Calculate position, orientation, velocity, speed position = get_centriod(back_LED, front_LED) # cm diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 5a3c5b9b1..4f0500949 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -205,12 +205,12 @@ def make(self, key): @staticmethod def generate_pos_components(*args, **kwargs): - return IntervalPositionInfo.generate_pos_components(*args, **kwargs) + return IntervalPositionInfo().generate_pos_components(*args, **kwargs) @staticmethod def calculate_position_info(*args, **kwargs): """Calculate position info from 2D spatial series.""" - return IntervalPositionInfo.calculate_position_info(*args, **kwargs) + return IntervalPositionInfo().calculate_position_info(*args, **kwargs) def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb(