Skip to content

Commit

Permalink
WIP: PositionSource add part table
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 15, 2023
1 parent f8f33e6 commit a5aa589
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 121 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ temp_nwb/*s
*.json
*.gz
*.pdf
dj_local_conf.json
dj_local_conf*
!dj_local_conf_example.json

!/.vscode/extensions.json
Expand Down
85 changes: 55 additions & 30 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@
class PositionSource(dj.Manual):
definition = """
-> Session
-> IntervalList
---
source: varchar(200) # source of data; current options are "trodes" and "dlc" (deep lab cut)
import_file_name: varchar(2000) # path to import file if importing position data
source: varchar(200) # source of data (e.g., trodes, dlc)
import_file_name: varchar(2000) # path to import file if importing
"""

class IntervalList(dj.Part):
definition = """
-> master
-> IntervalList
"""

@classmethod
def insert_from_nwbfile(cls, nwb_file_name):
"""Given an NWB file name, get the spatial series and interval lists from the file, add the interval
Expand All @@ -43,32 +48,50 @@ def insert_from_nwbfile(cls, nwb_file_name):
nwb_file_name : str
The name of the NWB file.
"""
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
nwbf = get_nwb_file(nwb_file_name)
all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False)
sess_key = dict(nwb_file_name=nwb_file_name)
pos_source_key = dict(**sess_key, source="trodes", import_file_name="")

if all_pos is None:
return

intervals = []
pos_intervals = []

for epoch, epoch_list in enumerate(all_pos.values()):
for index, pdict in enumerate(epoch_list):
pos_interval_name = cls.get_pos_interval_name([epoch, index])

intervals.append(
dict(
**sess_key,
interval_list_name=pos_interval_name,
valid_times=pos_dict["valid_times"],
)
)

pos_dict = get_all_spatial_series(nwbf, verbose=True)
if pos_dict is not None:
for epoch in pos_dict:
pdict = pos_dict[epoch]
pos_interval_list_name = cls.get_pos_interval_name(epoch)

# create the interval list and insert it
interval_dict = dict()
interval_dict["nwb_file_name"] = nwb_file_name
interval_dict["interval_list_name"] = pos_interval_list_name
interval_dict["valid_times"] = pdict["valid_times"]
IntervalList().insert1(interval_dict, skip_duplicates=True)

# add this interval list to the table
key = dict()
key["nwb_file_name"] = nwb_file_name
key["interval_list_name"] = pos_interval_list_name
key["source"] = "trodes"
key["import_file_name"] = ""
cls.insert1(key)
# UNTESTED
IntervalList.insert(intervals, skip_duplicates=True)
cls.IntervalList.insert(intervals, skip_duplicates=True)
cls.insert1(key)

@staticmethod
def get_pos_interval_name(pos_epoch_num):
"""Retun string of the interval name from the epoch number.

Check failure on line 81 in src/spyglass/common/common_behav.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

Retun ==> Return
Parameters
----------
pos_epoch_num : int or str or list
If list of length 2, then a string of the form "epoch 1 index 2"
Returns
-------
str
Position interval name (e.g., pos epoch 1 index 2 valid times)
"""
if isinstance(pos_epoch_num, list) and len(pos_epoch_num) == 2:
pos_epoch_num = f"epoch {pos_epoch_num[0]} index {pos_epoch_num[1]}"
return f"pos {pos_epoch_num} valid times"


Expand Down Expand Up @@ -100,11 +123,13 @@ def make(self, key):
for epoch in pos_dict:
if key[
"interval_list_name"
] == PositionSource.get_pos_interval_name(epoch):
pdict = pos_dict[epoch]
key["raw_position_object_id"] = pdict["raw_position_object_id"]
self.insert1(key)
break
] != PositionSource.get_pos_interval_name(epoch):
continue

pdict = pos_dict[epoch]
key["raw_position_object_id"] = pdict["raw_position_object_id"]
self.insert1(key)
break

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)
Expand Down
23 changes: 11 additions & 12 deletions src/spyglass/common/common_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,26 @@
import pandas as pd
import pynwb

from ..utils.dj_helper_fn import fetch_nwb # dj_replace
from ..utils.nwb_helper_fn import (
estimate_sampling_rate,
get_config,
get_data_interface,
get_electrode_indices,
get_nwb_file,
get_valid_intervals,
)
from .common_device import Probe # noqa: F401
from .common_filter import FirFilterParameters
from .common_interval import interval_list_censor # noqa: F401
from .common_interval import (
IntervalList,
interval_list_censor, # noqa: F401
interval_list_contains_ind,
interval_list_intersect,
)
from .common_nwbfile import AnalysisNwbfile, Nwbfile
from .common_region import BrainRegion # noqa: F401
from .common_session import Session # noqa: F401
from ..utils.dj_helper_fn import fetch_nwb # dj_replace
from ..utils.nwb_helper_fn import (
estimate_sampling_rate,
get_data_interface,
get_electrode_indices,
get_nwb_file,
get_valid_intervals,
get_config,
)

schema = dj.schema("common_ephys")

Expand Down Expand Up @@ -251,9 +251,8 @@ def make(self, key):
print("Estimating sampling rate...")
# NOTE: Only use first 1e6 timepoints to save time
sampling_rate = estimate_sampling_rate(
np.asarray(rawdata.timestamps[: int(1e6)]), 1.5
np.asarray(rawdata.timestamps[: int(1e6)]), 1.5, verbose=True
)
print(f"Estimated sampling rate: {sampling_rate}")
key["sampling_rate"] = sampling_rate

interval_dict = dict()
Expand Down
10 changes: 6 additions & 4 deletions src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import spikeinterface as si
from hdmf.common import DynamicTable

from ..settings import load_config
from ..settings import raw_dir
from ..utils.dj_helper_fn import get_child_tables
from ..utils.nwb_helper_fn import get_electrode_indices, get_nwb_file

Expand Down Expand Up @@ -73,20 +73,22 @@ def insert_from_relative_file_name(cls, nwb_file_name):
def get_abs_path(nwb_file_name):
"""Return absolute path for a stored raw NWB file given file name.
The SPYGLASS_BASE_DIR environment variable must be set.
The SPYGLASS_BASE_DIR must be set, either as an environment or part of
dj.config['custom']. See spyglass.settings.load_config
Parameters
----------
nwb_file_name : str
The name of an NWB file that has been inserted into the Nwbfile() schema.
The name of an NWB file that has been inserted into the Nwbfile()
schema.
Returns
-------
nwb_file_abspath : str
The absolute path for the given file name.
"""

return load_config()["SPYGLASS_RAW_DIR"] + "/" + nwb_file_name
return raw_dir + "/" + nwb_file_name

@staticmethod
def add_to_lock(nwb_file_name):
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datajoint as dj

from ..settings import config
from ..utils.nwb_helper_fn import get_config, get_nwb_file
from .common_device import CameraDevice, DataAcquisitionDevice, Probe
from .common_lab import Institution, Lab, LabMember
from .common_nwbfile import Nwbfile
from .common_subject import Subject
from ..utils.nwb_helper_fn import get_nwb_file, get_config
from ..settings import config

schema = dj.schema("common_session")

Expand Down
30 changes: 23 additions & 7 deletions src/spyglass/data_import/insert_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pynwb

from ..common import Nwbfile, get_raw_eseries, populate_all_common
from ..settings import load_config
from ..settings import raw_dir
from ..utils.nwb_helper_fn import get_nwb_copy_filename


def insert_sessions(nwb_file_names: Union[str, List[str]]):
Expand All @@ -18,7 +19,9 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]):
----------
nwb_file_names : str or List of str
File names in raw directory ($SPYGLASS_RAW_DIR) pointing to
existing .nwb files. Each file represents a session.
existing .nwb files. Each file represents a session. Also accepts
strings with glob wildcards (e.g., *) so long as the wildcard specifies
exactly one file.
"""

if not isinstance(nwb_file_names, list):
Expand All @@ -29,11 +32,23 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]):
nwb_file_name = nwb_file_name.split("/")[-1]

nwb_file_abs_path = Path(Nwbfile.get_abs_path(nwb_file_name))

if not nwb_file_abs_path.exists():
raise FileNotFoundError(f"File not found: {nwb_file_abs_path}")
possible_matches = sorted(Path(raw_dir).glob(f"*{nwb_file_name}*"))

if len(possible_matches) == 1:
nwb_file_abs_path = possible_matches[0]
nwb_file_name = nwb_file_abs_path.name

else:
raise FileNotFoundError(
f"File not found: {nwb_file_abs_path}\n\t"
+ f"{len(possible_matches)} possible matches:"
+ f"{possible_matches}"
)

# file name for the copied raw data
out_nwb_file_name = nwb_file_abs_path.stem + "_.nwb"
out_nwb_file_name = get_nwb_copy_filename(nwb_file_abs_path.stem)

# Check whether the file already exists in the Nwbfile table
if len(Nwbfile() & {"nwb_file_name": out_nwb_file_name}):
Expand Down Expand Up @@ -72,11 +87,12 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name):
)

nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name)
assert os.path.exists(
nwb_file_abs_path
), f"File does not exist: {nwb_file_abs_path}"

if not os.path.exists(nwb_file_abs_path):
raise FileNotFoundError(f"Could not find raw file: {nwb_file_abs_path}")

out_nwb_file_abs_path = Nwbfile.get_abs_path(out_nwb_file_name)

if os.path.exists(out_nwb_file_name):
warnings.warn(
f"Output file {out_nwb_file_abs_path} exists and will be "
Expand Down
Loading

0 comments on commit a5aa589

Please sign in to comment.