Skip to content

Commit

Permalink
Merge pull request #370 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
fix: bugfix in Block Detection
  • Loading branch information
ttngu207 authored Jun 20, 2024
2 parents 3b73245 + e3e51f2 commit 875e7af
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def make(self, key):
"""On a per-chunk basis, check for the presence of new block, insert into Block table."""
# find the 0s
# that would mark the start of a new block
# if the 0 is the first index - look back at the previous chunk
# In the BlockState data - if the 0 is the first index - look back at the previous chunk
# if the previous timestamp belongs to a previous epoch -> block_end is the previous timestamp
# else block_end is the timestamp of this 0
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")
Expand All @@ -55,17 +55,31 @@ def make(self, key):
key["experiment_name"], previous_block_start, chunk_end
)

block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)[previous_block_start:chunk_end]
# detecting block end times
# pellet count reset - find 0s in BlockState

block_ends = block_state_df[block_state_df.pellet_ct.diff() < 0]
block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)
block_state_df.index = block_state_df.index.round(
"us"
) # timestamp precision in DJ is only at microseconds
block_state_df = block_state_df.loc[
(block_state_df.index > previous_block_start) & (block_state_df.index <= chunk_end)
]

block_ends = block_state_df[block_state_df.pellet_ct == 0]
# account for the double 0s - find any 0s that are within 1 second of each other, remove the 2nd one
double_0s = block_ends.index.to_series().diff().dt.total_seconds() < 1
# find the indices of the 2nd 0s and remove
double_0s = double_0s.shift(-1).fillna(False)
block_ends = block_ends[~double_0s]

block_entries = []
for idx, block_end in enumerate(block_ends.index):
if idx == 0:
if previous_block_key:
# if there is a previous block - insert "block_end" for the previous block
previous_pellet_time = block_state_df[:block_end].index[-2]
previous_pellet_time = block_state_df[:block_end].index[-1]
previous_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
Expand Down Expand Up @@ -233,6 +247,10 @@ def make(self, key):
}
)

# update block_end if last timestamp of encoder_df is before the current block_end
if encoder_df.index[-1] < block_end:
block_end = encoder_df.index[-1]

# Subject data
# Get all unique subjects that visited the environment over the entire exp;
# For each subject, see 'type' of visit most recent to start of block
Expand All @@ -248,6 +266,7 @@ def make(self, key):
_df = subject_visits_df[subject_visits_df.id == subject_name]
if _df.type[-1] != "Exit":
subject_names.append(subject_name)

for subject_name in subject_names:
# positions - query for CameraTop, identity_name matches subject_name,
pos_query = (
Expand Down Expand Up @@ -291,6 +310,14 @@ def make(self, key):
}
)

# update block_end if last timestamp of pos_df is before the current block_end
if pos_df.index[-1] < block_end:
block_end = pos_df.index[-1]

if block_end != (Block & key).fetch1("block_end"):
Block.update1({**key, "block_end": block_end})
self.update1({**key, "block_duration": (block_end - block_start).total_seconds() / 3600})


@schema
class BlockSubjectAnalysis(dj.Computed):
Expand Down Expand Up @@ -501,7 +528,7 @@ def make(self, key):

@schema
class BlockPlots(dj.Computed):
definition = """
definition = """
-> BlockAnalysis
---
subject_positions_plot: longblob
Expand Down

0 comments on commit 875e7af

Please sign in to comment.