Skip to content

Commit

Permalink
Merge pull request #421 from SainsburyWellcomeCentre/gl-issue-418
Browse files Browse the repository at this point in the history
Allow reading pose model metadata from local folder
  • Loading branch information
glopesdev authored Oct 3, 2024
2 parents 7812b4f + f925d75 commit 83cd905
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 19 deletions.
49 changes: 31 additions & 18 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,37 @@ class (int): Int ID of a subject in the environment.
"""

def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"):
"""Pose reader constructor."""
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
"""Pose reader constructor.
The pattern for this reader should typically be `<device>_<hpcnode>_<jobid>*`.
If a register prefix is required, the pattern should end with a trailing
underscore, e.g. `Camera_202_*`. Otherwise, the pattern should include a
common prefix for the pose model folder excluding the trailing underscore,
e.g. `Camera_model-dir*`.
"""
super().__init__(pattern, columns=None)
self._model_root = model_root
self._pattern_offset = pattern.rfind("_") + 1

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
model_dir = Path(file.stem[self._pattern_offset :].replace("_", "/")).parent

# Check if model directory exists in local or shared directories.
# Local directory is prioritized over shared directory.
local_config_file_dir = file.parent / model_dir
shared_config_file_dir = Path(self._model_root) / model_dir
if local_config_file_dir.exists():
config_file_dir = local_config_file_dir
elif shared_config_file_dir.exists():
config_file_dir = shared_config_file_dir
else:
raise FileNotFoundError(
f"""Cannot find model dir in either local ({local_config_file_dir}) \
or shared ({shared_config_file_dir}) directories"""
)

config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)
Expand Down Expand Up @@ -350,7 +369,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data, config_file)
data = self.class_int2str(data, identities)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"])
Expand Down Expand Up @@ -407,18 +426,12 @@ def get_bodyparts(config_file: Path) -> list[str]:
return parts

@staticmethod
def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame:
def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
if not classes:
raise ValueError("Classes list cannot be None or empty.")
identity_mapping = dict(enumerate(classes))
data["identity"] = data["identity"].replace(identity_mapping)
return data

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion aeon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def find_nested_key(obj: dict | list, key: str) -> Any:
found = find_nested_key(v, key)
if found:
return found
else:
elif obj is not None:
for item in obj:
found = find_nested_key(item, key)
if found:
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
{
"data": {
"labels": {
"training_labels": "social_dev_b5350ff/aeon3_social_dev_b5350ff_ceph.slp",
"validation_labels": null,
"validation_fraction": 0.1,
"test_labels": null,
"split_by_inds": false,
"training_inds": null,
"validation_inds": null,
"test_inds": null,
"search_path_hints": [],
"skeletons": [
{
"directed": true,
"graph": {
"name": "Skeleton-1",
"num_edges_inserted": 0
},
"links": [],
"multigraph": true,
"nodes": [
{
"id": {
"py/object": "sleap.skeleton.Node",
"py/state": {
"py/tuple": [
"centroid",
1.0
]
}
}
}
]
}
]
},
"preprocessing": {
"ensure_rgb": false,
"ensure_grayscale": false,
"imagenet_mode": null,
"input_scaling": 1.0,
"pad_to_stride": 16,
"resize_and_pad_to_target": true,
"target_height": 1080,
"target_width": 1440
},
"instance_cropping": {
"center_on_part": "centroid",
"crop_size": 96,
"crop_size_detection_padding": 16
}
},
"model": {
"backbone": {
"leap": null,
"unet": {
"stem_stride": null,
"max_stride": 16,
"output_stride": 2,
"filters": 16,
"filters_rate": 1.5,
"middle_block": true,
"up_interpolate": false,
"stacks": 1
},
"hourglass": null,
"resnet": null,
"pretrained_encoder": null
},
"heads": {
"single_instance": null,
"centroid": null,
"centered_instance": null,
"multi_instance": null,
"multi_class_bottomup": null,
"multi_class_topdown": {
"confmaps": {
"anchor_part": "centroid",
"part_names": [
"centroid"
],
"sigma": 1.5,
"output_stride": 2,
"loss_weight": 1.0,
"offset_refinement": false
},
"class_vectors": {
"classes": [
"BAA-1104045",
"BAA-1104047"
],
"num_fc_layers": 3,
"num_fc_units": 256,
"global_pool": true,
"output_stride": 2,
"loss_weight": 0.01
}
}
},
"base_checkpoint": null
},
"optimization": {
"preload_data": true,
"augmentation_config": {
"rotate": true,
"rotation_min_angle": -180.0,
"rotation_max_angle": 180.0,
"translate": false,
"translate_min": -5,
"translate_max": 5,
"scale": false,
"scale_min": 0.9,
"scale_max": 1.1,
"uniform_noise": false,
"uniform_noise_min_val": 0.0,
"uniform_noise_max_val": 10.0,
"gaussian_noise": false,
"gaussian_noise_mean": 5.0,
"gaussian_noise_stddev": 1.0,
"contrast": false,
"contrast_min_gamma": 0.5,
"contrast_max_gamma": 2.0,
"brightness": false,
"brightness_min_val": 0.0,
"brightness_max_val": 10.0,
"random_crop": false,
"random_crop_height": 256,
"random_crop_width": 256,
"random_flip": false,
"flip_horizontal": true
},
"online_shuffling": true,
"shuffle_buffer_size": 128,
"prefetch": true,
"batch_size": 4,
"batches_per_epoch": 469,
"min_batches_per_epoch": 200,
"val_batches_per_epoch": 54,
"min_val_batches_per_epoch": 10,
"epochs": 200,
"optimizer": "adam",
"initial_learning_rate": 0.0001,
"learning_rate_schedule": {
"reduce_on_plateau": true,
"reduction_factor": 0.1,
"plateau_min_delta": 1e-08,
"plateau_patience": 20,
"plateau_cooldown": 3,
"min_learning_rate": 1e-08
},
"hard_keypoint_mining": {
"online_mining": false,
"hard_to_easy_ratio": 2.0,
"min_hard_keypoints": 2,
"max_hard_keypoints": null,
"loss_scale": 5.0
},
"early_stopping": {
"stop_training_on_plateau": true,
"plateau_min_delta": 1e-08,
"plateau_patience": 20
}
},
"outputs": {
"save_outputs": true,
"run_name": "aeon3_social_dev_b5350ff_ceph_topdown_top.centered_instance_multiclass",
"run_name_prefix": "",
"run_name_suffix": "",
"runs_folder": "social_dev_b5350ff/models",
"tags": [],
"save_visualizations": true,
"delete_viz_images": true,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": true,
"best_model": true,
"every_epoch": false,
"latest_model": false,
"final_model": false
},
"tensorboard": {
"write_logs": false,
"loss_frequency": "epoch",
"architecture_graph": false,
"profile_graph": false,
"visualizations": true
},
"zmq": {
"subscribe_to_controller": false,
"controller_address": "tcp://127.0.0.1:9000",
"controller_polling_timeout": 10,
"publish_updates": false,
"publish_address": "tcp://127.0.0.1:9001"
}
},
"name": "",
"description": "",
"sleap_version": "1.3.1",
"filename": "Z:/aeon/data/processed/test-node1/4310907/2024-01-12T19-00-00/topdown-multianimal-id-133/confmap_config.json"
}
25 changes: 25 additions & 0 deletions tests/io/test_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

import pytest
from pytest import mark

import aeon
from aeon.schema.schemas import social02, social03

pose_path = Path(__file__).parent.parent / "data" / "pose"


@mark.api
def test_Pose_read_local_model_dir():
data = aeon.load(pose_path, social02.CameraTop.Pose)
assert len(data) > 0


@mark.api
def test_Pose_read_local_model_dir_with_register_prefix():
data = aeon.load(pose_path, social03.CameraTop.Pose)
assert len(data) > 0


if __name__ == "__main__":
pytest.main()

0 comments on commit 83cd905

Please sign in to comment.