Skip to content

Commit

Permalink
Merge pull request #374 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
Update SLEAP data ingestion to work with the new Pose Reader for Bonsai-Sleap0.3
  • Loading branch information
jkbhagatio authored Jul 3, 2024
2 parents 875e7af + 069385a commit 5cbd091
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 68 deletions.
46 changes: 17 additions & 29 deletions aeon/dj_pipeline/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,39 +174,24 @@ def make(self, key):
if not len(pose_data):
raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}")

# Find the config file for the SLEAP model
for data_dir in data_dirs:
try:
f = next(
data_dir.glob(
f"**/**/{stream_reader.pattern}{io_api.chunk(chunk_start).strftime('%Y-%m-%dT%H-%M-%S')}*.{stream_reader.extension}"
)
)
except StopIteration:
continue
else:
config_file = stream_reader.get_config_file(
stream_reader._model_root / Path(*Path(f.stem.replace("_", "/")).parent.parts[1:])
)
break
else:
raise FileNotFoundError(f"Unable to find SLEAP model config file for: {stream_reader.pattern}")

# get bodyparts and classes
bodyparts = stream_reader.get_bodyparts(config_file)
bodyparts = stream_reader.get_bodyparts()
anchor_part = bodyparts[0] # anchor_part is always the first one
class_names = stream_reader.get_class_names(config_file)
class_names = stream_reader.get_class_names()
identity_mapping = {n: i for i, n in enumerate(class_names)}

# ingest parts and classes
pose_identity_entries, part_entries = [], []
for class_idx in set(pose_data["class"].values.astype(int)):
class_position = pose_data[pose_data["class"] == class_idx]
for part in set(class_position.part.values):
part_position = class_position[class_position.part == part]
for identity in identity_mapping:
identity_position = pose_data[pose_data["identity"] == identity]
if identity_position.empty:
continue
for part in set(identity_position.part.values):
part_position = identity_position[identity_position.part == part]
part_entries.append(
{
**key,
"identity_idx": class_idx,
"identity_idx": identity_mapping[identity],
"part_name": part,
"timestamps": part_position.index.values,
"x": part_position.x.values,
Expand All @@ -216,14 +201,17 @@ def make(self, key):
}
)
if part == anchor_part:
class_likelihood = part_position.class_likelihood.values
identity_likelihood = part_position.identity_likelihood.values
if isinstance(identity_likelihood[0], dict):
identity_likelihood = np.array([v[identity] for v in identity_likelihood])

pose_identity_entries.append(
{
**key,
"identity_idx": class_idx,
"identity_name": class_names[class_idx],
"identity_idx": identity_mapping[identity],
"identity_name": identity,
"anchor_part": anchor_part,
"identity_likelihood": class_likelihood,
"identity_likelihood": identity_likelihood,
}
)

Expand Down
94 changes: 55 additions & 39 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,35 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:])
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
bonsai_sleap_v = 0.2
columns = ["identity", "identity_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
except ValueError: # column mismatch; Bonsai.Sleap0.3
bonsai_sleap_v = 0.3
columns = ["identity"]
columns.extend([f"{identity}_likelihood" for identity in identities])
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)

# Drop any repeat parts.
unique_parts, unique_idxs = np.unique(parts, return_index=True)
Expand All @@ -315,54 +327,74 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
part_data = pd.DataFrame(data[part_columns])
if bonsai_sleap_v == 0.3: # combine all identity_likelihood cols into a single col as dict
part_data["identity_likelihood"] = part_data.apply(
lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1
)
part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True)
part_data = part_data[ # reorder columns
["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
]
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_class_names(self, file: Path) -> list[str]:
def get_class_names(self) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {file}.") from err
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
return classes

def get_bodyparts(self, file: Path) -> list[str]:
def get_bodyparts(self) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data

@classmethod
def get_config_file(
cls,
config_file_dir: Path,
config_file_names: None | list[str] = None,
) -> Path:
def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list)
Expand All @@ -375,22 +407,6 @@ def get_config_file(
raise FileNotFoundError(f"Cannot find config file in {config_file_dir}")
return config_file

@classmethod
def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = cls.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["class"] == i, "class"] = subj
return data


def from_dict(data, pattern=None):
reader_type = data.get("type", None)
Expand Down

0 comments on commit 5cbd091

Please sign in to comment.