From 38cf6258274d1e17b43bfcfc9279cb21353103b9 Mon Sep 17 00:00:00 2001 From: Jai Bhagat Date: Thu, 17 Aug 2023 18:48:12 +0000 Subject: [PATCH] Finalized reading from Sleap config file --- aeon/schema/social.py | 54 ++++++++++++++++++++----------------------- aeon/util.py | 19 +++++++++++++++ 2 files changed, 44 insertions(+), 29 deletions(-) create mode 100644 aeon/util.py diff --git a/aeon/schema/social.py b/aeon/schema/social.py index 8a67269d..4eef8ce8 100644 --- a/aeon/schema/social.py +++ b/aeon/schema/social.py @@ -5,8 +5,10 @@ import pandas as pd +from aeon import util import aeon.io.reader as _reader +import ipdb class Pose(_reader.Harp): """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods.""" @@ -15,47 +17,43 @@ def __init__(self, pattern): self.extension = "bin" def read(self, file, ceph_proc_dir="/ceph/aeon/aeon/data/processed"): - """ - Reads data from the Harp-binarized tracking file - """ + """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. model_dir = file.stem.replace("_", "/") config_file_dir = Path(ceph_proc_dir + "/" + model_dir) assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}" config_file = get_config_file(config_file_dir) parts = self.get_bodyparts(config_file) - if parts is None: - raise ValueError(f"Cannot find bodyparts in {config_file}.") - # With bodyparts, set columns as set by Bonsai, and read data in default format. + # Using bodyparts, assign column names to Harp register values, and read data in default format. columns = ["class", "class_likelihood"] for part in parts: columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) self.columns = columns data = super().read(file) # Set new columns, and reformat `data`. - columns = ["class", "class_likelihood", "part", "part_likelihood", "x", "y"] - new_data = pd.DataFrame(columns=columns) - for part in parts: - part_data = data[["class", "class_likelihood", f"{part}_likelihood", f"{part}_x", f"{part}_y"]] - part_data.insert(2, "part", part) - part_data.columns = columns - new_data = pd.concat([new_data, part_data]) + new_columns = ["class", "class_likelihood", "part", "part_likelihood", "x", "y"] + new_data = pd.DataFrame(columns=new_columns) + part_data_list = [None] * len(parts) + for i, part in enumerate(parts): + part_data_list[i] = data[ + ["class", "class_likelihood", f"{part}_likelihood", f"{part}_x", f"{part}_y"] + ] + part_data_list[i].insert(2, "part", part) + part_data_list[i].columns = new_columns + new_data = pd.concat(part_data_list) return new_data.sort_index() - - def get_bodyparts(file): + + def get_bodyparts(self, file): """Returns a list of bodyparts from a model's config file.""" parts = None - if file.stem == "confmap_config": # SLEAP - with open(file) as f: + with open(file) as f: config = json.load(f) + if file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] - for model in heads.keys(): - if heads[model] is not None: - parts = heads[model]["confmaps"]["part_names"] - break + parts = util.find_nested_key(heads, "part_names") except KeyError as err: - raise KeyError(f"Cannot find bodyparts in {file}.") from err + raise KeyError(f"Cannot find bodyparts in {file}.") from err return parts @@ -74,6 +72,7 @@ def get_config_file( assert config_file is not None, f"Cannot find config file in {config_file_dir}" return config_file + def class_int2str(tracking_df, config_file_dir): """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" config_file = get_config_file(config_file_dir) @@ -82,12 +81,9 @@ def class_int2str(tracking_df, config_file_dir): config = json.load(f) try: heads = config["model"]["heads"] - for model in heads.keys(): - if heads[model] is not None: - classes = heads[model]["class_vectors"]["classes"] - break - for i, subj in enumerate(classes): - tracking_df.loc[tracking_df["class"] == i, "class"] = subj - return tracking_df + classes = util.find_nested_key(heads, "classes") except KeyError as err: raise KeyError(f"Cannot find classes in {config_file}.") from err + for i, subj in enumerate(classes): + tracking_df.loc[tracking_df["class"] == i, "class"] = subj + return tracking_df diff --git a/aeon/util.py b/aeon/util.py new file mode 100644 index 00000000..9e29513a --- /dev/null +++ b/aeon/util.py @@ -0,0 +1,19 @@ +"""Utility functions.""" + +from typing import Union + +def find_nested_key(obj: Union[dict, list], key: str): + """Returns the value of the first found nested key.""" + if isinstance(obj, dict): + for k, v in obj.items(): + if k == key: # Found it! + return v + found = find_nested_key(v, key) + if found: + return found + elif isinstance(obj, list): + for item in obj: + found = find_nested_key(item, key) + if found: + return found + return None