diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index e7c788fc..22ddf978 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -174,39 +174,24 @@ def make(self, key): if not len(pose_data): raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") - # Find the config file for the SLEAP model - for data_dir in data_dirs: - try: - f = next( - data_dir.glob( - f"**/**/{stream_reader.pattern}{io_api.chunk(chunk_start).strftime('%Y-%m-%dT%H-%M-%S')}*.{stream_reader.extension}" - ) - ) - except StopIteration: - continue - else: - config_file = stream_reader.get_config_file( - stream_reader._model_root / Path(*Path(f.stem.replace("_", "/")).parent.parts[1:]) - ) - break - else: - raise FileNotFoundError(f"Unable to find SLEAP model config file for: {stream_reader.pattern}") - # get bodyparts and classes - bodyparts = stream_reader.get_bodyparts(config_file) + bodyparts = stream_reader.get_bodyparts() anchor_part = bodyparts[0] # anchor_part is always the first one - class_names = stream_reader.get_class_names(config_file) + class_names = stream_reader.get_class_names() + identity_mapping = {n: i for i, n in enumerate(class_names)} # ingest parts and classes pose_identity_entries, part_entries = [], [] - for class_idx in set(pose_data["class"].values.astype(int)): - class_position = pose_data[pose_data["class"] == class_idx] - for part in set(class_position.part.values): - part_position = class_position[class_position.part == part] + for identity in identity_mapping: + identity_position = pose_data[pose_data["identity"] == identity] + if identity_position.empty: + continue + for part in set(identity_position.part.values): + part_position = identity_position[identity_position.part == part] part_entries.append( { **key, - "identity_idx": class_idx, + "identity_idx": identity_mapping[identity], "part_name": part, "timestamps": part_position.index.values, "x": part_position.x.values, @@ -216,14 +201,17 @@ def make(self, key): } ) if part == anchor_part: - class_likelihood = part_position.class_likelihood.values + identity_likelihood = part_position.identity_likelihood.values + if isinstance(identity_likelihood[0], dict): + identity_likelihood = np.array([v[identity] for v in identity_likelihood]) + pose_identity_entries.append( { **key, - "identity_idx": class_idx, - "identity_name": class_names[class_idx], + "identity_idx": identity_mapping[identity], + "identity_name": identity, "anchor_part": anchor_part, - "identity_likelihood": class_likelihood, + "identity_likelihood": identity_likelihood, } ) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 34f258cd..25db82a8 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -286,23 +286,35 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) self._model_root = model_root + self.config_file = None # requires reading the data file to be set def read(self, file: Path) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. - model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:]) + model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:]) config_file_dir = Path(self._model_root) / model_dir if not config_file_dir.exists(): raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") - config_file = self.get_config_file(config_file_dir) - parts = self.get_bodyparts(config_file) + self.config_file = self.get_config_file(config_file_dir) + identities = self.get_class_names() + parts = self.get_bodyparts() # Using bodyparts, assign column names to Harp register values, and read data in default format. - columns = ["class", "class_likelihood"] - for part in parts: - columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) - self.columns = columns - data = super().read(file) + try: # Bonsai.Sleap0.2 + bonsai_sleap_v = 0.2 + columns = ["identity", "identity_likelihood"] + for part in parts: + columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) + self.columns = columns + data = super().read(file) + except ValueError: # column mismatch; Bonsai.Sleap0.3 + bonsai_sleap_v = 0.3 + columns = ["identity"] + columns.extend([f"{identity}_likelihood" for identity in identities]) + for part in parts: + columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) + self.columns = columns + data = super().read(file) # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) @@ -315,54 +327,74 @@ def read(self, file: Path) -> pd.DataFrame: parts = unique_parts # Set new columns, and reformat `data`. + data = self.class_int2str(data) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts - new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"] + new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): - part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"] + part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2] + part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) part_data = pd.DataFrame(data[part_columns]) + if bonsai_sleap_v == 0.3: # combine all identity_likelihood cols into a single col as dict + part_data["identity_likelihood"] = part_data.apply( + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1 + ) + part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) + part_data = part_data[ # reorder columns + ["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"] + ] part_data.insert(2, "part", part) part_data.columns = new_columns part_data_list[i] = part_data new_data = pd.concat(part_data_list) return new_data.sort_index() - def get_class_names(self, file: Path) -> list[str]: + def get_class_names(self) -> list[str]: """Returns a list of classes from a model's config file.""" classes = None - with open(file) as f: + with open(self.config_file) as f: config = json.load(f) - if file.stem == "confmap_config": # SLEAP + if self.config_file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] classes = util.find_nested_key(heads, "class_vectors")["classes"] except KeyError as err: if not classes: - raise KeyError(f"Cannot find class_vectors in {file}.") from err + raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err return classes - def get_bodyparts(self, file: Path) -> list[str]: + def get_bodyparts(self) -> list[str]: """Returns a list of bodyparts from a model's config file.""" parts = [] - with open(file) as f: + with open(self.config_file) as f: config = json.load(f) - if file.stem == "confmap_config": # SLEAP + if self.config_file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] parts = [util.find_nested_key(heads, "anchor_part")] parts += util.find_nested_key(heads, "part_names") except KeyError as err: if not parts: - raise KeyError(f"Cannot find bodyparts in {file}.") from err + raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err return parts + def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame: + """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" + if self.config_file.stem == "confmap_config": # SLEAP + with open(self.config_file) as f: + config = json.load(f) + try: + heads = config["model"]["heads"] + classes = util.find_nested_key(heads, "classes") + except KeyError as err: + raise KeyError(f"Cannot find classes in {self.config_file}.") from err + for i, subj in enumerate(classes): + data.loc[data["identity"] == i, "identity"] = subj + return data + @classmethod - def get_config_file( - cls, - config_file_dir: Path, - config_file_names: None | list[str] = None, - ) -> Path: + def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) @@ -375,22 +407,6 @@ def get_config_file( raise FileNotFoundError(f"Cannot find config file in {config_file_dir}") return config_file - @classmethod - def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame: - """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" - config_file = cls.get_config_file(config_file_dir) - if config_file.stem == "confmap_config": # SLEAP - with open(config_file) as f: - config = json.load(f) - try: - heads = config["model"]["heads"] - classes = util.find_nested_key(heads, "classes") - except KeyError as err: - raise KeyError(f"Cannot find classes in {config_file}.") from err - for i, subj in enumerate(classes): - data.loc[data["class"] == i, "class"] = subj - return data - def from_dict(data, pattern=None): reader_type = data.get("type", None)