Skip to content

Commit

Permalink
Fetch cbroz1/613
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Sep 27, 2023
2 parents 8c36f05 + 4e6b29f commit d333864
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 119 deletions.
187 changes: 102 additions & 85 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def insert_from_nwbfile(cls, nwb_file_name):
"""
nwbf = get_nwb_file(nwb_file_name)
all_pos = get_all_spatial_series(nwbf, verbose=True)
sess_key = dict(nwb_file_name=nwb_file_name)
sess_key = Nwbfile.get_file_key(nwb_file_name)
src_key = dict(**sess_key, source="trodes", import_file_name="")

if all_pos is None:
Expand Down Expand Up @@ -119,6 +119,10 @@ def get_pos_interval_name(epoch_num: int) -> str:
)
return f"pos {epoch_num} valid times"

@staticmethod
def _is_valid_name(name) -> bool:
return name.startswith("pos ") and name.endswith(" valid times")

@staticmethod
def get_epoch_num(name: str) -> int:
"""Return the epoch number from the interval name.
Expand All @@ -133,6 +137,8 @@ def get_epoch_num(name: str) -> int:
int
epoch number
"""
if not PositionSource._is_valid_name(name):
raise ValueError(f"Invalid interval name: {name}")
return int(name.replace("pos ", "").replace(" valid times", ""))


Expand All @@ -152,7 +158,7 @@ class RawPosition(dj.Imported):
-> PositionSource
"""

