From 5f20d79a09bd52363988da329ae72fcb56e16f67 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 10:32:21 -0500 Subject: [PATCH 01/10] feat(block_analysis): add `patch_threshold` at `BlockSubjectAnalysis` level --- aeon/dj_pipeline/analysis/block_analysis.py | 28 +++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 5ddd4968..6a3d3197 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,16 +1,17 @@ import json +from datetime import datetime + import datajoint as dj import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objs as go from matplotlib import path as mpl_path -from datetime import datetime -from aeon.io import api as io_api from aeon.analysis import utils as analysis_utils from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods +from aeon.io import api as io_api schema = dj.schema(get_schema_name("block_analysis")) logger = dj.logger @@ -34,8 +35,7 @@ class BlockDetection(dj.Computed): """ def make(self, key): - """ - On a per-chunk basis, check for the presence of new block, insert into Block table. + """On a per-chunk basis, check for the presence of new block, insert into Block table. High level logic 1. Find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) 2. Remove any double 0s (0s within 1 second of each other) (pick the first 0) @@ -353,6 +353,7 @@ class Patch(dj.Part): pellet_count: int pellet_timestamps: longblob patch_threshold: longblob # patch threshold value at each pellet delivery + patch_threshold: longblob # patch threshold value at each pellet delivery wheel_cumsum_distance_travelled: longblob # wheel's cumulative distance travelled """ @@ -492,9 +493,9 @@ def make(self, key): ) subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt - all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ - "cum_time" - ] = subject_in_patch_cum_time + all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_time"] = ( + subject_in_patch_cum_time + ) closest_subj_mask = closest_subjects_pellet_ts == subject_name subj_pellets = closest_subjects_pellet_ts[closest_subj_mask] @@ -663,10 +664,10 @@ class BlockSubjectPlots(dj.Computed): def make(self, key): from aeon.analysis.block_plotting import ( - subject_colors, - patch_markers_linestyles, - patch_markers, gen_hex_grad, + patch_markers, + patch_markers_linestyles, + subject_colors, ) patch_names, subject_names = (BlockSubjectAnalysis.Preference & key).fetch( @@ -790,8 +791,7 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """ - Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other - Remove all events without threshold value (NaN) @@ -889,5 +889,7 @@ def get_threshold_associated_pellets(patch_key, start, end): # Shift back the pellet_timestamp values by 1 to match with the previous threshold update pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) - pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp", "beam_break_timestamp"]) + pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( + subset=["pellet_timestamp", "beam_break_timestamp"] + ) return pellet_ts_threshold_df From a53ec10f27f079ddc28826bc01b61445bc08cd57 Mon Sep 17 00:00:00 2001 From: Jai Date: Fri, 13 Sep 2024 11:01:51 +0100 Subject: [PATCH 02/10] added working 'get_foraging_bouts' function --- aeon/dj_pipeline/analysis/block_analysis.py | 113 ++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 6a3d3197..1702e037 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -893,3 +893,116 @@ def get_threshold_associated_pellets(patch_key, start, end): subset=["pellet_timestamp", "beam_break_timestamp"] ) return pellet_ts_threshold_df + + +"""Foraging bout function.""" + + +def get_foraging_bouts( + key: dict, + min_pellets: int = 3, + max_inactive_time: pd.Timedelta | None = None, # seconds + min_wheel_movement: float = 10, # cm +) -> pd.DataFrame: + """Gets foraging bouts for all subjects across all patches within a block. + + Args: + key: Block key - dict containing keys for 'experiment_name', 'block_start', 'block_end'. + min_pellets: Minimum number of pellets for a foraging bout. + max_inactive_time: Maximum time between `min_wheel_movement`s for a foraging bout. + min_wheel_movement: Minimum wheel movement for a foraging bout. + + Returns: + DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. + """ + max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") + subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True) + subject_names = subject_patch_data.index.get_level_values("subject_name").unique() + wheel_ts = (BlockAnalysis.Patch & key).fetch("wheel_timestamps")[0] + # For each subject: + # - Create cumulative wheel distance spun sum df combining all patches + # - Columns: timestamp, wheel distance, patch + # - Discretize into 'possible foraging events' based on `max_inactive_time`, and `min_wheel_movement` + # - Filter out events with < `min_pellets` + # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF + bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) + for subject in subject_names: + cur_subject_data = subject_patch_data.xs(subject, level="subject_name") + # Create combined cumulative wheel distance spun: ensure equal length wheel vals across patches + wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values + min_len = min(len(arr) for arr in wheel_vals) + comb_cum_wheel_dist = np.sum([arr[:min_len] for arr in wheel_vals], axis=0) + wheel_ts, comb_cum_wheel_dist = ( # ensure equal length wheel vals and wheel ts + arr[: min(len(wheel_ts), len(comb_cum_wheel_dist))] for arr in [wheel_ts, comb_cum_wheel_dist] + ) + # For each wheel_ts, get the correspdoning patch that was spun + patch_spun = np.empty(len(wheel_ts), dtype=" wheel_spun_thresh)[0] + patch_spun[spun_indices] = patch_name + patch_spun_df = pd.DataFrame(index=wheel_ts, columns=["cum_wheel_dist", "patch_spun"]) + patch_spun_df["cum_wheel_dist"] = comb_cum_wheel_dist + patch_spun_df["patch_spun"] = patch_spun + wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") + win_len = int(max_inactive_time / wheel_s_r) + # Find times when foraging + max_windowed_wheel_vals = ( + patch_spun_df["cum_wheel_dist"] + .shift(-(win_len - 1)) + .rolling(window=win_len, min_periods=1) + .max() + ) + foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) + # Discretize into foraging bouts + bout_start_indxs = np.where(np.diff(foraging_mask.astype(int), prepend=0) == 1)[0] + bout_end_indxs = np.where(np.diff(foraging_mask.astype(int), prepend=0) == -1)[0] + assert len(bout_start_indxs) == len(bout_end_indxs) + bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds + "timedelta64[ns]" + ).astype(float) / 1e9 + bout_starts_ends = np.array( + [ + (wheel_ts[start_idx], wheel_ts[end_idx]) + for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) + ] + ) + all_pel_ts = np.sort( + np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) + ) + bout_pellets = np.array( + [ + len(all_pel_ts[(all_pel_ts >= start) & (all_pel_ts <= end)]) + for start, end in bout_starts_ends + ] + ) + # Filter by `min_pellets` + bout_durations = bout_durations[bout_pellets >= min_pellets] + bout_starts_ends = bout_starts_ends[bout_pellets >= min_pellets] + bout_pellets = bout_pellets[bout_pellets >= min_pellets] + bout_cum_wheel_dist = np.array( + [ + patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] + for start, end in bout_starts_ends + ] + ) + # Add to returned DF + bout_data = pd.concat( + [ + bout_data, + pd.DataFrame( + { + "start": bout_starts_ends[:, 0], + "end": bout_starts_ends[:, 1], + "n_pellets": bout_pellets, + "cum_wheel_dist": bout_cum_wheel_dist, + "subject": subject, + } + ), + ] + ) + return bout_data.sort_values("start").reset_index(drop=True) From 7189c2ef246f08e3b1883b8e6b93165fb67e4648 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 20 Sep 2024 17:28:33 +0100 Subject: [PATCH 03/10] Re-add missing `patch_threshold` --- aeon/dj_pipeline/analysis/block_analysis.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 1702e037..0243caac 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -503,16 +503,16 @@ def make(self, key): self.Patch.insert1( key - | dict( - patch_name=patch["patch_name"], - subject_name=subject_name, - in_patch_timestamps=subject_in_patch.index.values, - in_patch_time=subject_in_patch_cum_time[-1], - pellet_count=len(subj_pellets), - pellet_timestamps=subj_pellets.index.values, - patch_threshold=subj_patch_thresh, - wheel_cumsum_distance_travelled=cum_wheel_dist_subj_df[subject_name].values, - ) + | { + "patch_name": patch["patch_name"], + "subject_name": subject_name, + "in_patch_timestamps": subject_in_patch.index.values, + "in_patch_time": subject_in_patch_cum_time[-1], + "pellet_count": len(subj_pellets), + "pellet_timestamps": subj_pellets.index.values, + "patch_threshold": subj_patch_thresh, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, + } ) # Now that we have computed all individual patch and subject values, we iterate again through From 493ba9117e034e5dd8337e3297a6fd4bf6f766c1 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 20 Sep 2024 17:57:19 +0100 Subject: [PATCH 04/10] Simplify `get_foraging_bouts` --- aeon/dj_pipeline/analysis/block_analysis.py | 53 ++++++++++----------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 0243caac..16a76a10 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -907,7 +907,7 @@ def get_foraging_bouts( """Gets foraging bouts for all subjects across all patches within a block. Args: - key: Block key - dict containing keys for 'experiment_name', 'block_start', 'block_end'. + key: Block key - dict containing keys for 'experiment_name' and 'block_start'. min_pellets: Minimum number of pellets for a foraging bout. max_inactive_time: Maximum time between `min_wheel_movement`s for a foraging bout. min_wheel_movement: Minimum wheel movement for a foraging bout. @@ -918,49 +918,46 @@ def get_foraging_bouts( max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True) - subject_names = subject_patch_data.index.get_level_values("subject_name").unique() - wheel_ts = (BlockAnalysis.Patch & key).fetch("wheel_timestamps")[0] + wheel_ts = (BlockAnalysis.Patch() & key).fetch("wheel_timestamps")[0] # For each subject: # - Create cumulative wheel distance spun sum df combining all patches # - Columns: timestamp, wheel distance, patch # - Discretize into 'possible foraging events' based on `max_inactive_time`, and `min_wheel_movement` + # - Look ahead by `max_inactive_time` and compare with current wheel distance, + # if the wheel will have moved by `min_wheel_movement`, then it is a foraging event + # - Because we "looked ahead" (shifted), we need to readjust the start time of a foraging bout + # - For the foraging bout end time, we need to account for the final pellet delivery time # - Filter out events with < `min_pellets` # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) - for subject in subject_names: + for subject in subject_patch_data.index.unique("subject_name"): cur_subject_data = subject_patch_data.xs(subject, level="subject_name") - # Create combined cumulative wheel distance spun: ensure equal length wheel vals across patches + # Create combined cumulative wheel distance spun wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values - min_len = min(len(arr) for arr in wheel_vals) + # Ensure equal length wheel_vals across patches and wheel_ts + min_len = min(*(len(arr) for arr in wheel_vals), len(wheel_ts)) comb_cum_wheel_dist = np.sum([arr[:min_len] for arr in wheel_vals], axis=0) - wheel_ts, comb_cum_wheel_dist = ( # ensure equal length wheel vals and wheel ts - arr[: min(len(wheel_ts), len(comb_cum_wheel_dist))] for arr in [wheel_ts, comb_cum_wheel_dist] - ) - # For each wheel_ts, get the correspdoning patch that was spun - patch_spun = np.empty(len(wheel_ts), dtype=" wheel_spun_thresh)[0] - patch_spun[spun_indices] = patch_name - patch_spun_df = pd.DataFrame(index=wheel_ts, columns=["cum_wheel_dist", "patch_spun"]) - patch_spun_df["cum_wheel_dist"] = comb_cum_wheel_dist - patch_spun_df["patch_spun"] = patch_spun + diffs = np.diff( + np.stack(cur_subject_data["wheel_cumsum_distance_travelled"].values), axis=1, prepend=0 + ) + spun_indices = np.where(diffs > wheel_spun_thresh) + patch_spun[spun_indices[1]] = patch_names[spun_indices[0]] + patch_spun_df = pd.DataFrame( + {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts + ) wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = ( - patch_spun_df["cum_wheel_dist"] - .shift(-(win_len - 1)) - .rolling(window=win_len, min_periods=1) - .max() - ) + max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(win_len - 1)).ffill() foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask.astype(int), prepend=0) == 1)[0] - bout_end_indxs = np.where(np.diff(foraging_mask.astype(int), prepend=0) == -1)[0] + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (win_len - 1) + bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + 4 # TODO: Change this assert len(bout_start_indxs) == len(bout_end_indxs) bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" From 4021c84df227b3b32d73297c192d754221151917 Mon Sep 17 00:00:00 2001 From: Jai Date: Wed, 25 Sep 2024 17:42:55 +0100 Subject: [PATCH 05/10] slight cleanup --- aeon/analysis/block_plotting.py | 39 +++++++-------------- aeon/dj_pipeline/analysis/block_analysis.py | 10 +++--- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 67ebed32..96387b19 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -1,32 +1,9 @@ +"""Helper functions for plotting block data.""" + from colorsys import hls_to_rgb, rgb_to_hls import numpy as np -import plotly - -"""Standardize subject colors, patch colors, and markers.""" - -subject_colors = plotly.colors.qualitative.Plotly -subject_colors_dict = { - "BAA-1104045": subject_colors[0], - "BAA-1104047": subject_colors[1], - "BAA-1104048": subject_colors[2], - "BAA-1104049": subject_colors[3], -} -patch_colors = plotly.colors.qualitative.Dark2 -patch_markers = [ - "circle", - "bowtie", - "square", - "hourglass", - "diamond", - "cross", - "x", - "triangle", - "star", -] -patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"] -patch_markers_dict = dict(zip(patch_markers, patch_markers_symbols, strict=False)) -patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"] +from numpy.lib.stride_tricks import as_strided def gen_hex_grad(hex_col, vals, min_l=0.3): @@ -44,3 +21,13 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): grad[i] = cur_hex_col return grad + + +def conv2d(arr, kernel): + """Performs "valid" 2d convolution using numpy `as_strided` and `einsum`.""" + out_shape = tuple(np.subtract(arr.shape, kernel.shape) + 1) + sub_mat_shape = kernel.shape + out_shape + # Create "new view" of `arr` as submatrices at which kernel will be applied + sub_mats = as_strided(arr, shape=sub_mat_shape, strides=(arr.strides * 2)) + out = np.einsum("ij, ijkl -> kl", kernel, sub_mats) + return out diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 16a76a10..f0478c61 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -902,7 +902,7 @@ def get_foraging_bouts( key: dict, min_pellets: int = 3, max_inactive_time: pd.Timedelta | None = None, # seconds - min_wheel_movement: float = 10, # cm + min_wheel_movement: float = 5, # cm ) -> pd.DataFrame: """Gets foraging bouts for all subjects across all patches within a block. @@ -923,7 +923,7 @@ def get_foraging_bouts( # - Create cumulative wheel distance spun sum df combining all patches # - Columns: timestamp, wheel distance, patch # - Discretize into 'possible foraging events' based on `max_inactive_time`, and `min_wheel_movement` - # - Look ahead by `max_inactive_time` and compare with current wheel distance, + # - Look ahead by `max_inactive_time` and compare with current wheel distance; # if the wheel will have moved by `min_wheel_movement`, then it is a foraging event # - Because we "looked ahead" (shifted), we need to readjust the start time of a foraging bout # - For the foraging bout end time, we need to account for the final pellet delivery time @@ -936,7 +936,8 @@ def get_foraging_bouts( wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values # Ensure equal length wheel_vals across patches and wheel_ts min_len = min(*(len(arr) for arr in wheel_vals), len(wheel_ts)) - comb_cum_wheel_dist = np.sum([arr[:min_len] for arr in wheel_vals], axis=0) + wheel_vals = [arr[:min_len] for arr in wheel_vals] + comb_cum_wheel_dist = np.vstack(wheel_vals).sum(axis=0) wheel_ts = wheel_ts[:min_len] # For each wheel_ts, get the corresponding patch that was spun patch_spun = np.full(len(wheel_ts), "", dtype=" (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (win_len - 1) - bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + 4 # TODO: Change this + n_samples_in_1s = int(1 / ((wheel_ts[1] - wheel_ts[0]).astype(int) / 1e9)) + bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + n_samples_in_1s assert len(bout_start_indxs) == len(bout_end_indxs) bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" From 59f82a7e5a77f9024284e1c37694aa1c077426b1 Mon Sep 17 00:00:00 2001 From: Jai Date: Wed, 25 Sep 2024 17:48:03 +0100 Subject: [PATCH 06/10] slight cleanup --- aeon/dj_pipeline/analysis/block_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index f0478c61..c63d54a7 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -958,7 +958,7 @@ def get_foraging_bouts( foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (win_len - 1) - n_samples_in_1s = int(1 / ((wheel_ts[1] - wheel_ts[0]).astype(int) / 1e9)) + n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + n_samples_in_1s assert len(bout_start_indxs) == len(bout_end_indxs) bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds From 6867e52cfa38da07aaab5bf32e07503ad88154e5 Mon Sep 17 00:00:00 2001 From: Jai Date: Fri, 27 Sep 2024 23:53:19 +0100 Subject: [PATCH 07/10] trim ongoing bout at block end in foraging bout detection --- aeon/dj_pipeline/analysis/block_analysis.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index c63d54a7..56628577 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -952,14 +952,15 @@ def get_foraging_bouts( {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts ) wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") - win_len = int(max_inactive_time / wheel_s_r) + max_inactive_win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(win_len - 1)).ffill() + max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (win_len - 1) + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + n_samples_in_1s + bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # trim ongoing bout at block end assert len(bout_start_indxs) == len(bout_end_indxs) bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" @@ -979,6 +980,7 @@ def get_foraging_bouts( for start, end in bout_starts_ends ] ) + import ipdb; ipdb.set_trace() # Filter by `min_pellets` bout_durations = bout_durations[bout_pellets >= min_pellets] bout_starts_ends = bout_starts_ends[bout_pellets >= min_pellets] From 5b0e70a87c6f2fe9742c9316d6ffc4dbabe153bc Mon Sep 17 00:00:00 2001 From: Jai Date: Sat, 28 Sep 2024 14:21:36 +0100 Subject: [PATCH 08/10] added a catch for return empty df for 'get_foraging_bouts' --- aeon/dj_pipeline/analysis/block_analysis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 56628577..6fc19b84 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -916,7 +916,10 @@ def get_foraging_bouts( DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. """ max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") + if subject_patch_data.empty: + return bout_data subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True) wheel_ts = (BlockAnalysis.Patch() & key).fetch("wheel_timestamps")[0] # For each subject: @@ -929,7 +932,6 @@ def get_foraging_bouts( # - For the foraging bout end time, we need to account for the final pellet delivery time # - Filter out events with < `min_pellets` # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF - bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) for subject in subject_patch_data.index.unique("subject_name"): cur_subject_data = subject_patch_data.xs(subject, level="subject_name") # Create combined cumulative wheel distance spun @@ -980,7 +982,6 @@ def get_foraging_bouts( for start, end in bout_starts_ends ] ) - import ipdb; ipdb.set_trace() # Filter by `min_pellets` bout_durations = bout_durations[bout_pellets >= min_pellets] bout_starts_ends = bout_starts_ends[bout_pellets >= min_pellets] From 2f593178fc70a0597f9c2f9b10d81aa042e2f018 Mon Sep 17 00:00:00 2001 From: Jai Date: Sun, 29 Sep 2024 19:54:32 +0100 Subject: [PATCH 09/10] cleaned up some edge cases --- aeon/dj_pipeline/analysis/block_analysis.py | 21 ++++++++++++++++----- aeon/util.py | 5 +++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 6fc19b84..23d01816 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -934,6 +934,9 @@ def get_foraging_bouts( # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF for subject in subject_patch_data.index.unique("subject_name"): cur_subject_data = subject_patch_data.xs(subject, level="subject_name") + n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values]) + if n_pels < min_pellets: + continue # Create combined cumulative wheel distance spun wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values # Ensure equal length wheel_vals across patches and wheel_ts @@ -941,13 +944,13 @@ def get_foraging_bouts( wheel_vals = [arr[:min_len] for arr in wheel_vals] comb_cum_wheel_dist = np.vstack(wheel_vals).sum(axis=0) wheel_ts = wheel_ts[:min_len] + # Ensure monotically increasing wheel dist + comb_cum_wheel_dist = np.maximum.accumulate(comb_cum_wheel_dist) # For each wheel_ts, get the corresponding patch that was spun patch_spun = np.full(len(wheel_ts), "", dtype=" wheel_spun_thresh) patch_spun[spun_indices[1]] = patch_names[spun_indices[0]] patch_spun_df = pd.DataFrame( @@ -961,8 +964,16 @@ def get_foraging_bouts( # Discretize into foraging bouts bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) - bout_end_indxs = np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + n_samples_in_1s - bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # trim ongoing bout at block end + bout_end_indxs = ( + np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + + (max_inactive_win_len - 1) + + n_samples_in_1s + ) + bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block + # Remove bout that starts at block end + if bout_start_indxs[-1] >= len(wheel_ts): + bout_start_indxs = bout_start_indxs[:-1] + bout_end_indxs = bout_end_indxs[:-1] assert len(bout_start_indxs) == len(bout_end_indxs) bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" diff --git a/aeon/util.py b/aeon/util.py index 2251eaad..ceb0637a 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -14,9 +14,10 @@ def find_nested_key(obj: dict | list, key: str) -> Any: found = find_nested_key(v, key) if found: return found - else: + elif isinstance(obj, list): for item in obj: found = find_nested_key(item, key) if found: return found - return None + else: + return None From fc375d6797fcdf4b57dee2bba3622cf32d953891 Mon Sep 17 00:00:00 2001 From: lochhh Date: Wed, 2 Oct 2024 14:45:53 +0100 Subject: [PATCH 10/10] Ruffen `block_analysis.py` --- aeon/dj_pipeline/analysis/block_analysis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 23d01816..e3f9a38d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -36,6 +36,7 @@ class BlockDetection(dj.Computed): def make(self, key): """On a per-chunk basis, check for the presence of new block, insert into Block table. + High level logic 1. Find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) 2. Remove any double 0s (0s within 1 second of each other) (pick the first 0) @@ -792,6 +793,7 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other - Remove all events without threshold value (NaN)