Skip to content

Commit

Permalink
Refactor trodes position LorenFrankLab#613
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 18, 2023
1 parent a5aa589 commit bb7580c
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 176 deletions.
163 changes: 119 additions & 44 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
from functools import reduce
from typing import Dict

import datajoint as dj
Expand Down Expand Up @@ -27,15 +28,18 @@
class PositionSource(dj.Manual):
definition = """
-> Session
-> IntervalList
---
source: varchar(200) # source of data (e.g., trodes, dlc)
import_file_name: varchar(2000) # path to import file if importing
"""

class IntervalList(dj.Part):
class SpatialSeries(dj.Part):
definition = """
-> master
-> IntervalList
id : int unsigned # index of spatial series
---
name=null: varchar(32) # name of spatial series
"""

@classmethod
Expand All @@ -51,48 +55,79 @@ def insert_from_nwbfile(cls, nwb_file_name):
nwbf = get_nwb_file(nwb_file_name)
all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False)
sess_key = dict(nwb_file_name=nwb_file_name)
pos_source_key = dict(**sess_key, source="trodes", import_file_name="")
src_key = dict(**sess_key, source="trodes", import_file_name="")

if all_pos is None:
return

sources = []
intervals = []
pos_intervals = []
spat_series = []

for epoch, epoch_list in enumerate(all_pos.values()):
for index, pdict in enumerate(epoch_list):
pos_interval_name = cls.get_pos_interval_name([epoch, index])
for epoch, epoch_list in all_pos.items():
ind_key = dict(interval_list_name=cls.get_pos_interval_name(epoch))