class Object(dj.Part):
class PosObject(dj.Part):
definition = """
-> master
-> PositionSource.SpatialSeries.proj('id')
Expand All @@ -170,6 +176,9 @@ def fetch1_dataframe(self):

id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()]

if len(set(rp.interval for _, rp in id_rp)) > 1:
print("WARNING: loading DataFrame with multiple intervals.")

df_list = [
pd.DataFrame(
data=rp.data,
Expand Down Expand Up @@ -200,7 +209,7 @@ def make(self, key):
]

self.insert1(key)
self.Object.insert(
self.PosObject.insert(
[
dict(
nwb_file_name=nwb_file_name,
Expand All @@ -213,15 +222,35 @@ def make(self, key):
]
)

def fetch_nwb(self, *attrs, **kwargs):
raise NotImplementedError(
"fetch_nwb now operates on RawPosition.Object"
def fetch_nwb(self, *attrs, **kwargs) -> list:
"""
Returns a condatenated list of nwb objects from RawPosition.PosObject
"""
return (
self.PosObject()
.restrict(self.restriction) # Avoids fetch_nwb on whole table
.fetch_nwb(*attrs, **kwargs)
)

def fetch1_dataframe(self):
raise NotImplementedError(
"fetch1_dataframe now operates on RawPosition.Object"
)
"""Returns a dataframe with all RawPosition.PosObject items.
Uses interval_list_name as column index.
"""
ret = {}

pos_obj_set = self.PosObject & self.restriction
unique_intervals = set(pos_obj_set.fetch("interval_list_name"))

for interval in unique_intervals:
ret[interval] = (
pos_obj_set & {"interval_list_name": interval}
).fetch1_dataframe()

if len(unique_intervals) == 1:
return next(iter(ret.values()))

return pd.concat(ret, axis=1)


@schema
Expand Down Expand Up @@ -434,139 +463,127 @@ def make(self, key):
self._no_transaction_make(key)

def _no_transaction_make(self, key):
# Find correspondence between pos valid times names and epochs
# Use epsilon to tolerate small differences in epoch boundaries across epoch/pos intervals
# Find correspondence between pos valid times names and epochs. Use
# epsilon to tolerate small differences in epoch boundaries across
# epoch/pos intervals

if not self.connection.in_transaction:
# if not called in the context of a make function, call its own make function
self.populate(key)
return

# *** HARD CODED VALUES ***
EPSILON = 0.11 # tolerated time difference in epoch boundaries across epoch/pos intervals
# *************************
EPSILON = 0.11 # tolerated time diff in bounds across epoch/pos
no_pop_msg = "CANNOT POPULATE PositionIntervalMap"

# Unpack key
nwb_file_name = key["nwb_file_name"]

# Get pos interval list names
pos_interval_list_names = get_pos_interval_list_names(nwb_file_name)
pos_intervals = get_pos_interval_list_names(nwb_file_name)

# Skip populating if no pos interval list names
if len(pos_interval_list_names) == 0:
print(
f"NO POS INTERVALS FOR {key}; CANNOT POPULATE PositionIntervalMap"
)
if len(pos_intervals) == 0:
print(f"NO POS INTERVALS FOR {key}; {no_pop_msg}")
return

# Get the interval times
valid_times = (IntervalList & key).fetch1("valid_times")
time_interval = [
time_bounds = [
valid_times[0][0] - EPSILON,
valid_times[-1][-1] + EPSILON,
] # [start, end], widen to tolerate small differences in epoch boundaries across epoch/pos intervals

# compare the position intervals against our interval
matching_pos_interval_list_names = []
for (
pos_interval_list_name
) in pos_interval_list_names: # for each pos valid time interval list
pos_valid_times = (
IntervalList
& {
"nwb_file_name": nwb_file_name,
"interval_list_name": pos_interval_list_name,
}
).fetch1(
]

matching_pos_intervals = []
restr = (
f"nwb_file_name='{nwb_file_name}' AND interval_list_name=" + "'{}'"
)
for pos_interval in pos_intervals:
# cbroz: fetch1->fetch. fetch1 would fail w/o result
pos_times = (IntervalList & restr.format(pos_interval)).fetch(
"valid_times"
) # get interval valid times
if len(pos_valid_times) == 0:
)

if len(pos_times) == 0:
continue
pos_time_interval = [
pos_valid_times[0][0],
pos_valid_times[-1][-1],
] # [pos valid time interval start, pos valid time interval end]
if (time_interval[0] < pos_time_interval[0]) and (
time_interval[1] > pos_time_interval[1]
): # if pos valid time interval within epoch interval
matching_pos_interval_list_names.append(
pos_interval_list_name
) # add pos interval list name to list of matching pos interval list names

pos_times = pos_times[0]

if all(
[
time_bounds[0] <= time <= time_bounds[1]
for time in [pos_times[0][0], pos_times[-1][-1]]
]
):
matching_pos_intervals.append(pos_interval)

if len(matching_pos_intervals) > 1:
break

# Check that each pos interval was matched to only one epoch
if len(matching_pos_interval_list_names) > 1:
print(
f"MULTIPLE POS INTERVALS MATCHED TO EPOCH {key}; CANNOT POPULATE PositionIntervalMap"
)
print(matching_pos_interval_list_names)
return
# Check that at least one pos interval was matched to an epoch
if len(matching_pos_interval_list_names) == 0:
if len(matching_pos_intervals) != 1:
print(
f"No pos intervals matched to epoch {key}; CANNOT POPULATE PositionIntervalMap"
f"Found {len(matching_pos_intervals)} pos intervals for {key}; "
+ f"{no_pop_msg}\n{matching_pos_intervals}"
)
return

# Insert into table
key["position_interval_name"] = matching_pos_interval_list_names[0]
key["position_interval_name"] = matching_pos_intervals[0]
self.insert1(key, allow_direct_insert=True)
print(
f'Populated PosIntervalMap for {nwb_file_name}, {key["interval_list_name"]}'
"Populated PosIntervalMap for "
+ f'{nwb_file_name}, {key["interval_list_name"]}'
)


def get_pos_interval_list_names(nwb_file_name):
def get_pos_interval_list_names(nwb_file_name) -> list:
return [
interval_list_name
for interval_list_name in (
IntervalList & {"nwb_file_name": nwb_file_name}
).fetch("interval_list_name")
if (
(interval_list_name.split(" ")[0] == "pos")
and (" ".join(interval_list_name.split(" ")[2:]) == "valid times")
)
if PositionSource._is_valid_name(interval_list_name)
]


def convert_epoch_interval_name_to_position_interval_name(
key: dict, populate_missing: bool = True
) -> str:
"""Converts a primary key for IntervalList to the corresponding position interval name.
"""Converts IntervalList key to the corresponding position interval name.
Parameters
----------
key : dict
Lookup key
populate_missing: bool
whether to populate PositionIntervalMap for the key if missing. Should be False if this function is used inside of another populate call
Whether to populate PositionIntervalMap for the key if missing. Should
be False if this function is used inside of another populate call.
Defaults to True
Returns
-------
position_interval_name : str
"""
# get the interval list name if epoch given in key instead of interval list name
# get the interval list name if given epoch but not interval list name
if "interval_list_name" not in key and "epoch" in key:
key["interval_list_name"] = get_interval_list_name_from_epoch(
key["nwb_file_name"], key["epoch"]
)

pos_interval_names = (PositionIntervalMap & key).fetch(
"position_interval_name"
)
if len(pos_interval_names) == 0:
pos_query = PositionIntervalMap & key

