Skip to content

Commit

Permalink
WIP: Spellcheck. Remove debug params. Remove assigned lambda E713
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 21, 2023
1 parent 34134db commit e14477e
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 118 deletions.
2 changes: 1 addition & 1 deletion notebooks/14_Theta.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@
"\n",
"We can overlay theta and detected phase for each electrode.\n",
"\n",
"_Note:_ The red horizontal line indicates phase 0, corresponding to the trough\n",
"_Note:_ The red horizontal line indicates phase 0, corresponding to the through\n",
"of theta."
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/20_Position_Trodes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"available. To adjust the default, insert a new set into this table. The\n",
"parameters are...\n",
"\n",
"- `max_separation`, default 9 cm: maximium acceptable distance between red and\n",
"- `max_separation`, default 9 cm: maximum acceptable distance between red and\n",
" green LEDs.\n",
" - If exceeded, the times are marked as NaNs and inferred by interpolation.\n",
" - Useful when the inferred LED position tracks a reflection instead of the\n",
Expand Down
3 changes: 1 addition & 2 deletions notebooks/24_Linearization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
"\n",
"import spyglass.common as sgc\n",
"import spyglass.position.v1 as sgp\n",
"import spyglass as nd\n",
"\n",
"# ignore datajoint+jupyter async warnings\n",
"import warnings\n",
Expand Down Expand Up @@ -1501,7 +1500,7 @@
" + 1\n",
")\n",
"video_info = (\n",
" nd.common.common_behav.VideoFile()\n",
" sgc.common_behav.VideoFile()\n",
" & {\"nwb_file_name\": key[\"nwb_file_name\"], \"epoch\": epoch}\n",
").fetch1()\n",
"\n",
Expand Down
8 changes: 4 additions & 4 deletions notebooks/33_Decoding_Clusterless.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
" [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded \n",
" position data. If 1D decoding, this data should also be\n",
" [linearized](./24_Linearization.ipynb).\n",
"- Ths tutorial also assumes you're familiar with how to run processes on GPU, as\n",
" presented in [this notebook](./32_Decoding_with_GPUs.ipynb)\n",
"- This tutorial also assumes you're familiar with how to run processes on GPU, \n",
" as presented in [this notebook](./32_Decoding_with_GPUs.ipynb)\n",
"\n",
"Clusterless decoding can be performed on either 1D or 2D data. A few steps in\n",
"this notebook will refer to a `decode_1d` variable set in \n",
Expand Down Expand Up @@ -143,10 +143,10 @@
"source": [
"First, we'll fetch marks with `fetch_xarray`, which provides a labeled array of\n",
"shape (n_time, n_mark_features, n_electrodes). Time is in 2 ms bins with either\n",
"`NaN` if no spike occured or the value of the spike features.\n",
"`NaN` if no spike occurred or the value of the spike features.\n",
"\n",
"If there is >1 spike per time bin per tetrode, we take an an average of the\n",
"marks. Ideally, we would use all the marks, this is a rare occurance and\n",
"marks. Ideally, we would use all the marks, this is a rare occurrence and\n",
"decoding is generally robust to the averaging."
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/py_scripts/14_Theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
#
# We can overlay theta and detected phase for each electrode.
#
# _Note:_ The red horizontal line indicates phase 0, corresponding to the trough
# _Note:_ The red horizontal line indicates phase 0, corresponding to the through
# of theta.

# +
Expand Down
2 changes: 1 addition & 1 deletion notebooks/py_scripts/20_Position_Trodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# available. To adjust the default, insert a new set into this table. The
# parameters are...
#
# - `max_separation`, default 9 cm: maximium acceptable distance between red and
# - `max_separation`, default 9 cm: maximum acceptable distance between red and
# green LEDs.
# - If exceeded, the times are marked as NaNs and inferred by interpolation.
# - Useful when the inferred LED position tracks a reflection instead of the
Expand Down
3 changes: 1 addition & 2 deletions notebooks/py_scripts/24_Linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@

import spyglass.common as sgc
import spyglass.position.v1 as sgp
import spyglass as nd

# ignore datajoint+jupyter async warnings
import warnings
Expand Down Expand Up @@ -335,7 +334,7 @@
+ 1
)
video_info = (
nd.common.common_behav.VideoFile()
sgc.common_behav.VideoFile()
& {"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
).fetch1()

Expand Down
8 changes: 4 additions & 4 deletions notebooks/py_scripts/33_Decoding_Clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
# [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded
# position data. If 1D decoding, this data should also be
# [linearized](./24_Linearization.ipynb).
# - Ths tutorial also assumes you're familiar with how to run processes on GPU, as
# presented in [this notebook](./32_Decoding_with_GPUs.ipynb)
# - This tutorial also assumes you're familiar with how to run processes on GPU,
# as presented in [this notebook](./32_Decoding_with_GPUs.ipynb)
#
# Clusterless decoding can be performed on either 1D or 2D data. A few steps in
# this notebook will refer to a `decode_1d` variable set in
Expand Down Expand Up @@ -87,10 +87,10 @@

# First, we'll fetch marks with `fetch_xarray`, which provides a labeled array of
# shape (n_time, n_mark_features, n_electrodes). Time is in 2 ms bins with either
# `NaN` if no spike occured or the value of the spike features.
# `NaN` if no spike occurred or the value of the spike features.
#
# If there is >1 spike per time bin per tetrode, we take an an average of the
# marks. Ideally, we would use all the marks, this is a rare occurance and
# marks. Ideally, we would use all the marks, this is a rare occurrence and
# decoding is generally robust to the averaging.

# +
Expand Down
55 changes: 33 additions & 22 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PositionSource(dj.Manual):
-> IntervalList
---
source: varchar(200) # source of data (e.g., trodes, dlc)
import_file_name: varchar(2000) # path to import file if importing
import_file_name: varchar(2000) # path to import file if importing
"""

class SpatialSeries(dj.Part):
Expand All @@ -44,16 +44,19 @@ class SpatialSeries(dj.Part):

@classmethod
def insert_from_nwbfile(cls, nwb_file_name):
"""Given an NWB file name, get the spatial series and interval lists from the file, add the interval
lists to the IntervalList table, and populate the RawPosition table if possible.
"""Add intervals to ItervalList and PositionSource.
Given an NWB file name, get the spatial series and interval lists from
the file, add the interval lists to the IntervalList table, and
populate the RawPosition table if possible.
Parameters
----------
nwb_file_name : str
The name of the NWB file.
"""
nwbf = get_nwb_file(nwb_file_name)
all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False)
all_pos = get_all_spatial_series(nwbf, verbose=True)
sess_key = dict(nwb_file_name=nwb_file_name)
src_key = dict(**sess_key, source="trodes", import_file_name="")

