Skip to content

Commit

Permalink
Merge pull request #410 from SainsburyWellcomeCentre/block_foraging_b…
Browse files Browse the repository at this point in the history
…out_detection

added working 'get_foraging_bouts' function
  • Loading branch information
jkbhagatio authored Oct 3, 2024
2 parents 07fe4c2 + fc375d6 commit 775e0ca
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 51 deletions.
39 changes: 13 additions & 26 deletions aeon/analysis/block_plotting.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,9 @@
"""Helper functions for plotting block data."""

from colorsys import hls_to_rgb, rgb_to_hls

import numpy as np
import plotly

"""Standardize subject colors, patch colors, and markers."""

subject_colors = plotly.colors.qualitative.Plotly
subject_colors_dict = {
"BAA-1104045": subject_colors[0],
"BAA-1104047": subject_colors[1],
"BAA-1104048": subject_colors[2],
"BAA-1104049": subject_colors[3],
}
patch_colors = plotly.colors.qualitative.Dark2
patch_markers = [
"circle",
"bowtie",
"square",
"hourglass",
"diamond",
"cross",
"x",
"triangle",
"star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = dict(zip(patch_markers, patch_markers_symbols, strict=False))
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]
from numpy.lib.stride_tricks import as_strided


def gen_hex_grad(hex_col, vals, min_l=0.3):
Expand All @@ -44,3 +21,13 @@ def gen_hex_grad(hex_col, vals, min_l=0.3):
grad[i] = cur_hex_col

return grad


def conv2d(arr, kernel):
"""Performs "valid" 2d convolution using numpy `as_strided` and `einsum`."""
out_shape = tuple(np.subtract(arr.shape, kernel.shape) + 1)
sub_mat_shape = kernel.shape + out_shape
# Create "new view" of `arr` as submatrices at which kernel will be applied
sub_mats = as_strided(arr, shape=sub_mat_shape, strides=(arr.strides * 2))
out = np.einsum("ij, ijkl -> kl", kernel, sub_mats)
return out
176 changes: 153 additions & 23 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import json
from datetime import datetime

import datajoint as dj
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from matplotlib import path as mpl_path
from datetime import datetime

from aeon.io import api as io_api
from aeon.analysis import utils as analysis_utils
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods
from aeon.io import api as io_api

schema = dj.schema(get_schema_name("block_analysis"))
logger = dj.logger
Expand All @@ -34,8 +35,8 @@ class BlockDetection(dj.Computed):
"""

def make(self, key):
"""
On a per-chunk basis, check for the presence of new block, insert into Block table.
"""On a per-chunk basis, check for the presence of new block, insert into Block table.
High level logic
1. Find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block)
2. Remove any double 0s (0s within 1 second of each other) (pick the first 0)
Expand Down Expand Up @@ -353,6 +354,7 @@ class Patch(dj.Part):
pellet_count: int
pellet_timestamps: longblob
patch_threshold: longblob # patch threshold value at each pellet delivery
patch_threshold: longblob # patch threshold value at each pellet delivery
wheel_cumsum_distance_travelled: longblob # wheel's cumulative distance travelled
"""

Expand Down Expand Up @@ -492,26 +494,26 @@ def make(self, key):
)
subject_in_patch = in_patch[subject_name]
subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt
all_subj_patch_pref_dict[patch["patch_name"]][subject_name][
"cum_time"
] = subject_in_patch_cum_time
all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_time"] = (
subject_in_patch_cum_time
)

closest_subj_mask = closest_subjects_pellet_ts == subject_name
subj_pellets = closest_subjects_pellet_ts[closest_subj_mask]
subj_patch_thresh = patch["patch_threshold"][closest_subj_mask]

self.Patch.insert1(
key
| dict(
patch_name=patch["patch_name"],
subject_name=subject_name,
in_patch_timestamps=subject_in_patch.index.values,
in_patch_time=subject_in_patch_cum_time[-1],
pellet_count=len(subj_pellets),
pellet_timestamps=subj_pellets.index.values,
patch_threshold=subj_patch_thresh,
wheel_cumsum_distance_travelled=cum_wheel_dist_subj_df[subject_name].values,
)
| {
"patch_name": patch["patch_name"],
"subject_name": subject_name,
"in_patch_timestamps": subject_in_patch.index.values,
"in_patch_time": subject_in_patch_cum_time[-1],
"pellet_count": len(subj_pellets),
"pellet_timestamps": subj_pellets.index.values,
"patch_threshold": subj_patch_thresh,
"wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values,
}
)

# Now that we have computed all individual patch and subject values, we iterate again through
Expand Down Expand Up @@ -663,10 +665,10 @@ class BlockSubjectPlots(dj.Computed):

def make(self, key):
from aeon.analysis.block_plotting import (
subject_colors,
patch_markers_linestyles,
patch_markers,
gen_hex_grad,
patch_markers,
patch_markers_linestyles,
subject_colors,
)

patch_names, subject_names = (BlockSubjectAnalysis.Preference & key).fetch(
Expand Down Expand Up @@ -790,8 +792,8 @@ class AnalysisNote(dj.Manual):


def get_threshold_associated_pellets(patch_key, start, end):
"""
Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time.
"""Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time.
1. Get all patch state update timestamps (DepletionState): let's call these events "A"
- Remove all events within 1 second of each other
- Remove all events without threshold value (NaN)
Expand Down Expand Up @@ -889,5 +891,133 @@ def get_threshold_associated_pellets(patch_key, start, end):
# Shift back the pellet_timestamp values by 1 to match with the previous threshold update
pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1)
pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1)
pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp", "beam_break_timestamp"])
pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(
subset=["pellet_timestamp", "beam_break_timestamp"]
)
return pellet_ts_threshold_df


