Skip to content

Commit

Permalink
Merge pull request #353 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
feat(block_analysis): retrieve patch offset
  • Loading branch information
ttngu207 authored Apr 11, 2024
2 parents e0564c5 + 0ca72d9 commit 91f4490
Showing 1 changed file with 71 additions and 65 deletions.
136 changes: 71 additions & 65 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,72 @@ class Block(dj.Manual):
"""


@schema
class BlockDetection(dj.Computed):
definition = """
-> acquisition.Environment
"""

def make(self, key):
"""On a per-chunk basis, check for the presence of new block, insert into Block table."""
# find the 0s
# that would mark the start of a new block
# if the 0 is the first index - look back at the previous chunk
# if the previous timestamp belongs to a previous epoch -> block_end is the previous timestamp
# else block_end is the timestamp of this 0
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")
exp_key = {"experiment_name": key["experiment_name"]}
# only consider the time period between the last block and the current chunk
previous_block = Block & exp_key & f"block_start <= '{chunk_start}'"
if previous_block:
previous_block_key = previous_block.fetch("KEY", limit=1, order_by="block_start DESC")[0]
previous_block_start = previous_block_key["block_start"]
else:
previous_block_key = None
previous_block_start = (acquisition.Chunk & exp_key).fetch(
"chunk_start", limit=1, order_by="chunk_start"
)[0]

chunk_restriction = acquisition.create_chunk_restriction(
key["experiment_name"], previous_block_start, chunk_end
)

block_query = acquisition.Environment.BlockState & chunk_restriction
block_df = fetch_stream(block_query).sort_index()[previous_block_start:chunk_end]

block_ends = block_df[block_df.pellet_ct.diff() < 0]

block_entries = []
for idx, block_end in enumerate(block_ends.index):
if idx == 0:
if previous_block_key:
# if there is a previous block - insert "block_end" for the previous block
previous_pellet_time = block_df[:block_end].index[-2]
previous_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
& f"'{previous_pellet_time}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')"
).fetch1("KEY")
current_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
& f"'{block_end}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')"
).fetch1("KEY")

previous_block_key["block_end"] = (
block_end if current_epoch == previous_epoch else previous_pellet_time
)
Block.update1(previous_block_key)
else:
block_entries[-1]["block_end"] = block_end
block_entries.append({**exp_key, "block_start": block_end, "block_end": None})

Block.insert(block_entries)
self.insert1(key)


# ---- Block Analysis and Visualization ----

@schema
class BlockAnalysis(dj.Computed):
definition = """
Expand All @@ -49,6 +115,7 @@ class Patch(dj.Part):
patch_threshold: longblob
patch_threshold_timestamps: longblob
patch_rate: float
patch_offset: float
"""

class Subject(dj.Part):
Expand Down Expand Up @@ -139,8 +206,10 @@ def make(self, key):
encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)

patch_rate = depletion_state_df.rate.unique()
assert len(patch_rate) == 1 # expects a single rate for this block
patch_offset = depletion_state_df.offset.unique()
assert len(patch_rate) == 1, f"Found multiple patch rates: {patch_rate} for patch: {patch_name}"
patch_rate = patch_rate[0]
patch_offset = patch_offset[0]

self.Patch.insert1(
{
Expand All @@ -155,6 +224,7 @@ def make(self, key):
"patch_threshold": depletion_state_df.threshold.values,
"patch_threshold_timestamps": depletion_state_df.index.values,
"patch_rate": patch_rate,
"patch_offset": patch_offset,
}
)

Expand Down Expand Up @@ -486,67 +556,3 @@ def make(self, key):
"cumulative_pellet_plot": json.loads(cumulative_pellet_fig.to_json()),
}
)


@schema
class BlockDetection(dj.Computed):
definition = """
-> acquisition.Environment
"""

def make(self, key):
"""On a per-chunk basis, check for the presence of new block, insert into Block table."""
# find the 0s
# that would mark the start of a new block
# if the 0 is the first index - look back at the previous chunk
# if the previous timestamp belongs to a previous epoch -> block_end is the previous timestamp
# else block_end is the timestamp of this 0
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")
exp_key = {"experiment_name": key["experiment_name"]}
# only consider the time period between the last block and the current chunk
previous_block = Block & exp_key & f"block_start <= '{chunk_start}'"
if previous_block:
previous_block_key = previous_block.fetch("KEY", limit=1, order_by="block_start DESC")[0]
previous_block_start = previous_block_key["block_start"]
else:
previous_block_key = None
previous_block_start = (acquisition.Chunk & exp_key).fetch(
"chunk_start", limit=1, order_by="chunk_start"
)[0]

chunk_restriction = acquisition.create_chunk_restriction(
key["experiment_name"], previous_block_start, chunk_end
)

block_query = acquisition.Environment.BlockState & chunk_restriction
block_df = fetch_stream(block_query).sort_index()[previous_block_start:chunk_end]

block_ends = block_df[block_df.pellet_ct.diff() < 0]

block_entries = []
for idx, block_end in enumerate(block_ends.index):
if idx == 0:
if previous_block_key:
# if there is a previous block - insert "block_end" for the previous block
previous_pellet_time = block_df[:block_end].index[-2]
previous_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
& f"'{previous_pellet_time}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')"
).fetch1("KEY")
current_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
& f"'{block_end}' BETWEEN epoch_start AND IFNULL(epoch_end, '2200-01-01')"
).fetch1("KEY")

previous_block_key["block_end"] = (
block_end if current_epoch == previous_epoch else previous_pellet_time
)
Block.update1(previous_block_key)
else:
block_entries[-1]["block_end"] = block_end
block_entries.append({**exp_key, "block_start": block_end, "block_end": None})

Block.insert(block_entries)
self.insert1(key)

0 comments on commit 91f4490

Please sign in to comment.