From e14477e80db2754c146b3d37b5e1fab423295499 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 21 Aug 2023 11:56:32 -0700 Subject: [PATCH 1/3] WIP: Spellcheck. Remove debug params. Remove assigned lambda E713 --- notebooks/14_Theta.ipynb | 2 +- notebooks/20_Position_Trodes.ipynb | 2 +- notebooks/24_Linearization.ipynb | 3 +- notebooks/33_Decoding_Clusterless.ipynb | 8 +- notebooks/py_scripts/14_Theta.py | 2 +- notebooks/py_scripts/20_Position_Trodes.py | 2 +- notebooks/py_scripts/24_Linearization.py | 3 +- .../py_scripts/33_Decoding_Clusterless.py | 8 +- src/spyglass/common/common_behav.py | 55 ++++---- src/spyglass/position/v1/dlc_utils.py | 7 +- src/spyglass/utils/dj_merge_tables.py | 17 ++- src/spyglass/utils/nwb_helper_fn.py | 121 +++++++----------- 12 files changed, 112 insertions(+), 118 deletions(-) diff --git a/notebooks/14_Theta.ipynb b/notebooks/14_Theta.ipynb index ddb5083e1..44e9e2aa8 100644 --- a/notebooks/14_Theta.ipynb +++ b/notebooks/14_Theta.ipynb @@ -594,7 +594,7 @@ "\n", "We can overlay theta and detected phase for each electrode.\n", "\n", - "_Note:_ The red horizontal line indicates phase 0, corresponding to the trough\n", + "_Note:_ The red horizontal line indicates phase 0, corresponding to the through\n", "of theta." ] }, diff --git a/notebooks/20_Position_Trodes.ipynb b/notebooks/20_Position_Trodes.ipynb index fa765059d..df10bfd41 100644 --- a/notebooks/20_Position_Trodes.ipynb +++ b/notebooks/20_Position_Trodes.ipynb @@ -142,7 +142,7 @@ "available. To adjust the default, insert a new set into this table. The\n", "parameters are...\n", "\n", - "- `max_separation`, default 9 cm: maximium acceptable distance between red and\n", + "- `max_separation`, default 9 cm: maximum acceptable distance between red and\n", " green LEDs.\n", " - If exceeded, the times are marked as NaNs and inferred by interpolation.\n", " - Useful when the inferred LED position tracks a reflection instead of the\n", diff --git a/notebooks/24_Linearization.ipynb b/notebooks/24_Linearization.ipynb index f0d2fa6fd..310df156c 100644 --- a/notebooks/24_Linearization.ipynb +++ b/notebooks/24_Linearization.ipynb @@ -87,7 +87,6 @@ "\n", "import spyglass.common as sgc\n", "import spyglass.position.v1 as sgp\n", - "import spyglass as nd\n", "\n", "# ignore datajoint+jupyter async warnings\n", "import warnings\n", @@ -1501,7 +1500,7 @@ " + 1\n", ")\n", "video_info = (\n", - " nd.common.common_behav.VideoFile()\n", + " sgc.common_behav.VideoFile()\n", " & {\"nwb_file_name\": key[\"nwb_file_name\"], \"epoch\": epoch}\n", ").fetch1()\n", "\n", diff --git a/notebooks/33_Decoding_Clusterless.ipynb b/notebooks/33_Decoding_Clusterless.ipynb index 565329f12..ae6118729 100644 --- a/notebooks/33_Decoding_Clusterless.ipynb +++ b/notebooks/33_Decoding_Clusterless.ipynb @@ -32,8 +32,8 @@ " [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded \n", " position data. If 1D decoding, this data should also be\n", " [linearized](./24_Linearization.ipynb).\n", - "- Ths tutorial also assumes you're familiar with how to run processes on GPU, as\n", - " presented in [this notebook](./32_Decoding_with_GPUs.ipynb)\n", + "- This tutorial also assumes you're familiar with how to run processes on GPU, \n", + " as presented in [this notebook](./32_Decoding_with_GPUs.ipynb)\n", "\n", "Clusterless decoding can be performed on either 1D or 2D data. A few steps in\n", "this notebook will refer to a `decode_1d` variable set in \n", @@ -143,10 +143,10 @@ "source": [ "First, we'll fetch marks with `fetch_xarray`, which provides a labeled array of\n", "shape (n_time, n_mark_features, n_electrodes). Time is in 2 ms bins with either\n", - "`NaN` if no spike occured or the value of the spike features.\n", + "`NaN` if no spike occurred or the value of the spike features.\n", "\n", "If there is >1 spike per time bin per tetrode, we take an an average of the\n", - "marks. Ideally, we would use all the marks, this is a rare occurance and\n", + "marks. Ideally, we would use all the marks, this is a rare occurrence and\n", "decoding is generally robust to the averaging." ] }, diff --git a/notebooks/py_scripts/14_Theta.py b/notebooks/py_scripts/14_Theta.py index 5ca3eaaba..fb1b366bc 100644 --- a/notebooks/py_scripts/14_Theta.py +++ b/notebooks/py_scripts/14_Theta.py @@ -128,7 +128,7 @@ # # We can overlay theta and detected phase for each electrode. # -# _Note:_ The red horizontal line indicates phase 0, corresponding to the trough +# _Note:_ The red horizontal line indicates phase 0, corresponding to the through # of theta. # + diff --git a/notebooks/py_scripts/20_Position_Trodes.py b/notebooks/py_scripts/20_Position_Trodes.py index e1e00dcfb..d9483dc58 100644 --- a/notebooks/py_scripts/20_Position_Trodes.py +++ b/notebooks/py_scripts/20_Position_Trodes.py @@ -87,7 +87,7 @@ # available. To adjust the default, insert a new set into this table. The # parameters are... # -# - `max_separation`, default 9 cm: maximium acceptable distance between red and +# - `max_separation`, default 9 cm: maximum acceptable distance between red and # green LEDs. # - If exceeded, the times are marked as NaNs and inferred by interpolation. # - Useful when the inferred LED position tracks a reflection instead of the diff --git a/notebooks/py_scripts/24_Linearization.py b/notebooks/py_scripts/24_Linearization.py index 0e4f3b7c1..ad77b0de9 100644 --- a/notebooks/py_scripts/24_Linearization.py +++ b/notebooks/py_scripts/24_Linearization.py @@ -54,7 +54,6 @@ import spyglass.common as sgc import spyglass.position.v1 as sgp -import spyglass as nd # ignore datajoint+jupyter async warnings import warnings @@ -335,7 +334,7 @@ + 1 ) video_info = ( - nd.common.common_behav.VideoFile() + sgc.common_behav.VideoFile() & {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ).fetch1() diff --git a/notebooks/py_scripts/33_Decoding_Clusterless.py b/notebooks/py_scripts/33_Decoding_Clusterless.py index db0ecedcf..9c1ab81d5 100644 --- a/notebooks/py_scripts/33_Decoding_Clusterless.py +++ b/notebooks/py_scripts/33_Decoding_Clusterless.py @@ -28,8 +28,8 @@ # [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded # position data. If 1D decoding, this data should also be # [linearized](./24_Linearization.ipynb). -# - Ths tutorial also assumes you're familiar with how to run processes on GPU, as -# presented in [this notebook](./32_Decoding_with_GPUs.ipynb) +# - This tutorial also assumes you're familiar with how to run processes on GPU, +# as presented in [this notebook](./32_Decoding_with_GPUs.ipynb) # # Clusterless decoding can be performed on either 1D or 2D data. A few steps in # this notebook will refer to a `decode_1d` variable set in @@ -87,10 +87,10 @@ # First, we'll fetch marks with `fetch_xarray`, which provides a labeled array of # shape (n_time, n_mark_features, n_electrodes). Time is in 2 ms bins with either -# `NaN` if no spike occured or the value of the spike features. +# `NaN` if no spike occurred or the value of the spike features. # # If there is >1 spike per time bin per tetrode, we take an an average of the -# marks. Ideally, we would use all the marks, this is a rare occurance and +# marks. Ideally, we would use all the marks, this is a rare occurrence and # decoding is generally robust to the averaging. # + diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index c775069b0..bcc701354 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -31,7 +31,7 @@ class PositionSource(dj.Manual): -> IntervalList --- source: varchar(200) # source of data (e.g., trodes, dlc) - import_file_name: varchar(2000) # path to import file if importing + import_file_name: varchar(2000) # path to import file if importing """ class SpatialSeries(dj.Part): @@ -44,8 +44,11 @@ class SpatialSeries(dj.Part): @classmethod def insert_from_nwbfile(cls, nwb_file_name): - """Given an NWB file name, get the spatial series and interval lists from the file, add the interval - lists to the IntervalList table, and populate the RawPosition table if possible. + """Add intervals to ItervalList and PositionSource. + + Given an NWB file name, get the spatial series and interval lists from + the file, add the interval lists to the IntervalList table, and + populate the RawPosition table if possible. Parameters ---------- @@ -53,7 +56,7 @@ def insert_from_nwbfile(cls, nwb_file_name): The name of the NWB file. """ nwbf = get_nwb_file(nwb_file_name) - all_pos = get_all_spatial_series(nwbf, verbose=True, old_format=False) + all_pos = get_all_spatial_series(nwbf, verbose=True) sess_key = dict(nwb_file_name=nwb_file_name) src_key = dict(**sess_key, source="trodes", import_file_name="") @@ -81,7 +84,7 @@ def insert_from_nwbfile(cls, nwb_file_name): dict( **sess_key, **ind_key, - id=ndex, + id=index, name=pdict.get("name"), ) ) @@ -189,9 +192,9 @@ def make(self, key): indices = (PositionSource.SpatialSeries & key).fetch("id") # incl_times = False -> don't do extra processing for valid_times - spat_objs = get_all_spatial_series( - nwbf, old_format=False, incl_times=False - )[PositionSource.get_epoch_num(interval_list_name)] + spat_objs = get_all_spatial_series(nwbf, incl_times=False)[ + PositionSource.get_epoch_num(interval_list_name) + ] self.insert1(key) self.Object.insert( @@ -227,7 +230,7 @@ class StateScriptFile(dj.Imported): """ def make(self, key): - """Add a new row to the StateScriptFile table. Requires keys "nwb_file_name", "file_object_id".""" + """Add a new row to the StateScriptFile table.""" nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -237,8 +240,8 @@ def make(self, key): ) or nwbf.processing.get("associated files") if associated_files is None: print( - f'Unable to import StateScriptFile: no processing module named "associated_files" ' - f"found in {nwb_file_name}." + "Unable to import StateScriptFile: no processing module named " + + '"associated_files" found in {nwb_file_name}.' ) return @@ -247,13 +250,16 @@ def make(self, key): associated_file_obj, ndx_franklab_novela.AssociatedFiles ): print( - f'Data interface {associated_file_obj.name} within "associated_files" processing module is not ' - f"of expected type ndx_franklab_novela.AssociatedFiles\n" + f"Data interface {associated_file_obj.name} within " + + '"associated_files" processing module is not ' + + "of expected type ndx_franklab_novela.AssociatedFiles\n" ) return + # parse the task_epochs string - # TODO update associated_file_obj.task_epochs to be an array of 1-based ints, - # not a comma-separated string of ints + # TODO: update associated_file_obj.task_epochs to be an array of + # 1-based ints, not a comma-separated string of ints + epoch_list = associated_file_obj.task_epochs.split(",") # only insert if this is the statescript file print(associated_file_obj.description) @@ -281,8 +287,9 @@ class VideoFile(dj.Imported): Notes ----- - The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is used. - If PTP is not used, the video timestamps come from videoTimeStamps.cameraHWFrameCount . + The video timestamps come from: videoTimeStamps.cameraHWSync if PTP is + used. If PTP is not used, the video timestamps come from + videoTimeStamps.cameraHWFrameCount . """ @@ -330,7 +337,9 @@ def _no_transaction_make(self, key, verbose=True): if isinstance(video, pynwb.image.ImageSeries): video = [video] for video_obj in video: - # check to see if the times for this video_object are largely overlapping with the task epoch times + # check to see if the times for this video_object are largely + # overlapping with the task epoch times + if len( interval_list_contains(valid_times, video_obj.timestamps) > 0.9 * len(video_obj.timestamps) @@ -341,7 +350,8 @@ def _no_transaction_make(self, key, verbose=True): key["camera_name"] = video_obj.device.camera_name else: raise KeyError( - f"No camera with camera_name: {camera_name} found in CameraDevice table." + f"No camera with camera_name: {camera_name} found " + + "in CameraDevice table." ) key["video_file_object_id"] = video_obj.object_id self.insert1(key) @@ -365,16 +375,17 @@ def update_entries(cls, restrict={}): video_nwb = (cls & row).fetch_nwb()[0] if len(video_nwb) != 1: raise ValueError( - f"expecting 1 video file per entry, but {len(video_nwb)} files found" + f"Expecting 1 video file per entry. {len(video_nwb)} found" ) row["camera_name"] = video_nwb[0]["video_file"].device.camera_name cls.update1(row=row) @classmethod def get_abs_path(cls, key: Dict): - """Return the absolute path for a stored video file given a key with the nwb_file_name and epoch number + """Return the absolute path for a stored video file given a key. - The SPYGLASS_VIDEO_DIR environment variable must be set. + Key must include the nwb_file_name and epoch number. The + SPYGLASS_VIDEO_DIR environment variable must be set. Parameters ---------- diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index ce3e3b3fc..3abfb8f72 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -533,7 +533,9 @@ def get_gpu_memory(): if subproccess command errors. """ - output_to_list = lambda x: x.decode("ascii").split("\n")[:-1] + def output_to_list(x): + return x.decode("ascii").split("\n")[:-1] + query_cmd = "nvidia-smi --query-gpu=memory.used --format=csv" try: memory_use_info = output_to_list( @@ -541,7 +543,8 @@ def get_gpu_memory(): )[1:] except subprocess.CalledProcessError as err: raise RuntimeError( - f"command {err.cmd} return with error (code {err.returncode}): {err.output}" + f"command {err.cmd} return with error (code {err.returncode}): " + + f"{err.output}" ) from err memory_use_values = { i: int(x.split()[0]) for i, x in enumerate(memory_use_info) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 5d7c093eb..ef9f60674 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -445,10 +445,19 @@ def merge_delete_parent( if dry_run: return part_parents - with cls._safe_context(): - super().delete(cls(), **kwargs) - for part_parent in part_parents: - super().delete(part_parent, **kwargs) + merge_ids = cls.merge_restrict(restriction).fetch( + RESERVED_PRIMARY_KEY, as_dict=True + ) + + # CB: Removed transaction protection here bc 'no' confirmation resp + # still resulted in deletes. If re-add, consider transaction=False + super().delete((cls & merge_ids), **kwargs) + + if cls & merge_ids: # If 'no' on del prompt from above, skip below + return # User can still abort del below, but yes/no is unlikly + + for part_parent in part_parents: + super().delete(part_parent, **kwargs) # add safemode=False? @classmethod def fetch_nwb( diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 2d493371a..835dfedc7 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -3,6 +3,7 @@ import os import os.path import warnings +from itertools import groupby from pathlib import Path import numpy as np @@ -43,20 +44,19 @@ def get_nwb_file(nwb_file_path): nwb_file_path = Nwbfile.get_abs_path(nwb_file_path) _, nwbfile = __open_nwb_files.get(nwb_file_path, (None, None)) - nwb_uri = None - nwb_raw_uri = None if nwbfile is None: # check to see if the file exists if not os.path.exists(nwb_file_path): print( - f"NWB file not found locally; checking kachery for " + "NWB file not found locally; checking kachery for " + f"{nwb_file_path}" ) # first try the analysis files from ..sharing.sharing_kachery import AnalysisNwbfileKachery - # the download functions assume just the filename, so we need to get that from the path + # the download functions assume just the filename, so we need to + # get that from the path if not AnalysisNwbfileKachery.download_file( os.path.basename(nwb_file_path) ): @@ -109,7 +109,8 @@ def close_nwb_files(): def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): - """Search for a specified NWBDataInterface or DynamicTable in the processing modules of an NWB file. + """ + Search for NWBDataInterface or DynamicTable in processing modules of an NWB. Parameters ---------- @@ -118,13 +119,15 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): data_interface_name : str The name of the NWBDataInterface or DynamicTable to search for. data_interface_class : type, optional - The class (or superclass) to search for. This argument helps to prevent accessing an object with the same - name but the incorrect type. Default: no restriction. + The class (or superclass) to search for. This argument helps to prevent + accessing an object with the same name but the incorrect type. Default: + no restriction. Warns ----- UserWarning - If multiple NWBDataInterface and DynamicTable objects with the matching name are found. + If multiple NWBDataInterface and DynamicTable objects with the matching + name are found. Returns ------- @@ -156,7 +159,8 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): def get_raw_eseries(nwbfile): """Return all ElectricalSeries in the acquisition group of an NWB file. - ElectricalSeries found within LFP objects in the acquisition will also be returned. + ElectricalSeries found within LFP objects in the acquisition will also be + returned. Parameters ---------- @@ -306,16 +310,20 @@ def get_valid_intervals( def get_electrode_indices(nwb_object, electrode_ids): - """Given an NWB file or electrical series object, return the indices of the specified electrode_ids. + """Return indices of the specified electrode_ids given an NWB file. - If an ElectricalSeries is given, then the indices returned are relative to the selected rows in - ElectricalSeries.electrodes. For example, if electricalseries.electrodes = [5], and row index 5 of - nwbfile.electrodes has ID 10, then calling get_electrode_indices(electricalseries, 10) will return 0, the - index of the matching electrode in electricalseries.electrodes. + Also accepts electrical series object. If an ElectricalSeries is given, + then the indices returned are relative to the selected rows in + ElectricalSeries.electrodes. For example, if electricalseries.electrodes = + [5], and row index 5 of nwbfile.electrodes has ID 10, then calling + get_electrode_indices(electricalseries, 10) will return 0, the index of the + matching electrode in electricalseries.electrodes. - Indices for electrode_ids that are not in the electrical series are returned as np.nan + Indices for electrode_ids that are not in the electrical series are + returned as np.nan - If an NWBFile is given, then the row indices with the matching IDs in the file's electrodes table are returned. + If an NWBFile is given, then the row indices with the matching IDs in the + file's electrodes table are returned. Parameters ---------- @@ -330,8 +338,9 @@ def get_electrode_indices(nwb_object, electrode_ids): Array of indices of the specified electrode IDs. """ if isinstance(nwb_object, pynwb.ecephys.ElectricalSeries): - # electrodes is a DynamicTableRegion which may contain a subset of the rows in NWBFile.electrodes - # match against only the subset of electrodes referenced by this ElectricalSeries + # electrodes is a DynamicTableRegion which may contain a subset of the + # rows in NWBFile.electrodes match against only the subset of + # electrodes referenced by this ElectricalSeries electrode_table_indices = nwb_object.electrodes.data[:] selected_elect_ids = [ nwb_object.electrodes.table.id[x] for x in electrode_table_indices @@ -344,7 +353,9 @@ def get_electrode_indices(nwb_object, electrode_ids): "nwb_object must be of type ElectricalSeries or NWBFile" ) - # for each electrode_id, find its index in selected_elect_ids and return that if it's there and invalid_electrode_index if not. + # for each electrode_id, find its index in selected_elect_ids and return + # that if it's there and invalid_electrode_index if not. + return [ selected_elect_ids.index(elect_id) if elect_id in selected_elect_ids @@ -353,18 +364,7 @@ def get_electrode_indices(nwb_object, electrode_ids): ] -def _get_epoch_groups(position: pynwb.behavior.Position, old_format=True): - if old_format: - epoch_start_time = np.zeros(len(position.spatial_series.values())) - for pos_epoch, spatial_series in enumerate( - position.spatial_series.values() - ): - epoch_start_time[pos_epoch] = spatial_series.timestamps[0] - - return np.argsort(epoch_start_time) - - from itertools import groupby - +def _get_epoch_groups(position: pynwb.behavior.Position): epoch_start_time = {} for pos_epoch, spatial_series in enumerate( position.spatial_series.values() @@ -384,7 +384,6 @@ def _get_pos_dict( epoch_groups: dict, session_id: str = None, verbose: bool = False, - old_format: bool = True, # TODO: remove after changing prod database incl_times: bool = True, ): """Return dict with obj ids and valid intervals for each epoch. @@ -406,14 +405,13 @@ def _get_pos_dict( # prev, this was just a list. now, we need to gen mult dict per epoch pos_data_dict = dict() all_spatial_series = list(position.values()) - if old_format: - # for index, orig_epoch in enumerate(sorted_order): - for index, orig_epoch in enumerate(epoch_groups): - spatial_series = all_spatial_series[orig_epoch] - # get the valid intervals for the position data + for epoch, index_list in enumerate(epoch_groups.values()): + pos_data_dict[epoch] = [] + for index in index_list: + spatial_series = all_spatial_series[index] valid_times = None - if incl_times: + if incl_times: # get the valid intervals for the position data timestamps = np.asarray(spatial_series.timestamps) sampling_rate = estimate_sampling_rate( timestamps, verbose=verbose, filename=session_id @@ -422,44 +420,19 @@ def _get_pos_dict( timestamps=timestamps, sampling_rate=sampling_rate, ) - # add the valid intervals to the Interval list - pos_data_dict[index] = { - "valid_times": valid_times, - "raw_position_object_id": spatial_series.object_id, - } - - else: - for epoch, index_list in enumerate(epoch_groups.values()): - pos_data_dict[epoch] = [] - for index in index_list: - spatial_series = all_spatial_series[index] - # get the valid intervals for the position data - valid_times = None - if incl_times: - timestamps = np.asarray(spatial_series.timestamps) - sampling_rate = estimate_sampling_rate( - timestamps, verbose=verbose, filename=session_id - ) - valid_times = get_valid_intervals( - timestamps=timestamps, - sampling_rate=sampling_rate, - ) - # add the valid intervals to the Interval list - pos_data_dict[epoch].append( - { - "valid_times": valid_times, - "raw_position_object_id": spatial_series.object_id, - "name": spatial_series.name, - } - ) + pos_data_dict[epoch].append( + { + "valid_times": valid_times, + "raw_position_object_id": spatial_series.object_id, + "name": spatial_series.name, + } + ) return pos_data_dict -def get_all_spatial_series( - nwbf, verbose=False, old_format=True, incl_times=True -) -> dict: +def get_all_spatial_series(nwbf, verbose=False, incl_times=True) -> dict: """ Given an NWB, get the spatial series and return a dictionary by epoch. @@ -492,10 +465,9 @@ def get_all_spatial_series( return _get_pos_dict( position=pos_interface.spatial_series, - epoch_groups=_get_epoch_groups(pos_interface, old_format=old_format), + epoch_groups=_get_epoch_groups(pos_interface), session_id=nwbf.session_id, verbose=verbose, - old_format=old_format, incl_times=incl_times, ) @@ -525,7 +497,8 @@ def change_group_permissions( # Loop through nwb file directories and change group permissions for target_content in target_contents: print( - f"For {target_content}, changing group to {set_group_name} and giving read/write/execute permissions" + f"For {target_content}, changing group to {set_group_name} " + + "and giving read/write/execute permissions" ) # Change group os.system(f"chgrp -R {set_group_name} {target_content}") From 2d6d87ed000dfdc8377af5acd989084be20f4cf0 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 21 Aug 2023 12:57:55 -0700 Subject: [PATCH 2/3] WIP: Pass tests. Remove codespell offending link --- notebooks/01_Insert_Data.ipynb | 2 +- src/spyglass/common/common_nwbfile.py | 7 ++++++- src/spyglass/data_import/insert_sessions.py | 10 +++++++--- tests/conftest.py | 10 ++-------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/notebooks/01_Insert_Data.ipynb b/notebooks/01_Insert_Data.ipynb index 6cb991994..d93efa838 100644 --- a/notebooks/01_Insert_Data.ipynb +++ b/notebooks/01_Insert_Data.ipynb @@ -692,7 +692,7 @@ "- `minirec20230622.nwb`, .3 GB: minimal recording, on\n", " [Box](https://ucsf.box.com/s/k3sgql6z475oia848q1rgms4zdh4rkjn)\n", "- `montague20200802.nwb`, 8 GB: full recording, on\n", - " [DropBox](https://www.dropbox.com/scl/fo/4i5b1z4iapetzxfps0grf/h?dl=0&preview=montague20200802_tutorial_.nwb&rlkey=ctahes9v0r7bxes8yceh86gzg)\n", + " DropBox (link coming soon)\n", "- For those in the UCSF network, these and many others on `/stelmo/nwb/raw`\n", "\n", "If you are connected to the Frank lab database, please rename any downloaded\n", diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index dd65d60a9..52c070d2b 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -87,7 +87,7 @@ def get_file_key(cls, nwb_file_name: str) -> dict: return {"nwb_file_name": cls._get_file_name(nwb_file_name)} @classmethod - def get_abs_path(cls, nwb_file_name) -> str: + def get_abs_path(cls, nwb_file_name, new_file=False) -> str: """Return absolute path for a stored raw NWB file given file name. The SPYGLASS_BASE_DIR must be set, either as an environment or part of @@ -98,12 +98,17 @@ def get_abs_path(cls, nwb_file_name) -> str: nwb_file_name : str The name of an NWB file that has been inserted into the Nwbfile() table. May be file substring. May include % wildcard(s). + new_file : bool, optional + Adding a new file to Nwbfile table. Defaults to False. Returns ------- nwb_file_abspath : str The absolute path for the given file name. """ + if new_file: + return raw_dir + "/" + nwb_file_name + return raw_dir + "/" + cls._get_file_name(nwb_file_name) @staticmethod diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index f47db38e9..5466b58a6 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -31,7 +31,9 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): if "/" in nwb_file_name: nwb_file_name = nwb_file_name.split("/")[-1] - nwb_file_abs_path = Path(Nwbfile.get_abs_path(nwb_file_name)) + nwb_file_abs_path = Path( + Nwbfile.get_abs_path(nwb_file_name, new_file=True) + ) if not nwb_file_abs_path.exists(): possible_matches = sorted(Path(raw_dir).glob(f"*{nwb_file_name}*")) @@ -86,12 +88,14 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): + f"with link to raw ephys data: {out_nwb_file_name}" ) - nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name) + nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) if not os.path.exists(nwb_file_abs_path): raise FileNotFoundError(f"Could not find raw file: {nwb_file_abs_path}") - out_nwb_file_abs_path = Nwbfile.get_abs_path(out_nwb_file_name) + out_nwb_file_abs_path = Nwbfile.get_abs_path( + out_nwb_file_name, new_file=True + ) if os.path.exists(out_nwb_file_name): warnings.warn( diff --git a/tests/conftest.py b/tests/conftest.py index 1356a80ef..bf8c8a0a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,10 +8,7 @@ import datajoint as dj from .datajoint._config import DATAJOINT_SERVER_PORT -from .datajoint._datajoint_server import ( - kill_datajoint_server, - run_datajoint_server, -) +from .datajoint._datajoint_server import kill_datajoint_server, run_datajoint_server thisdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(thisdir) @@ -71,10 +68,7 @@ def _set_env(): """Set environment variables.""" print("Setting datajoint and kachery environment variables.") - spyglass_base_dir = pathlib.Path(tempfile.mkdtemp()) - from spyglass.settings import load_config - - _ = load_config(str(spyglass_base_dir), force_reload=True) + os.environ["SPYGLASS_BASE_DIR"] = str(tempfile.mkdtemp()) dj.config["database.host"] = "localhost" dj.config["database.port"] = DATAJOINT_SERVER_PORT From 8c36f05a880f3b0aba5ffe34e938369b43397248 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 21 Aug 2023 13:02:18 -0700 Subject: [PATCH 3/3] WIP: blackify --- notebooks/py_scripts/01_Insert_Data.py | 2 +- tests/conftest.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index d8c0daa8b..61ca9c9e3 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -109,7 +109,7 @@ # - `minirec20230622.nwb`, .3 GB: minimal recording, on # [Box](https://ucsf.box.com/s/k3sgql6z475oia848q1rgms4zdh4rkjn) # - `montague20200802.nwb`, 8 GB: full recording, on -# [DropBox](https://www.dropbox.com/scl/fo/4i5b1z4iapetzxfps0grf/h?dl=0&preview=montague20200802_tutorial_.nwb&rlkey=ctahes9v0r7bxes8yceh86gzg) +# DropBox (link coming soon) # - For those in the UCSF network, these and many others on `/stelmo/nwb/raw` # # If you are connected to the Frank lab database, please rename any downloaded diff --git a/tests/conftest.py b/tests/conftest.py index bf8c8a0a4..eae26c2c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,10 @@ import datajoint as dj from .datajoint._config import DATAJOINT_SERVER_PORT -from .datajoint._datajoint_server import kill_datajoint_server, run_datajoint_server +from .datajoint._datajoint_server import ( + kill_datajoint_server, + run_datajoint_server, +) thisdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(thisdir)