if len(pos_query) == 0:
if populate_missing:
PositionIntervalMap()._no_transaction_make(key)
pos_interval_names = (PositionIntervalMap & key).fetch(
"position_interval_name"
)
else:
raise KeyError(
f"{key} must be populated in the PositionIntervalMap table prior to your current populate call"
f"{key} must be populated in the PositionIntervalMap table "
+ "prior to your current populate call"
)
if len(pos_interval_names) == 0:

if len(pos_query) == 0:
print(f"No position intervals found for {key}")
return []
if len(pos_interval_names) == 1:
return pos_interval_names[0]

if len(pos_query) == 1:
return pos_query.fetch1("position_interval_name")


def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str:
Expand All @@ -591,19 +608,19 @@ def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str:
)
if (x.split("_")[0] == f"{epoch:02}")
]
if len(interval_names) == 0:
print(f"No interval list name found for {nwb_file_name} epoch {epoch}")
return None
if len(interval_names) > 1:

if len(interval_names) != 1:
print(
f"Multiple interval list names found for {nwb_file_name} epoch {epoch}"
f"Found {len(interval_name)} interval list names found for "
+ f"{nwb_file_name} epoch {epoch}"
)
return None

return interval_names[0]


def populate_position_interval_map_session(nwb_file_name: str):
for interval_name in (TaskEpoch() & {"nwb_file_name": nwb_file_name}).fetch(
for interval_name in (TaskEpoch & {"nwb_file_name": nwb_file_name}).fetch(
"interval_list_name"
):
PositionIntervalMap.populate(
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def insert_from_nwbfile(cls, nwbf):
The NWB file with experimenter information.
"""
if isinstance(nwbf, str):
nwb_file_abspath = Nwbfile.get_abs_path(nwbf)
nwb_file_abspath = Nwbfile.get_abs_path(nwbf, new_file=True)
nwbf = get_nwb_file(nwb_file_abspath)

if nwbf.experimenter is None:
Expand Down
5 changes: 3 additions & 2 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def insert_from_relative_file_name(cls, nwb_file_name):
nwb_file_name : str
The relative path to the NWB file.
"""
nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name)
nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True)
assert os.path.exists(
nwb_file_abs_path
), f"File does not exist: {nwb_file_abs_path}"
Expand All @@ -78,7 +78,8 @@ def _get_file_name(cls, nwb_file_name: str) -> str:
return query.fetch1("nwb_file_name")

raise ValueError(
f"Found {len(query)} matches for {nwb_file_name}: \n{query}"
f"Found {len(query)} matches for {nwb_file_name} in Nwbfile table:"
+ f" \n{query}"
)

@classmethod
Expand Down
13 changes: 9 additions & 4 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

import datajoint as dj

from ..settings import config
from ..settings import config, debug_mode
from ..utils.nwb_helper_fn import get_config, get_nwb_file
from .common_device import CameraDevice, DataAcquisitionDevice, Probe
from .common_lab import Institution, Lab, LabMember
Expand Down Expand Up @@ -79,9 +81,10 @@ def make(self, key):
print("Subject...")
Subject().insert_from_nwbfile(nwbf)

print("Populate DataAcquisitionDevice...")
DataAcquisitionDevice.insert_from_nwbfile(nwbf, config)
print()
if not debug_mode: # TODO: remove when demo files agree on device
print("Populate DataAcquisitionDevice...")
DataAcquisitionDevice.insert_from_nwbfile(nwbf, config)
print()

print("Populate CameraDevice...")
CameraDevice.insert_from_nwbfile(nwbf)
Expand Down Expand Up @@ -260,6 +263,8 @@ def create_spyglass_view(session_group_name: str):
# datajoint prohibits deleting from a subtable without
# also deleting the parent table.
# See: https://docs.datajoint.org/python/computation/03-master-part.html


@schema
class SessionGroupSession(dj.Manual):
definition = """
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/common/populate_all_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def populate_all_common(nwb_file_name):

print("Populate DIOEvents...")
DIOEvents.populate(fp)

# sensor data (from analog ProcessingModule) is temporarily removed from NWBFile
# to reduce file size while it is not being used. add it back in by commenting out
# the removal code in spyglass/data_import/insert_sessions.py when ready
# print('Populate SensorData')
# SensorData.populate(fp)

print("Populate TaskEpochs")
TaskEpoch.populate(fp)
print("Populate StateScriptFile")
Expand Down
Loading

0 comments on commit d333864

Please sign in to comment.