Skip to content

Commit

Permalink
fix: update logic to associate true pellet times with each threshold …
Browse files Browse the repository at this point in the history
…update time
  • Loading branch information
ttngu207 committed Jul 18, 2024
1 parent 96e9e46 commit 3411fe8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
1 change: 1 addition & 0 deletions aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def fetch_stream(query, drop_pk=True):
df.rename(columns={"timestamps": "time"}, inplace=True)
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
return df


Expand Down
48 changes: 36 additions & 12 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,42 @@ def make(self, key):
patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name")

for patch_key, patch_name in zip(patch_keys, patch_names):
delivered_pellet_df = fetch_stream(
# pellet delivery and patch threshold data
beam_break_df = fetch_stream(
streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction
)[block_start:block_end]
depletion_state_df = fetch_stream(
streams.UndergroundFeederDepletionState & patch_key & chunk_restriction
)[block_start:block_end]
# remove NaNs from threshold column
depletion_state_df = depletion_state_df.dropna(subset=["threshold"])
# identify & remove invalid indices where the time difference is less than 1 second
invalid_indices = np.where(depletion_state_df.index.to_series().diff().dt.total_seconds() < 1)[0]
depletion_state_df = depletion_state_df.drop(depletion_state_df.index[invalid_indices])

# find pellet times associated with each threshold update
# for each threshold, find the time of the next threshold update,
# find the closest beam break after this update time,
# and use this beam break time as the delivery time for the initial threshold
pellet_ts_threshold_df = depletion_state_df.copy()
pellet_ts_threshold_df["pellet_timestamp"] = pd.NaT
for threshold_idx in range(len(pellet_ts_threshold_df) - 1):
if np.isnan(pellet_ts_threshold_df.threshold.iloc[threshold_idx]):
continue
next_threshold_time = pellet_ts_threshold_df.index[threshold_idx + 1]
post_thresh_pellet_ts = beam_break_df.index[beam_break_df.index > next_threshold_time]
next_beam_break = post_thresh_pellet_ts[np.searchsorted(post_thresh_pellet_ts, next_threshold_time)]
pellet_ts_threshold_df.pellet_timestamp.iloc[threshold_idx] = next_beam_break
# remove NaNs from pellet_timestamp column (last row)
pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"])

# wheel encoder data
encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[
block_start:block_end
]
# filter out maintenance period based on logs
pellet_df = filter_out_maintenance_periods(
delivered_pellet_df,
pellet_ts_threshold_df = filter_out_maintenance_periods(
pellet_ts_threshold_df,
maintenance_period,
block_end,
dropna=True,
Expand Down Expand Up @@ -229,22 +253,21 @@ def make(self, key):

patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]

# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate

self.Patch.insert1(
{
**key,
"patch_name": patch_name,
"pellet_count": len(pellet_df),
"pellet_timestamps": pellet_df.index.values,
"pellet_count": len(pellet_ts_threshold_df),
"pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values,
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[
::wheel_downsampling_factor
],
"wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor],
"patch_threshold": depletion_state_df.threshold.values,
"patch_threshold_timestamps": depletion_state_df.index.values,
"patch_threshold": pellet_ts_threshold_df.threshold.values,
"patch_threshold_timestamps": pellet_ts_threshold_df.index.values,
"patch_rate": patch_rate,
"patch_offset": patch_offset,
}
Expand All @@ -267,7 +290,7 @@ def make(self, key):
subject_names = []
for subject_name in set(subject_visits_df.id):
_df = subject_visits_df[subject_visits_df.id == subject_name]
if _df.type[-1] != "Exit":
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

for subject_name in subject_names:
Expand Down Expand Up @@ -454,7 +477,7 @@ def make(self, key):
"dist_to_patch"
].values

# Get closest subject to patch at each pel del timestep
# Get closest subject to patch at each pellet timestep
closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1)
# Get closest subject to patch at each wheel timestep
cum_wheel_dist_subj_df = pd.DataFrame(
Expand All @@ -481,9 +504,10 @@ def make(self, key):
all_subj_patch_pref_dict[patch["patch_name"]][subject_name][
"cum_time"
] = subject_in_patch_cum_time
subj_pellets = closest_subjects_pellet_ts[closest_subjects_pellet_ts == subject_name]

subj_patch_thresh = patch["patch_threshold"][np.searchsorted(patch["patch_threshold_timestamps"], subj_pellets.index.values) - 1]
closest_subj_mask = closest_subjects_pellet_ts == subject_name
subj_pellets = closest_subjects_pellet_ts[closest_subj_mask]
subj_patch_thresh = patch["patch_threshold"][closest_subj_mask]

self.Patch.insert1(
key
Expand Down

0 comments on commit 3411fe8

Please sign in to comment.