intervals.append(
sources.append(dict(**src_key, **ind_key))
intervals.append(
dict(
**sess_key,
**ind_key,
valid_times=epoch_list[0]["valid_times"],
)
)

for index, pdict in enumerate(epoch_list):
spat_series.append(
dict(
**sess_key,
interval_list_name=pos_interval_name,
valid_times=pos_dict["valid_times"],
**ind_key,
id=ndex,
name=pdict.get("name"),
)
)

# UNTESTED
IntervalList.insert(intervals, skip_duplicates=True)
cls.IntervalList.insert(intervals, skip_duplicates=True)
cls.insert1(key)
with cls.connection.transaction:
IntervalList.insert(intervals)
cls.insert(sources)
cls.SpatialSeries.insert(spat_series)

@staticmethod
def get_pos_interval_name(pos_epoch_num):
"""Retun string of the interval name from the epoch number.
def get_pos_interval_name(epoch_num: int) -> str:
"""Return string of the interval name from the epoch number.
Parameters
----------
pos_epoch_num : int or str or list
If list of length 2, then a string of the form "epoch 1 index 2"
pos_epoch_num : int
Input epoch number
Returns
-------
str
Position interval name (e.g., pos epoch 1 index 2 valid times)
Position interval name (e.g., pos 2 valid times)
"""
if isinstance(pos_epoch_num, list) and len(pos_epoch_num) == 2:
pos_epoch_num = f"epoch {pos_epoch_num[0]} index {pos_epoch_num[1]}"
return f"pos {pos_epoch_num} valid times"
try:
int(epoch_num)
except ValueError:
raise ValueError(
f"Epoch number must must be an integer. Received: {epoch_num}"
)
return f"pos {epoch_num} valid times"

@staticmethod
def get_epoch_num(name: str) -> int:
"""Return the epoch number from the interval name.
Parameters
----------
name : str
Name of position interval (e.g., pos epoch 1 index 2 valid times)
Returns
-------
int
epoch number
"""
return int(name.replace("pos ", "").replace(" valid times", ""))


@schema
Expand All @@ -109,37 +144,77 @@ class RawPosition(dj.Imported):

definition = """
-> PositionSource
---
raw_position_object_id: varchar(40) # the object id of the spatial series for this epoch in the NWB file
"""

class Object(dj.Part):
definition = """
-> master
-> PositionSource.SpatialSeries.proj('id')
---
raw_position_object_id: varchar(40) # id of spatial series in NWB file
"""

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self):
INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1)

id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()]

df_list = [
pd.DataFrame(
data=rp.data,
index=pd.Index(rp.timestamps, name="time"),
columns=[
col # use existing columns if already numbered
if "1" in rp.description or "2" in rp.description
# else number them by id
else col + str(id + INDEX_ADJUST)
for col in rp.description.split(", ")
],
)
for id, rp in id_rp
]

return reduce(lambda x, y: pd.merge(x, y, on="time"), df_list)

def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)

# TODO refactor this. this calculates sampling rate (unused here) and is expensive to do twice
pos_dict = get_all_spatial_series(nwbf)
for epoch in pos_dict:
if key[
"interval_list_name"
] != PositionSource.get_pos_interval_name(epoch):
continue
interval_list_name = key["interval_list_name"]

pdict = pos_dict[epoch]
key["raw_position_object_id"] = pdict["raw_position_object_id"]
self.insert1(key)
break
nwbf = get_nwb_file(nwb_file_name)
indices = (PositionSource.SpatialSeries & key).fetch("id")

# incl_times = False -> don't do extra processing for valid_times
spat_objs = get_all_spatial_series(
nwbf, old_format=False, incl_times=False
)[PositionSource.get_epoch_num(interval_list_name)]

self.insert1(key)
self.Object.insert(
[
dict(
nwb_file_name=nwb_file_name,
interval_list_name=interval_list_name,
id=index,
raw_position_object_id=obj["raw_position_object_id"],
)
for index, obj in enumerate(spat_objs)
if index in indices
]
)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)
raise NotImplementedError(
"fetch_nwb now operates on RawPosition.Object"
)

def fetch1_dataframe(self):
raw_position_nwb = self.fetch_nwb()[0]["raw_position"]
return pd.DataFrame(
data=raw_position_nwb.data,
index=pd.Index(raw_position_nwb.timestamps, name="time"),
columns=raw_position_nwb.description.split(", "),
raise NotImplementedError(
"fetch1_dataframe now operates on RawPosition.Object"
)


Expand Down
26 changes: 21 additions & 5 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,25 @@ def insert_from_relative_file_name(cls, nwb_file_name):
key["nwb_file_abs_path"] = nwb_file_abs_path
cls.insert1(key, skip_duplicates=True)

@staticmethod
def get_abs_path(nwb_file_name):
@classmethod
def _get_file_name(cls, nwb_file_name: str) -> str:
"""Get valid nwb file name given substring."""
query = cls & f'nwb_file_name LIKE "%{nwb_file_name}%"'

if len(query) == 1:
return query.fetch1("nwb_file_name")

raise ValueError(
f"Found {len(query)} matches for {nwb_file_name}: \n{query}"
)

@classmethod
def get_file_key(cls, nwb_file_name: str) -> dict:
"""Return primary key using nwb_file_name substring."""
return {"nwb_file_name": cls._get_file_name(nwb_file_name)}

@classmethod
def get_abs_path(cls, nwb_file_name) -> str:
"""Return absolute path for a stored raw NWB file given file name.
The SPYGLASS_BASE_DIR must be set, either as an environment or part of
Expand All @@ -80,15 +97,14 @@ def get_abs_path(nwb_file_name):
----------
nwb_file_name : str
The name of an NWB file that has been inserted into the Nwbfile()
schema.
table. May be file substring. May include % wildcard(s).
Returns
-------
nwb_file_abspath : str
The absolute path for the given file name.
"""

return raw_dir + "/" + nwb_file_name
return raw_dir + "/" + cls._get_file_name(nwb_file_name)

@staticmethod
def add_to_lock(nwb_file_name):
Expand Down
Loading

0 comments on commit bb7580c

Please sign in to comment.