"""Foraging bout function."""


def get_foraging_bouts(
key: dict,
min_pellets: int = 3,
max_inactive_time: pd.Timedelta | None = None, # seconds
min_wheel_movement: float = 5, # cm
) -> pd.DataFrame:
"""Gets foraging bouts for all subjects across all patches within a block.
Args:
key: Block key - dict containing keys for 'experiment_name' and 'block_start'.
min_pellets: Minimum number of pellets for a foraging bout.
max_inactive_time: Maximum time between `min_wheel_movement`s for a foraging bout.
min_wheel_movement: Minimum wheel movement for a foraging bout.
Returns:
DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject.
"""
max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time
bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"])
subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame")
if subject_patch_data.empty:
return bout_data
subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True)
wheel_ts = (BlockAnalysis.Patch() & key).fetch("wheel_timestamps")[0]
# For each subject:
# - Create cumulative wheel distance spun sum df combining all patches
# - Columns: timestamp, wheel distance, patch
# - Discretize into 'possible foraging events' based on `max_inactive_time`, and `min_wheel_movement`
# - Look ahead by `max_inactive_time` and compare with current wheel distance;
# if the wheel will have moved by `min_wheel_movement`, then it is a foraging event
# - Because we "looked ahead" (shifted), we need to readjust the start time of a foraging bout
# - For the foraging bout end time, we need to account for the final pellet delivery time
# - Filter out events with < `min_pellets`
# - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF
for subject in subject_patch_data.index.unique("subject_name"):
cur_subject_data = subject_patch_data.xs(subject, level="subject_name")
n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values])
if n_pels < min_pellets:
continue
# Create combined cumulative wheel distance spun
wheel_vals = cur_subject_data["wheel_cumsum_distance_travelled"].values
# Ensure equal length wheel_vals across patches and wheel_ts
min_len = min(*(len(arr) for arr in wheel_vals), len(wheel_ts))
wheel_vals = [arr[:min_len] for arr in wheel_vals]
comb_cum_wheel_dist = np.vstack(wheel_vals).sum(axis=0)
wheel_ts = wheel_ts[:min_len]
# Ensure monotically increasing wheel dist
comb_cum_wheel_dist = np.maximum.accumulate(comb_cum_wheel_dist)
# For each wheel_ts, get the corresponding patch that was spun
patch_spun = np.full(len(wheel_ts), "", dtype="<U20")
patch_names = cur_subject_data.index.get_level_values(1)
wheel_spun_thresh = 0.03 # threshold for wheel movement (cm)
diffs = np.diff(np.stack(wheel_vals), axis=1, prepend=0)
spun_indices = np.where(diffs > wheel_spun_thresh)
patch_spun[spun_indices[1]] = patch_names[spun_indices[0]]
patch_spun_df = pd.DataFrame(
{"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts
)
wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns")
max_inactive_win_len = int(max_inactive_time / wheel_s_r)
# Find times when foraging
max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill()
foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement)
# Discretize into foraging bouts
bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1)
n_samples_in_1s = int(1 / wheel_s_r.total_seconds())
bout_end_indxs = (
np.where(np.diff(foraging_mask, prepend=0) == -1)[0]
+ (max_inactive_win_len - 1)
+ n_samples_in_1s
)
bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block
# Remove bout that starts at block end
if bout_start_indxs[-1] >= len(wheel_ts):
bout_start_indxs = bout_start_indxs[:-1]
bout_end_indxs = bout_end_indxs[:-1]
assert len(bout_start_indxs) == len(bout_end_indxs)
bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds
"timedelta64[ns]"
).astype(float) / 1e9
bout_starts_ends = np.array(
[
(wheel_ts[start_idx], wheel_ts[end_idx])
for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True)
]
)
all_pel_ts = np.sort(
np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0])
)
bout_pellets = np.array(
[
len(all_pel_ts[(all_pel_ts >= start) & (all_pel_ts <= end)])
for start, end in bout_starts_ends
]
)
# Filter by `min_pellets`
bout_durations = bout_durations[bout_pellets >= min_pellets]
bout_starts_ends = bout_starts_ends[bout_pellets >= min_pellets]
bout_pellets = bout_pellets[bout_pellets >= min_pellets]
bout_cum_wheel_dist = np.array(
[
patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"]
for start, end in bout_starts_ends
]
)
# Add to returned DF
bout_data = pd.concat(
[
bout_data,
pd.DataFrame(
{
"start": bout_starts_ends[:, 0],
"end": bout_starts_ends[:, 1],
"n_pellets": bout_pellets,
"cum_wheel_dist": bout_cum_wheel_dist,
"subject": subject,
}
),
]
)
return bout_data.sort_values("start").reset_index(drop=True)
5 changes: 3 additions & 2 deletions aeon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ def find_nested_key(obj: dict | list, key: str) -> Any:
found = find_nested_key(v, key)
if found:
return found
else:
elif isinstance(obj, list):
for item in obj:
found = find_nested_key(item, key)
if found:
return found
return None
else:
return None

0 comments on commit 775e0ca

Please sign in to comment.