Expand Down Expand Up @@ -81,7 +84,7 @@ def insert_from_nwbfile(cls, nwb_file_name):
dict(
**sess_key,
**ind_key,
id=ndex,
id=index,
name=pdict.get("name"),
)
)
Expand Down Expand Up @@ -189,9 +192,9 @@ def make(self, key):
indices = (PositionSource.SpatialSeries & key).fetch("id")

# incl_times = False -> don't do extra processing for valid_times
spat_objs = get_all_spatial_series(
nwbf, old_format=False, incl_times=False
)[PositionSource.get_epoch_num(interval_list_name)]
spat_objs = get_all_spatial_series(nwbf, incl_times=False)[
PositionSource.get_epoch_num(interval_list_name)
]

self.insert1(key)
self.Object.insert(
Expand Down Expand Up @@ -227,7 +230,7 @@ class StateScriptFile(dj.Imported):
"""

def make(self, key):
"""Add a new row to the StateScriptFile table. Requires keys "nwb_file_name", "file_object_id"."""
"""Add a new row to the StateScriptFile table."""
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
Expand All @@ -237,8 +240,8 @@ def make(self, key):
) or nwbf.processing.get("associated files")
if associated_files is None:
print(
f'Unable to import StateScriptFile: no processing module named "associated_files" '
f"found in {nwb_file_name}."
"Unable to import StateScriptFile: no processing module named "
+ '"associated_files" found in {nwb_file_name}.'
)
return

Expand All @@ -247,13 +250,16 @@ def make(self, key):
associated_file_obj, ndx_franklab_novela.AssociatedFiles
):
print(
f'Data interface {associated_file_obj.name} within "associated_files" processing module is not '
f"of expected type ndx_franklab_novela.AssociatedFiles\n"
f"Data interface {associated_file_obj.name} within "
+ '"associated_files" processing module is not '
+ "of expected type ndx_franklab_novela.AssociatedFiles\n"
)
return

# parse the task_epochs string
# TODO update associated_file_obj.task_epochs to be an array of 1-based ints,
# not a comma-separated string of ints
# TODO: update associated_file_obj.task_epochs to be an array of
# 1-based ints, not a comma-separated string of ints

epoch_list = associated_file_obj.task_epochs.split(",")
# only insert if this is the statescript file
print(associated_file_obj.description)
Expand Down Expand Up @@ -281,8 +287,9 @@ class VideoFile(dj.Imported):
Notes
-----
The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is used.
If PTP is not used, the video timestamps come from videoTimeStamps.cameraHWFrameCount .
The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is
used. If PTP is not used, the video timestamps come from
videoTimeStamps.cameraHWFrameCount .
"""

Expand Down Expand Up @@ -330,7 +337,9 @@ def _no_transaction_make(self, key, verbose=True):
if isinstance(video, pynwb.image.ImageSeries):
video = [video]
for video_obj in video:
# check to see if the times for this video_object are largely overlapping with the task epoch times
# check to see if the times for this video_object are largely
# overlapping with the task epoch times

if len(
interval_list_contains(valid_times, video_obj.timestamps)
> 0.9 * len(video_obj.timestamps)
Expand All @@ -341,7 +350,8 @@ def _no_transaction_make(self, key, verbose=True):
key["camera_name"] = video_obj.device.camera_name
else:
raise KeyError(
f"No camera with camera_name: {camera_name} found in CameraDevice table."
f"No camera with camera_name: {camera_name} found "
+ "in CameraDevice table."
)
key["video_file_object_id"] = video_obj.object_id
self.insert1(key)
Expand All @@ -365,16 +375,17 @@ def update_entries(cls, restrict={}):
video_nwb = (cls & row).fetch_nwb()[0]
if len(video_nwb) != 1:
raise ValueError(
f"expecting 1 video file per entry, but {len(video_nwb)} files found"
f"Expecting 1 video file per entry. {len(video_nwb)} found"
)
row["camera_name"] = video_nwb[0]["video_file"].device.camera_name
cls.update1(row=row)

@classmethod
def get_abs_path(cls, key: Dict):
"""Return the absolute path for a stored video file given a key with the nwb_file_name and epoch number
"""Return the absolute path for a stored video file given a key.
The SPYGLASS_VIDEO_DIR environment variable must be set.
Key must include the nwb_file_name and epoch number. The
SPYGLASS_VIDEO_DIR environment variable must be set.
Parameters
----------
Expand Down
7 changes: 5 additions & 2 deletions src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,18 @@ def get_gpu_memory():
if subproccess command errors.
"""

output_to_list = lambda x: x.decode("ascii").split("\n")[:-1]
def output_to_list(x):
return x.decode("ascii").split("\n")[:-1]

query_cmd = "nvidia-smi --query-gpu=memory.used --format=csv"
try:
memory_use_info = output_to_list(
subprocess.check_output(query_cmd.split(), stderr=subprocess.STDOUT)
)[1:]
except subprocess.CalledProcessError as err:
raise RuntimeError(
f"command {err.cmd} return with error (code {err.returncode}): {err.output}"
f"command {err.cmd} return with error (code {err.returncode}): "
+ f"{err.output}"
) from err
memory_use_values = {
i: int(x.split()[0]) for i, x in enumerate(memory_use_info)
Expand Down
17 changes: 13 additions & 4 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,19 @@ def merge_delete_parent(
if dry_run:
return part_parents

with cls._safe_context():
super().delete(cls(), **kwargs)
for part_parent in part_parents:
super().delete(part_parent, **kwargs)
merge_ids = cls.merge_restrict(restriction).fetch(
RESERVED_PRIMARY_KEY, as_dict=True
)

# CB: Removed transaction protection here bc 'no' confirmation resp
# still resulted in deletes. If re-add, consider transaction=False
super().delete((cls & merge_ids), **kwargs)

if cls & merge_ids: # If 'no' on del prompt from above, skip below
return # User can still abort del below, but yes/no is unlikly

for part_parent in part_parents:
super().delete(part_parent, **kwargs) # add safemode=False?

@classmethod
def fetch_nwb(
Expand Down
Loading

0 comments on commit e14477e

Please sign in to comment.