diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 67ebed32..96387b19 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -1,32 +1,9 @@ +"""Helper functions for plotting block data.""" + from colorsys import hls_to_rgb, rgb_to_hls import numpy as np -import plotly - -"""Standardize subject colors, patch colors, and markers.""" - -subject_colors = plotly.colors.qualitative.Plotly -subject_colors_dict = { - "BAA-1104045": subject_colors[0], - "BAA-1104047": subject_colors[1], - "BAA-1104048": subject_colors[2], - "BAA-1104049": subject_colors[3], -} -patch_colors = plotly.colors.qualitative.Dark2 -patch_markers = [ - "circle", - "bowtie", - "square", - "hourglass", - "diamond", - "cross", - "x", - "triangle", - "star", -] -patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"] -patch_markers_dict = dict(zip(patch_markers, patch_markers_symbols, strict=False)) -patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"] +from numpy.lib.stride_tricks import as_strided def gen_hex_grad(hex_col, vals, min_l=0.3): @@ -44,3 +21,13 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): grad[i] = cur_hex_col return grad + + +def conv2d(arr, kernel): + """Performs "valid" 2d convolution using numpy `as_strided` and `einsum`.""" + out_shape = tuple(np.subtract(arr.shape, kernel.shape) + 1) + sub_mat_shape = kernel.shape + out_shape + # Create "new view" of `arr` as submatrices at which kernel will be applied + sub_mats = as_strided(arr, shape=sub_mat_shape, strides=(arr.strides * 2)) + out = np.einsum("ij, ijkl -> kl", kernel, sub_mats) + return out diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 5ddd4968..e3f9a38d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,16 +1,17 @@ import json +from datetime import datetime + import datajoint as dj import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objs as go from matplotlib import path as mpl_path -from datetime import datetime -from aeon.io import api as io_api from aeon.analysis import utils as analysis_utils from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods +from aeon.io import api as io_api schema = dj.schema(get_schema_name("block_analysis")) logger = dj.logger @@ -34,8 +35,8 @@ class BlockDetection(dj.Computed): """ def make(self, key): - """ - On a per-chunk basis, check for the presence of new block, insert into Block table. + """On a per-chunk basis, check for the presence of new block, insert into Block table. + High level logic 1. Find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) 2. Remove any double 0s (0s within 1 second of each other) (pick the first 0) @@ -353,6 +354,7 @@ class Patch(dj.Part): pellet_count: int pellet_timestamps: longblob patch_threshold: longblob # patch threshold value at each pellet delivery + patch_threshold: longblob # patch threshold value at each pellet delivery wheel_cumsum_distance_travelled: longblob # wheel's cumulative distance travelled """ @@ -492,9 +494,9 @@ def make(self, key): ) subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt - all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ - "cum_time" - ] = subject_in_patch_cum_time + all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_time"] = ( + subject_in_patch_cum_time + ) closest_subj_mask = closest_subjects_pellet_ts == subject_name subj_pellets = closest_subjects_pellet_ts[closest_subj_mask] @@ -502,16 +504,16 @@ def make(self, key): self.Patch.insert1( key - | dict( - patch_name=patch["patch_name"], - subject_name=subject_name, - in_patch_timestamps=subject_in_patch.index.values, - in_patch_time=subject_in_patch_cum_time[-1], - pellet_count=len(subj_pellets), - pellet_timestamps=subj_pellets.index.values, - patch_threshold=subj_patch_thresh, - wheel_cumsum_distance_travelled=cum_wheel_dist_subj_df[subject_name].values, - ) + | { + "patch_name": patch["patch_name"], + "subject_name": subject_name, + "in_patch_timestamps": subject_in_patch.index.values, + "in_patch_time": subject_in_patch_cum_time[-1], + "pellet_count": len(subj_pellets), + "pellet_timestamps": subj_pellets.index.values, + "patch_threshold": subj_patch_thresh, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, + } ) # Now that we have computed all individual patch and subject values, we iterate again through @@ -663,10 +665,10 @@ class BlockSubjectPlots(dj.Computed): def make(self, key): from aeon.analysis.block_plotting import ( - subject_colors, - patch_markers_linestyles, - patch_markers, gen_hex_grad, + patch_markers, + patch_markers_linestyles, + subject_colors, ) patch_names, subject_names = (BlockSubjectAnalysis.Preference & key).fetch( @@ -790,8 +792,8 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """ - Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other - Remove all events without threshold value (NaN) @@ -889,5 +891,133 @@ def get_threshold_associated_pellets(patch_key, start, end): # Shift back the pellet_timestamp values by 1 to match with the previous threshold update pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) - pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp", "beam_break_timestamp"]) + pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( + subset=["pellet_timestamp", "beam_break_timestamp"] + ) return pellet_ts_threshold_df + + +"""Foraging bout function.""" + + +def get_foraging_bouts( + key: dict, + min_pellets: int = 3, + max_inactive_time: pd.Timedelta | None = None, # seconds + min_wheel_movement: float = 5, # cm +) -> pd.DataFrame: + """Gets foraging bouts for all subjects across all patches within a block. + + Args: + key: Block key - dict containing keys for 'experiment_name' and 'block_start'. + min_pellets: Minimum number of pellets for a foraging bout. + max_inactive_time: Maximum time between `min_wheel_movement`s for a foraging bout. + min_wheel_movement: Minimum wheel movement for a foraging bout. + + Returns: + DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. + """ + max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) + subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") + if subject_patch_data.empty: + return bout_data + subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True) + wheel_ts = (BlockAnalysis.Patch() & key).fetch("wheel_timestamps")[0] + # For each subject: + # - Create cumulative wheel distance spun sum df combining all patches + # - Columns: timestamp, wheel distance, patch + # - Discretize into 'possible foraging events' based on `max_inactive_time`, and `min_wheel_movement` + # - Look ahead by `max_inactive_time` and compare with current wheel distance; + # if the wheel will have moved by `min_wheel_movement`, then it is a foraging event + # - Because we "looked ahead" (shifted), we need to readjust the start time of a foraging bout + # - For the foraging bout end time, we need to account for the final pellet delivery time + # - Filter out events with < `min_pellets` + # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF + for subject in subject_patch_data.index.unique("subject_name"): + cur_subject_data = subject_patch_data.xs(subject, level="subject_name") + n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values]) + if n_pels < min_pellets: + continue + # Create combined cumulative wheel distance spun + wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values + # Ensure equal length wheel_vals across patches and wheel_ts + min_len = min(*(len(arr) for arr in wheel_vals), len(wheel_ts)) + wheel_vals = [arr[:min_len] for arr in wheel_vals] + comb_cum_wheel_dist = np.vstack(wheel_vals).sum(axis=0) + wheel_ts = wheel_ts[:min_len] + # Ensure monotically increasing wheel dist + comb_cum_wheel_dist = np.maximum.accumulate(comb_cum_wheel_dist) + # For each wheel_ts, get the corresponding patch that was spun + patch_spun = np.full(len(wheel_ts), "", dtype=" wheel_spun_thresh) + patch_spun[spun_indices[1]] = patch_names[spun_indices[0]] + patch_spun_df = pd.DataFrame( + {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts + ) + wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") + max_inactive_win_len = int(max_inactive_time / wheel_s_r) + # Find times when foraging + max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() + foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) + # Discretize into foraging bouts + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) + n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) + bout_end_indxs = ( + np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + + (max_inactive_win_len - 1) + + n_samples_in_1s + ) + bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block + # Remove bout that starts at block end + if bout_start_indxs[-1] >= len(wheel_ts): + bout_start_indxs = bout_start_indxs[:-1] + bout_end_indxs = bout_end_indxs[:-1] + assert len(bout_start_indxs) == len(bout_end_indxs) + bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds + "timedelta64[ns]" + ).astype(float) / 1e9 + bout_starts_ends = np.array( + [ + (wheel_ts[start_idx], wheel_ts[end_idx]) + for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) + ] + ) + all_pel_ts = np.sort( + np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) + ) + bout_pellets = np.array( + [ + len(all_pel_ts[(all_pel_ts >= start) & (all_pel_ts <= end)]) + for start, end in bout_starts_ends + ] + ) + # Filter by `min_pellets` + bout_durations = bout_durations[bout_pellets >= min_pellets] + bout_starts_ends = bout_starts_ends[bout_pellets >= min_pellets] + bout_pellets = bout_pellets[bout_pellets >= min_pellets] + bout_cum_wheel_dist = np.array( + [ + patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] + for start, end in bout_starts_ends + ] + ) + # Add to returned DF + bout_data = pd.concat( + [ + bout_data, + pd.DataFrame( + { + "start": bout_starts_ends[:, 0], + "end": bout_starts_ends[:, 1], + "n_pellets": bout_pellets, + "cum_wheel_dist": bout_cum_wheel_dist, + "subject": subject, + } + ), + ] + ) + return bout_data.sort_values("start").reset_index(drop=True) diff --git a/aeon/util.py b/aeon/util.py index 2251eaad..ceb0637a 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -14,9 +14,10 @@ def find_nested_key(obj: dict | list, key: str) -> Any: found = find_nested_key(v, key) if found: return found - else: + elif isinstance(obj, list): for item in obj: found = find_nested_key(item, key) if found: return found - return None + else: + return None