Skip to content

Commit

Permalink
Finalized reading from Sleap config file
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbhagatio committed Aug 17, 2023
1 parent 0305a35 commit 38cf625
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
54 changes: 25 additions & 29 deletions aeon/schema/social.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import pandas as pd

from aeon import util
import aeon.io.reader as _reader

import ipdb

class Pose(_reader.Harp):
"""Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods."""
Expand All @@ -15,47 +17,43 @@ def __init__(self, pattern):
self.extension = "bin"

def read(self, file, ceph_proc_dir="/ceph/aeon/aeon/data/processed"):
"""
Reads data from the Harp-binarized tracking file
"""
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = file.stem.replace("_", "/")
config_file_dir = Path(ceph_proc_dir + "/" + model_dir)
assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}"
config_file = get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)
if parts is None:
raise ValueError(f"Cannot find bodyparts in {config_file}.")
# With bodyparts, set columns as set by Bonsai, and read data in default format.
# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
# Set new columns, and reformat `data`.
columns = ["class", "class_likelihood", "part", "part_likelihood", "x", "y"]
new_data = pd.DataFrame(columns=columns)
for part in parts:
part_data = data[["class", "class_likelihood", f"{part}_likelihood", f"{part}_x", f"{part}_y"]]
part_data.insert(2, "part", part)
part_data.columns = columns
new_data = pd.concat([new_data, part_data])
new_columns = ["class", "class_likelihood", "part", "part_likelihood", "x", "y"]
new_data = pd.DataFrame(columns=new_columns)
part_data_list = [None] * len(parts)
for i, part in enumerate(parts):
part_data_list[i] = data[
["class", "class_likelihood", f"{part}_likelihood", f"{part}_x", f"{part}_y"]
]
part_data_list[i].insert(2, "part", part)
part_data_list[i].columns = new_columns
new_data = pd.concat(part_data_list)
return new_data.sort_index()
def get_bodyparts(file):

def get_bodyparts(self, file):
"""Returns a list of bodyparts from a model's config file."""
parts = None
if file.stem == "confmap_config": # SLEAP
with open(file) as f:
with open(file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
for model in heads.keys():
if heads[model] is not None:
parts = heads[model]["confmaps"]["part_names"]
break
parts = util.find_nested_key(heads, "part_names")
except KeyError as err:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
raise KeyError(f"Cannot find bodyparts in {file}.") from err
return parts


Expand All @@ -74,6 +72,7 @@ def get_config_file(
assert config_file is not None, f"Cannot find config file in {config_file_dir}"
return config_file


def class_int2str(tracking_df, config_file_dir):
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = get_config_file(config_file_dir)
Expand All @@ -82,12 +81,9 @@ def class_int2str(tracking_df, config_file_dir):
config = json.load(f)
try:
heads = config["model"]["heads"]
for model in heads.keys():
if heads[model] is not None:
classes = heads[model]["class_vectors"]["classes"]
break
for i, subj in enumerate(classes):
tracking_df.loc[tracking_df["class"] == i, "class"] = subj
return tracking_df
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
tracking_df.loc[tracking_df["class"] == i, "class"] = subj
return tracking_df
19 changes: 19 additions & 0 deletions aeon/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Utility functions."""

from typing import Union

def find_nested_key(obj: Union[dict, list], key: str):
"""Returns the value of the first found nested key."""
if isinstance(obj, dict):
for k, v in obj.items():
if k == key: # Found it!
return v
found = find_nested_key(v, key)
if found:
return found
elif isinstance(obj, list):
for item in obj:
found = find_nested_key(item, key)
if found:
return found
return None

0 comments on commit 38cf625

Please sign in to comment.