Skip to content

Commit

Permalink
Merge pull request #361 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
Improve block analysis - skip subjects with no position data, bugfix Block detection
  • Loading branch information
ttngu207 authored May 2, 2024
2 parents 5276cc1 + a8f41e3 commit ab51258
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
50 changes: 37 additions & 13 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import plotly.express as px
import plotly.graph_objs as go
from matplotlib import path as mpl_path
from datetime import datetime

from aeon.analysis import utils as analysis_utils
from aeon.dj_pipeline import (acquisition, fetch_stream, get_schema_name,
streams, tracking)
from aeon.dj_pipeline.analysis.visit import (filter_out_maintenance_periods,
get_maintenance_periods)
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods

schema = dj.schema(get_schema_name("block_analysis"))
logger = dj.logger
Expand Down Expand Up @@ -56,17 +55,17 @@ def make(self, key):
key["experiment_name"], previous_block_start, chunk_end
)

block_query = acquisition.Environment.BlockState & chunk_restriction
block_df = fetch_stream(block_query)[previous_block_start:chunk_end]
block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)[previous_block_start:chunk_end]

block_ends = block_df[block_df.pellet_ct.diff() < 0]
block_ends = block_state_df[block_state_df.pellet_ct.diff() < 0]

block_entries = []
for idx, block_end in enumerate(block_ends.index):
if idx == 0:
if previous_block_key:
# if there is a previous block - insert "block_end" for the previous block
previous_pellet_time = block_df[:block_end].index[-2]
previous_pellet_time = block_state_df[:block_end].index[-2]
previous_epoch = (
acquisition.Epoch.join(acquisition.EpochEnd, left=True)
& exp_key
Expand All @@ -92,6 +91,7 @@ def make(self, key):

# ---- Block Analysis and Visualization ----


@schema
class BlockAnalysis(dj.Computed):
definition = """
Expand Down Expand Up @@ -203,11 +203,18 @@ def make(self, key):

encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)

patch_rate = depletion_state_df.rate.unique()
patch_offset = depletion_state_df.offset.unique()
assert len(patch_rate) == 1, f"Found multiple patch rates: {patch_rate} for patch: {patch_name}"
patch_rate = patch_rate[0]
patch_offset = patch_offset[0]
if len(depletion_state_df.rate.unique()) > 1:
# multiple patch rates per block is unexpected, log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.utcnow(),
"note_type": "Multiple patch rates",
"note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}",
}
)

patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]

self.Patch.insert1(
{
Expand Down Expand Up @@ -247,6 +254,7 @@ def make(self, key):
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj("identity_name", anchor_part="part_name")
* tracking.SLEAPTracking.Part
& key
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name,
Expand All @@ -256,6 +264,9 @@ def make(self, key):
pos_df = fetch_stream(pos_query)[block_start:block_end]
pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end)

if pos_df.empty:
continue

position_diff = np.sqrt(
np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float)))
)
Expand Down Expand Up @@ -553,3 +564,16 @@ def make(self, key):
"cumulative_pellet_plot": json.loads(cumulative_pellet_fig.to_json()),
}
)


# ---- AnalysisNote ----


@schema
class AnalysisNote(dj.Manual):
definition = """ # Generic table to catch all notes generated during analysis
note_timestamp: datetime
---
note_type='': varchar(64)
note: varchar(3000)
"""
2 changes: 1 addition & 1 deletion aeon/dj_pipeline/populate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ingest_environment_visits():
acquisition_worker(acquisition.EpochConfig)
acquisition_worker(acquisition.Environment)
# acquisition_worker(ingest_environment_visits)
# acquisition_worker(block_analysis.BlockDetection)
acquisition_worker(block_analysis.BlockDetection)

# configure a worker to handle pyrat sync
pyrat_worker = DataJointWorker(
Expand Down
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def dj_config():
dj.config["custom"][
"database.prefix"
] = f"u_{dj.config['database.user']}_testsuite_"
return


def load_pipeline():
Expand Down Expand Up @@ -137,8 +136,6 @@ def experiment_creation(test_params, pipeline):
}
)

return


@pytest.fixture(scope="session")
def epoch_chunk_ingestion(test_params, pipeline, experiment_creation):
Expand All @@ -154,8 +151,6 @@ def epoch_chunk_ingestion(test_params, pipeline, experiment_creation):

acquisition.Chunk.ingest_chunks(experiment_name=test_params["experiment_name"])

return


@pytest.fixture(scope="session")
def experimentlog_ingestion(pipeline):
Expand All @@ -166,20 +161,14 @@ def experimentlog_ingestion(pipeline):
acquisition.SubjectEnterExit.populate(**_populate_settings)
acquisition.SubjectWeight.populate(**_populate_settings)

return


@pytest.fixture(scope="session")
def camera_qc_ingestion(pipeline, epoch_chunk_ingestion):
qc = pipeline["qc"]
qc.CameraQC.populate(**_populate_settings)

return


@pytest.fixture(scope="session")
def camera_tracking_ingestion(pipeline, camera_qc_ingestion):
tracking = pipeline["tracking"]
tracking.CameraTracking.populate(**_populate_settings)

return

0 comments on commit ab51258

Please sign in to comment.