Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pose reader #253

Merged
merged 5 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions aeon/schema/social.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Readers for data relevant to Social experiments."""

from pathlib import Path
from typing import List, Union
import json
from pathlib import Path

import numpy as np
import pandas as pd

from aeon import util
import aeon.io.reader as _reader
from aeon import util
jkbhagatio marked this conversation as resolved.
Show resolved Hide resolved


class Pose(_reader.Harp):
Expand All @@ -25,57 +25,71 @@ def __init__(self, pattern: str, extension: str="bin"):
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None, extension=extension)

def read(self, file: Path, ceph_proc_dir: Path=Path("/ceph/aeon/aeon/data/processed")) -> pd.DataFrame:
def read(
self, file: Path, ceph_proc_dir: str | Path = "/ceph/aeon/aeon/data/processed"
) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(file.stem.replace("_", "/")).parent
config_file_dir = ceph_proc_dir / model_dir
assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}"
config_file = get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)


# Drop any repeat parts.
unique_parts, unique_idxs = np.unique(parts, return_index=True)
repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs)
if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols)
init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5
rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx])
keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs)
data = data.iloc[:, keep_part_col_idxs]
parts = unique_parts

# Set new columns, and reformat `data`.
n_parts = len(parts)
part_data_list = [None] * n_parts
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_data = data[part_columns]
part_data = pd.DataFrame(data[part_columns])
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()
jkbhagatio marked this conversation as resolved.
Show resolved Hide resolved

jkbhagatio marked this conversation as resolved.
Show resolved Hide resolved
def get_bodyparts(self, file: Path) -> Union[None, List[str]]:
def get_bodyparts(self, file: Path) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = None
parts = []
with open(file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = util.find_nested_key(heads, "part_names")
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
if parts is None:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
return parts
jkbhagatio marked this conversation as resolved.
Show resolved Hide resolved


def get_config_file(
config_file_dir: Path,
config_file_names: List[str]=[
"confmap_config.json", # SLEAP (add others for other trackers to this list)
],
):
config_file_names: None | list[str] = None,
) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list)
config_file = None
for f in config_file_names:
if (config_file_dir / f).exists():
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ reportAssertAlwaysTrue = "error"
reportSelfClsParameterName = "error"
reportUnusedExpression = "error"
reportMatchNotExhaustive = "error"
reportImplicitOverride = "error"
reportShadowedImports = "error"
# *Note*: we may want to set all 'ReportOptional*' rules to "none", but leaving 'em default for now
venvPath = "."
Expand Down
Loading