Skip to content

Commit

Permalink
IntervalPositionInfo column ordering on fetch1_dataframe (LorenFrankL…
Browse files Browse the repository at this point in the history
…ab#673)

* Fix gitignore

* Fix LorenFrankLab#671

* Update changelog

* add static method decorator
  • Loading branch information
CBroz1 authored Nov 2, 2023
1 parent d441c6f commit c97d8e2
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 80 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- Migrate `config` helper scripts to Spyglass codebase. #662
- Revise contribution guidelines. #655
- Minor bug fixes. #656, #657, #659, #651
- Minor bug fixes. #656, #657, #659, #651, #671

## [0.4.2] (October 10, 2023)

Expand Down
223 changes: 146 additions & 77 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,14 @@ def generate_pos_components(
}

@staticmethod
def calculate_position_info(
spatial_df: pd.DataFrame,
meters_to_pixels: float,
position_smoothing_duration,
led1_is_front,
is_upsampled,
upsampling_sampling_rate,
upsampling_interpolation_method,
def _fix_kwargs(
orient_smoothing_std_dev=None,
speed_smoothing_std_dev=None,
max_LED_separation=None,
max_plausible_speed=None,
**kwargs,
):
CM_TO_METERS = 100

"""Handles discrepancies between common and v1 param names."""
if not orient_smoothing_std_dev:
orient_smoothing_std_dev = kwargs.get(
"head_orient_smoothing_std_dev"
Expand All @@ -242,36 +234,149 @@ def calculate_position_info(
[speed_smoothing_std_dev, max_LED_separation, max_plausible_speed]
):
raise ValueError(
"Missing required parameters:\n\t"
"Missing at least one required parameter:\n\t"
+ f"speed_smoothing_std_dev: {speed_smoothing_std_dev}\n\t"
+ f"max_LED_separation: {max_LED_separation}\n\t"
+ f"max_plausible_speed: {max_plausible_speed}"
)
return (
orient_smoothing_std_dev,
speed_smoothing_std_dev,
max_LED_separation,
max_plausible_speed,
)

# Accepts x/y 'loc' or 'loc1' format for first pos. Renames to 'loc'
DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2", "xloc1", "yloc1"]
ALTERNATIVE_COLS = ["xloc1", "xloc2", "yloc1", "yloc2"]

if all([c in spatial_df.columns for c in DEFAULT_COLS[:4]]):
# move the 4 position columns to front, continue
spatial_df = spatial_df[DEFAULT_COLS[:4]]
elif all([c in spatial_df.columns for c in ALTERNATIVE_COLS]):
# move the 4 position columns to front, rename to default, continue
spatial_df = spatial_df[ALTERNATIVE_COLS]
spatial_df.columns = DEFAULT_COLS[:4]
else:
cols = list(spatial_df.columns)
if len(cols) != 4 or not all([c in DEFAULT_COLS for c in cols]):
choice = dj.utils.user_choice(
"Unexpected columns in raw position. Assume "
+ f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n"
@staticmethod
def _fix_col_names(spatial_df):
"""Renames columns in spatial dataframe according to previous norm
Accepts unnamed first led, 1 or 0 indexed.
Prompts user for confirmation of renaming unexpected columns.
For backwards compatibility, renames to "xloc", "yloc", "xloc2", "yloc2"
"""

DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"]
ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"]
ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"]

input_cols = list(spatial_df.columns)

has_default = all([c in input_cols for c in DEFAULT_COLS])
has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS])
has_1_idx = all([c in input_cols for c in ONE_IDX_COLS])

# if unexpected columns, ask user to confirm
if len(input_cols) != 4 or not (has_default or has_0_idx or has_1_idx):
choice = dj.utils.user_choice(
"Unexpected columns in raw position. Assume "
+ f"{DEFAULT_COLS[:4]}?\n{spatial_df}\n"
)
if choice.lower() not in ["yes", "y"]:
raise ValueError(
f"Unexpected columns in raw position: {input_cols}"
)
if choice.lower() not in ["yes", "y"]:
raise ValueError(
f"Unexpected columns in raw position: {cols}"
)
# rename first 4 columns, keep rest. Rest dropped below
spatial_df.columns = DEFAULT_COLS[:4] + cols[4:]
spatial_df.columns = DEFAULT_COLS + input_cols[4:]

# Ensure data order, only 4 col
spatial_df = (
spatial_df[DEFAULT_COLS]
if has_default
else spatial_df[ZERO_IDX_COLS]
if has_0_idx
else spatial_df[ONE_IDX_COLS]
)

# rename to default
spatial_df.columns = DEFAULT_COLS

return spatial_df

@staticmethod
def _upsample(
front_LED,
back_LED,
time,
sampling_rate,
upsampling_sampling_rate,
upsampling_interpolation_method,
**kwargs,
):
position_df = pd.DataFrame(
{
"time": time,
"back_LED_x": back_LED[:, 0],
"back_LED_y": back_LED[:, 1],
"front_LED_x": front_LED[:, 0],
"front_LED_y": front_LED[:, 1],
}
).set_index("time")

upsampling_start_time = time[0]
upsampling_end_time = time[-1]

n_samples = (
int(
np.ceil(
(upsampling_end_time - upsampling_start_time)
* upsampling_sampling_rate
)
)
+ 1
)
new_time = np.linspace(
upsampling_start_time, upsampling_end_time, n_samples
)
new_index = pd.Index(
np.unique(np.concatenate((position_df.index, new_time))),
name="time",
)
position_df = (
position_df.reindex(index=new_index)
.interpolate(method=upsampling_interpolation_method)
.reindex(index=new_time)
)

time = np.asarray(position_df.index)
back_LED = np.asarray(position_df.loc[:, ["back_LED_x", "back_LED_y"]])
front_LED = np.asarray(
position_df.loc[:, ["front_LED_x", "front_LED_y"]]
)

sampling_rate = upsampling_sampling_rate

return front_LED, back_LED, time, sampling_rate

def calculate_position_info(
self,
spatial_df: pd.DataFrame,
meters_to_pixels: float,
position_smoothing_duration,
led1_is_front,
is_upsampled,
upsampling_sampling_rate,
upsampling_interpolation_method,
orient_smoothing_std_dev=None,
speed_smoothing_std_dev=None,
max_LED_separation=None,
max_plausible_speed=None,
**kwargs,
):
CM_TO_METERS = 100

(
orient_smoothing_std_dev,
speed_smoothing_std_dev,
max_LED_separation,
max_plausible_speed,
) = self._fix_kwargs(
orient_smoothing_std_dev,
speed_smoothing_std_dev,
max_LED_separation,
max_plausible_speed,
**kwargs,
)

spatial_df = self._fix_col_names(spatial_df)
# Get spatial series properties
time = np.asarray(spatial_df.index) # seconds
position = np.asarray(spatial_df.iloc[:, :4]) # meters
Expand Down Expand Up @@ -338,51 +443,15 @@ def calculate_position_info(
)

if is_upsampled:
position_df = pd.DataFrame(
{
"time": time,
"back_LED_x": back_LED[:, 0],
"back_LED_y": back_LED[:, 1],
"front_LED_x": front_LED[:, 0],
"front_LED_y": front_LED[:, 1],
}
).set_index("time")

upsampling_start_time = time[0]
upsampling_end_time = time[-1]

n_samples = (
int(
np.ceil(
(upsampling_end_time - upsampling_start_time)
* upsampling_sampling_rate
)
)
+ 1
)
new_time = np.linspace(
upsampling_start_time, upsampling_end_time, n_samples
)
new_index = pd.Index(
np.unique(np.concatenate((position_df.index, new_time))),
name="time",
)
position_df = (
position_df.reindex(index=new_index)
.interpolate(method=upsampling_interpolation_method)
.reindex(index=new_time)
front_LED, back_LED, time, sampling_rate = self._upsample(
front_LED,
back_LED,
time,
sampling_rate,
upsampling_sampling_rate,
upsampling_interpolation_method,
)

time = np.asarray(position_df.index)
back_LED = np.asarray(
position_df.loc[:, ["back_LED_x", "back_LED_y"]]
)
front_LED = np.asarray(
position_df.loc[:, ["front_LED_x", "front_LED_y"]]
)

sampling_rate = upsampling_sampling_rate

# Calculate position, orientation, velocity, speed
position = get_centriod(back_LED, front_LED) # cm

Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def make(self, key):

@staticmethod
def generate_pos_components(*args, **kwargs):
return IntervalPositionInfo.generate_pos_components(*args, **kwargs)
return IntervalPositionInfo().generate_pos_components(*args, **kwargs)

@staticmethod
def calculate_position_info(*args, **kwargs):
"""Calculate position info from 2D spatial series."""
return IntervalPositionInfo.calculate_position_info(*args, **kwargs)
return IntervalPositionInfo().calculate_position_info(*args, **kwargs)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
Expand Down

0 comments on commit c97d8e2

Please sign in to comment.