From 92764ccbb3d31fdbb7cbad8c07bc899ef5cf05d4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 8 Aug 2023 17:41:37 -0500 Subject: [PATCH] Blackify py scrips. Continue config changes --- notebooks/README.md | 3 + notebooks/py_scripts/01_Insert_Data.py | 3 +- notebooks/py_scripts/02_Data_Sync.py | 30 ++-- notebooks/py_scripts/11_Curation.py | 21 ++- notebooks/py_scripts/14_Theta.py | 31 ++-- notebooks/py_scripts/20_Position_Trodes.py | 24 ++- notebooks/py_scripts/21_Position_DLC_1.py | 16 +- notebooks/py_scripts/22_Position_DLC_2.py | 8 +- notebooks/py_scripts/23_Position_DLC_3.py | 6 +- notebooks/py_scripts/24_Linearization.py | 2 +- notebooks/py_scripts/30_Ripple_Detection.py | 10 +- .../py_scripts/31_Extract_Mark_Indicators.py | 2 +- notebooks/py_scripts/32_Decoding_with_GPUs.py | 17 +- .../py_scripts/33_Decoding_Clusterless.py | 14 +- src/spyglass/__init__.py | 4 +- src/spyglass/common/common_session.py | 18 +- .../common/prepopulate/prepopulate.py | 15 +- src/spyglass/data_import/insert_sessions.py | 1 - src/spyglass/position/v1/dlc_utils.py | 8 +- src/spyglass/settings.py | 154 ++++++++++++++---- .../spikesorting/spikesorting_sorting.py | 19 +-- src/spyglass/utils/dj_merge_tables.py | 91 ++++++++--- 22 files changed, 315 insertions(+), 182 deletions(-) diff --git a/notebooks/README.md b/notebooks/README.md index 8ad86120b..8fc21cc31 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -48,4 +48,7 @@ root Spyglass directory pip install jupytext jupytext --to py notebooks/*ipynb mv notebooks/*py notebooks/py_scripts +black . ``` + +Unfortunately, jupytext-generated py script are not black-compliant by default. diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index c2081c2b8..3c8143882 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -53,6 +53,7 @@ # spyglass.data_import has tools for inserting NWB files into the database import spyglass.data_import as sgi + # - # ## Visualizing the database @@ -91,7 +92,7 @@ # By adding diagrams together, of adding and subtracting levels, we can visualize # key parts of Spyglass. # -# _Note:_ Notice the *Selection* tables. This is a design pattern that selects a +# _Note:_ Notice the *Selection* tables. This is a design pattern that selects a # subset of upstream items for further processing. In some cases, these also pair # the selected data with processing parameters. diff --git a/notebooks/py_scripts/02_Data_Sync.py b/notebooks/py_scripts/02_Data_Sync.py index 351c8481c..9742f7284 100644 --- a/notebooks/py_scripts/02_Data_Sync.py +++ b/notebooks/py_scripts/02_Data_Sync.py @@ -15,7 +15,7 @@ # # Sync Data # -# DEV note: +# DEV note: # - set up as host, then as client # - test as collaborator @@ -24,10 +24,10 @@ # This notebook will cover ... # -# 1. [General Kachery information](#intro) -# 2. Setting up Kachery as a [host](#host-setup). If you'll use an existing host, +# 1. [General Kachery information](#intro) +# 2. Setting up Kachery as a [host](#host-setup). If you'll use an existing host, # skip this. -# 3. Setting up Kachery in your [database](#database-setup). If you're using an +# 3. Setting up Kachery in your [database](#database-setup). If you're using an # existing database, skip this. # 4. Adding Kachery [data](#data-setup). # @@ -36,7 +36,7 @@ # # This is one notebook in a multi-part series on Spyglass. Before running, be sure -# to [setup your environment](./00_Setup.ipynb) and run some analyses (e.g. +# to [setup your environment](./00_Setup.ipynb) and run some analyses (e.g. # [LFP](./12_LFP.ipynb)). # # ### Cloud @@ -46,8 +46,8 @@ # makes it possible to share analysis results, stored in NWB files. When a user # tries to access a file, Spyglass does the following: # -# 1. Try to load from the local file system/store. -# 2. If unavailable, check if it is in the relevant sharing table (i.e., +# 1. Try to load from the local file system/store. +# 2. If unavailable, check if it is in the relevant sharing table (i.e., # `NwbKachery` or `AnalysisNWBKachery`). # 3. If present, attempt to download from the associated Kachery Resource. # @@ -64,7 +64,7 @@ # 2. `franklab.collaborator`: File sharing with collaborating labs. # 3. `franklab.public`: Public file sharing (not yet active) # -# Setting your zone can either be done as as an environment variable or an item +# Setting your zone can either be done as as an environment variable or an item # in a DataJoint config. # # - Environment variable: @@ -89,9 +89,9 @@ # See # [instructions](https://github.com/flatironinstitute/kachery-cloud/blob/main/doc/create_kachery_zone.md) # for setting up new Kachery Zones, including creating a cloud bucket and -# registering it with the Kachery team. +# registering it with the Kachery team. # -# _Notes:_ +# _Notes:_ # # - Bucket names cannot include periods, so we substitute a dash, as in # `franklab-default`. @@ -100,7 +100,7 @@ # ### Resources # # See [instructions](https://github.com/scratchrealm/kachery-resource/blob/main/README.md) -# for setting up zone resources. This allows for sharing files on demand. We +# for setting up zone resources. This allows for sharing files on demand. We # suggest using the same name for the zone and resource. # # _Note:_ For each zone, you need to run the local daemon that listens for @@ -167,7 +167,7 @@ # Once the zone exists, we can add `AnalysisNWB` files we want to share by adding # entries to the `AnalysisNwbfileKacherySelection` table. # -# _Note:_ This step depends on having previously run an analysis on the example +# _Note:_ This step depends on having previously run an analysis on the example # file. # + @@ -192,13 +192,13 @@ sgs.AnalysisNwbfileKachery.populate() # + [markdown] jupyter={"outputs_hidden": true} -# If all of that worked, +# If all of that worked, # # 1. go to https://kachery-gateway.figurl.org/admin?zone=your_zone # (changing your_zone to the name of your zone) # 2. Go to the Admin/Authorization Settings tab -# 3. Add the GitHub login names and permissions for the users you want to share -# with. +# 3. Add the GitHub login names and permissions for the users you want to share +# with. # # If those users can connect to your database, they should now be able to use the # `.fetch_nwb()` method to download any `AnalysisNwbfiles` that have been shared diff --git a/notebooks/py_scripts/11_Curation.py b/notebooks/py_scripts/11_Curation.py index dd22aacf8..d4d0b0c5e 100644 --- a/notebooks/py_scripts/11_Curation.py +++ b/notebooks/py_scripts/11_Curation.py @@ -50,6 +50,7 @@ dj.config.load("dj_local_conf.json") # load config for database connection info from spyglass.spikesorting import SpikeSorting + # - # ## Spikes Sorted @@ -80,21 +81,29 @@ f"https://sortingview.vercel.app/workspace?workspace={workspace_uri}&channel=franklab" ) -# This will take you to a workspace on the `sortingview` app. The workspace, which you can think of as a list of recording and associated sorting objects, was created at the end of spike sorting. On the workspace view, you will see a set of recordings that have been added to the workspace. +# This will take you to a workspace on the `sortingview` app. The workspace, which +# you can think of as a list of recording and associated sorting objects, was +# created at the end of spike sorting. On the workspace view, you will see a set +# of recordings that have been added to the workspace. # # ![Workspace view](./../notebook-images/workspace.png) # -# Clicking on a recording then takes you to a page that gives you information about the recording as well as the associated sorting objects. +# Clicking on a recording then takes you to a page that gives you information +# about the recording as well as the associated sorting objects. # # ![Recording view](./../notebook-images/recording.png) # -# Click on a sorting to see the curation view. Try exploring the many visualization widgets. +# Click on a sorting to see the curation view. Try exploring the many +# visualization widgets. # # ![Unit table](./../notebook-images/unittable.png) # -# The most important is the `Units Table` and the `Curation` menu, which allows you to give labels to the units. The curation labels will persist even if you suddenly lose connection to the app; this is because the curaiton actions are appended to the workspace as soon as they are created. Note that if you are not logged in with your Google account, `Curation` menu may not be visible. Log in and refresh the page to access this feature. +# The most important is the `Units Table` and the `Curation` menu, which allows +# you to give labels to the units. The curation labels will persist even if you +# suddenly lose connection to the app; this is because the curation actions are +# appended to the workspace as soon as they are created. Note that if you are not +# logged in with your Google account, `Curation` menu may not be visible. Log in +# and refresh the page to access this feature. # # ![Curation](./../notebook-images/curation.png) # - - diff --git a/notebooks/py_scripts/14_Theta.py b/notebooks/py_scripts/14_Theta.py index 0ce46dd92..5ca3eaaba 100644 --- a/notebooks/py_scripts/14_Theta.py +++ b/notebooks/py_scripts/14_Theta.py @@ -27,9 +27,9 @@ # - For additional info on DataJoint syntax, including table definitions and # inserts, see # [the Insert Data notebook](./01_Insert_Data.ipynb) -# - To run this notebook, you should have already completed the -# [LFP](./12_LFP.ipynb) notebook and populated the `LFPBand` table. -# +# - To run this notebook, you should have already completed the +# [LFP](./12_LFP.ipynb) notebook and populated the `LFPBand` table. +# # In this tutorial, we demonstrate how to generate analytic signals from the LFP # data, as well as how to compute theta phases and power. @@ -55,7 +55,6 @@ warnings.simplefilter("ignore", category=DeprecationWarning) warnings.simplefilter("ignore", category=ResourceWarning) - # - # ## Acquire Signal @@ -76,7 +75,7 @@ # We do not need all electrodes for theta phase/power, so we define a list for # analyses. When working with full data, this list might limit to hippocampal -# reference electrodes. +# reference electrodes. # # Make sure that the chosen electrodes already exist in the LFPBand data; if not, # go to the LFP tutorial to generate them. @@ -84,25 +83,25 @@ # + electrode_list = [0] -all_electrodes = ( # All available electrode ids +all_electrodes = ( # All available electrode ids (lfp_band.LFPBandV1() & lfp_key).fetch_nwb()[0]["lfp_band"] ).electrodes.data[:] -np.isin(electrode_list, all_electrodes) # Check if our list is in 'all' +np.isin(electrode_list, all_electrodes) # Check if our list is in 'all' # - # Next, we'll compute the theta analytic signal. # + -theta_analytic_signal = (lfp_band.LFPBandV1() & lfp_key).compute_analytic_signal( - electrode_list=electrode_list -) +theta_analytic_signal = ( + lfp_band.LFPBandV1() & lfp_key +).compute_analytic_signal(electrode_list=electrode_list) theta_analytic_signal # - # In the dataframe above, the index is the timestamps, and the columns are the -# analytic sinals of theta band (complex numbers) for each electrode. +# analytic signals of theta band (complex numbers) for each electrode. # ## Compute phase and power # @@ -162,9 +161,9 @@ fig.tight_layout() ax1.set_title( - f"Theta band amplitude and phase, electode {electrode_id}", + f"Theta band amplitude and phase, electrode {electrode_id}", fontsize=20, -); +) # - # We can also plot the theta power. @@ -184,9 +183,7 @@ ) ax.tick_params(axis="y", labelcolor="k") ax.set_title( - f"Theta band power, electode {electrode_id}", + f"Theta band power, electrode {electrode_id}", fontsize=20, -); +) # - - - diff --git a/notebooks/py_scripts/20_Position_Trodes.py b/notebooks/py_scripts/20_Position_Trodes.py index 5392c6677..bb6faf809 100644 --- a/notebooks/py_scripts/20_Position_Trodes.py +++ b/notebooks/py_scripts/20_Position_Trodes.py @@ -76,22 +76,22 @@ nwb_copy_file_name = sgu.nwb_helper_fn.get_nwb_copy_filename(nwb_file_name) sgc.common_behav.RawPosition() & {"nwb_file_name": nwb_copy_file_name} -# ## Setting parameters +# ## Setting parameters # # Parameters are set by the `TrodesPosParams` table, with a `default` set # available. To adjust the default, insert a new set into this table. The # parameters are... # -# - `max_separation`, default 9 cm: maxmium acceptable distance between red and -# green LEDs. -# - If exceeded, the times are marked as NaNs and inferred by interpolation. +# - `max_separation`, default 9 cm: maximium acceptable distance between red and +# green LEDs. +# - If exceeded, the times are marked as NaNs and inferred by interpolation. # - Useful when the inferred LED position tracks a reflection instead of the # true position. -# - `max_speed`, default 300.0 cm/s: maximum speed the animal can move. -# - If exceeded, times are marked as NaNs and inferred by interpolation. -# - Useful to prevent big jumps in position. +# - `max_speed`, default 300.0 cm/s: maximum speed the animal can move. +# - If exceeded, times are marked as NaNs and inferred by interpolation. +# - Useful to prevent big jumps in position. # - `position_smoothing_duration`, default 0.100 s: LED position smoothing before -# computing average position to get head position. +# computing average position to get head position. # - `speed_smoothing_std_dev`, default 0.100 s: standard deviation of the Gaussian # kernel used to smooth the head speed. # - `front_led1`, default 1 (True), use `xloc`/`yloc`: Which LED is the front LED @@ -120,7 +120,7 @@ # ## Select interval # Later, we'll pair the above parameters with an interval from our NWB file and -# insert into `TrodesPosSelection`. +# insert into `TrodesPosSelection`. # # First, let's select an interval from the `IntervalList` table. # @@ -133,7 +133,7 @@ # the video itself. # # `fetch1_dataframe` returns the position of the LEDs as a pandas dataframe where -# time is the index. +# time is the index. interval_list_name = "pos 0 valid times" # pos # is epoch # minus 1 raw_position_df = ( @@ -263,7 +263,7 @@ # ## Upsampling position # -# Sometimes we need the position data in smaller in time bins, which can be +# Sometimes we need the position data in smaller in time bins, which can be # achieved with upsampling using the following parameters. # # - `is_upsampled`, default 0 (False): If 1, perform upsampling. @@ -361,5 +361,3 @@ axes[1].set_ylabel("y-velocity [cm/s]", fontsize=18) axes[1].set_title("Upsampled Head Velocity", fontsize=28) # - - - diff --git a/notebooks/py_scripts/21_Position_DLC_1.py b/notebooks/py_scripts/21_Position_DLC_1.py index 23af36d86..6782fa49f 100644 --- a/notebooks/py_scripts/21_Position_DLC_1.py +++ b/notebooks/py_scripts/21_Position_DLC_1.py @@ -29,12 +29,12 @@ # inserts, see # [the Insert Data notebook](./01_Insert_Data.ipynb) # -# This tutorial will extract position via DeepLabCut (DLC). It will walk through... +# This tutorial will extract position via DeepLabCut (DLC). It will walk through... # - creating a DLC project # - extracting and labeling frames # - training your model # -# If you already have a pretrained project, you can either skip to the +# If you already have a pretrained project, you can either skip to the # [next tutorial](./22_Position_DLC_2.ipynb) to load it into the database, or skip # to the [following tutorial](./23_Position_DLC_3.ipynb) to start pose estimation # with a model that is already inserted. @@ -82,11 +82,11 @@ #
# Notes: @@ -125,7 +125,7 @@ # - A team name, as shown in `LabTeam` for setting permissions. Here, we'll # use "LorenLab". # - A `project_name`, as a unique identifier for this DLC project. Here, we'll use -# __"tutorial_scratch_yourinitials"__ +# __"tutorial_scratch_yourinitials"__ # - `bodyparts` is a list of body parts for which we want to extract position. # The pre-labeled frames we're using include the bodyparts listed below. # - Number of frames to extract/label as `frames_per_video`. A true project might @@ -171,14 +171,14 @@ # This step and beyond should be run on a GPU-enabled machine. #
-# #### [DLCModelTraining](#ToC) +# #### [DLCModelTraining](#ToC) # # Please make sure you're running this notebook on a GPU-enabled machine. # # Now that we've imported existing frames, we can get ready to train our model. # # First, we'll need to define a set of parameters for `DLCModelTrainingParams`, which will get used by DeepLabCut during training. Let's start with `gputouse`, -# which determines which GPU core to use. +# which determines which GPU core to use. # # The cell below determines which core has space and set the `gputouse` variable # accordingly. @@ -298,7 +298,7 @@ # ### Next Steps # -# With our trained model in place, we're ready to move on to +# With our trained model in place, we're ready to move on to # [pose estimation](./23_Position_DLC_3.ipynb). # ### [Return To Table of Contents](#TableOfContents)
diff --git a/notebooks/py_scripts/22_Position_DLC_2.py b/notebooks/py_scripts/22_Position_DLC_2.py index 0d1661f5a..6876a217d 100644 --- a/notebooks/py_scripts/22_Position_DLC_2.py +++ b/notebooks/py_scripts/22_Position_DLC_2.py @@ -31,7 +31,7 @@ # # This is a tutorial will cover how to extract position given a pre-trained DeepLabCut (DLC) model. It will walk through adding your DLC model to Spyglass. # -# If you already have a model in the database, skip to the +# If you already have a model in the database, skip to the # [next tutorial](./23_Position_DLC_3.ipynb). # ## Imports @@ -86,7 +86,7 @@ #
# Notes: @@ -119,7 +119,7 @@ # #### [DLCModel](#ToC) -# The `DLCModelInput` table has `dlc_model_name` and `project_name` as primary keys and `project_path` as a secondary key. +# The `DLCModelInput` table has `dlc_model_name` and `project_name` as primary keys and `project_path` as a secondary key. sgp.DLCModelInput() @@ -186,7 +186,7 @@ # ### Next Steps # -# With our trained model in place, we're ready to move on to +# With our trained model in place, we're ready to move on to # [pose estimation](./23_Position_DLC_3.ipynb). # ### [`Return To Table of Contents`](#ToC)
diff --git a/notebooks/py_scripts/23_Position_DLC_3.py b/notebooks/py_scripts/23_Position_DLC_3.py index 0f2aef08a..0fb6a4f8f 100644 --- a/notebooks/py_scripts/23_Position_DLC_3.py +++ b/notebooks/py_scripts/23_Position_DLC_3.py @@ -24,7 +24,7 @@ # inserts, see # [the Insert Data notebook](./01_Insert_Data.ipynb) # -# This tutorial will extract position via DeepLabCut (DLC). It will walk through... +# This tutorial will extract position via DeepLabCut (DLC). It will walk through... # - executing pose estimation # - processing the pose estimation output to extract a centroid and orientation # - inserting the resulting information into the `IntervalPositionInfo` table @@ -287,7 +287,7 @@ # #### [DLCOrientation](#TableOfContents) -# We'll go through a similar process for orientation. +# We'll go through a similar process for orientation. pprint(sgp.DLCOrientationParams.get_default()) dlc_orientation_params_name = "default" @@ -311,7 +311,7 @@ # #### [DLCPos](#TableOfContents) -# After processing the position data, we have to do a few table manipulations to standardize various outputs. +# After processing the position data, we have to do a few table manipulations to standardize various outputs. # # To summarize, we brought in a pretrained DLC project, used that model to run pose estimation on a new behavioral video, smoothed and interpolated the result, formed a cohort of bodyparts, and determined the centroid and orientation of this cohort. # diff --git a/notebooks/py_scripts/24_Linearization.py b/notebooks/py_scripts/24_Linearization.py index 1d246bb3d..0e4f3b7c1 100644 --- a/notebooks/py_scripts/24_Linearization.py +++ b/notebooks/py_scripts/24_Linearization.py @@ -252,7 +252,7 @@ # Running `fetch1_dataframe` will retrieve the linear position data, including... # -# - `time`: datafame index +# - `time`: dataframe index # - `linear_position`: 1D linearized position # - `track_segment_id`: index number of the edges given to track graph # - `projected_{x,y}_position`: 2D position projected to the track graph diff --git a/notebooks/py_scripts/30_Ripple_Detection.py b/notebooks/py_scripts/30_Ripple_Detection.py index be703b25d..581418747 100644 --- a/notebooks/py_scripts/30_Ripple_Detection.py +++ b/notebooks/py_scripts/30_Ripple_Detection.py @@ -71,7 +71,7 @@ # ?sgr.RippleLFPSelection.set_lfp_electrodes -# We'll need the `nwb_file_name`, an `electrode_list`, and to a `group_name`. +# We'll need the `nwb_file_name`, an `electrode_list`, and to a `group_name`. # # - By default, `group_name` is set to CA1 for ripple detection, but we could # alternatively use PFC. @@ -162,10 +162,10 @@ # - `speed_name`: the name of the speed parameters in `IntervalPositionInfo` # # For the `Kay_ripple_detector` (options are currently Kay and Karlsson, see `ripple_detection` package for specifics) the parameters are: -# -# - `speed_threshold` (cm/s): maxmimum speed the animal can move +# +# - `speed_threshold` (cm/s): maximum speed the animal can move # - `minimum_duration` (s): minimum time above threshold -# - `zscore_threshold` (std): mimimum value to be considered a ripple, in standard +# - `zscore_threshold` (std): minimum value to be considered a ripple, in standard # deviations from mean # - `smoothing_sigma` (s): how much to smooth the signal in time # - `close_ripple_threshold` (s): exclude ripples closer than this amount @@ -187,7 +187,7 @@ # We'll use the `head_speed` above as part of `RippleParameters`. -# ## Run Ripple Detection +# ## Run Ripple Detection # # diff --git a/notebooks/py_scripts/31_Extract_Mark_Indicators.py b/notebooks/py_scripts/31_Extract_Mark_Indicators.py index 6e9c9485f..d07f8f047 100644 --- a/notebooks/py_scripts/31_Extract_Mark_Indicators.py +++ b/notebooks/py_scripts/31_Extract_Mark_Indicators.py @@ -20,7 +20,7 @@ # _Developer Note:_ if you may make a PR in the future, be sure to copy this # notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. # -# This is one notebook in a multi-part series on clusterless decoding in Spyglass +# This is one notebook in a multi-part series on clusterless decoding in Spyglass # # - To set up your Spyglass environment and database, see # [the Setup notebook](./00_Setup.ipynb) diff --git a/notebooks/py_scripts/32_Decoding_with_GPUs.py b/notebooks/py_scripts/32_Decoding_with_GPUs.py index d7c972963..8127ac028 100644 --- a/notebooks/py_scripts/32_Decoding_with_GPUs.py +++ b/notebooks/py_scripts/32_Decoding_with_GPUs.py @@ -21,10 +21,10 @@ # notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. # # This is one notebook in a multi-part series on decoding in Spyglass. To set up -# your Spyglass environment and database, see +# your Spyglass environment and database, see # [the Setup notebook](./00_Setup.ipynb). # -# In this tutorial, we'll set up GPU access for subsequent decoding analyses. While this notebook doesn't have any direct prerequisites, you will need +# In this tutorial, we'll set up GPU access for subsequent decoding analyses. While this notebook doesn't have any direct prerequisites, you will need # [Spike Sorting](./02_Spike_Sorting.ipynb) data for the next step. # @@ -32,7 +32,7 @@ # # ### Connecting -# +# # Members of the Frank Lab have access to two GPU cluster, `breeze` and `zephyr`. # To access them, specify the cluster when you `ssh`, with the default port: # @@ -40,7 +40,7 @@ # # There are currently 10 available GPUs, each with 80 GB RAM, each referred to by their IDs (0 - 9). # -# +# # # ### Selecting a GPU # @@ -58,7 +58,7 @@ # ### Which GPU? # # You can see which GPUs are occupied by running the command `nvidia-smi` in -# a terminal (or `!nvidia-smi` in a notebook). Pick a GPU with low memory usage. +# a terminal (or `!nvidia-smi` in a notebook). Pick a GPU with low memory usage. # # In the output below, GPUs 1, 4, 6, and 7 have low memory use and power draw (~42W), are probably not in use. @@ -70,7 +70,7 @@ # # Other ways to monitor GPU usage are: # -# - A +# - A # [jupyter widget by nvidia](https://github.com/rapidsai/jupyterlab-nvdashboard) # to monitor GPU usage in the notebook # - A [terminal program](https://github.com/peci1/nvidia-htop) like nvidia-smi @@ -185,7 +185,7 @@ # conda install -c rapidsai -c nvidia -c conda-forge dask-cuda # ``` # -# We will set up a client to select GPUs. By default, this is all available +# We will set up a client to select GPUs. By default, this is all available # GPUs. Below, we select a subset using the `CUDA_VISIBLE_DEVICES`. # + @@ -202,6 +202,7 @@ # # In the example below, we run `test_gpu` on each item of `data` where each item is processed on a different GPU. + # + def setup_logger(name_logfile, path_logfile): """Sets up a logger for each function that outputs @@ -250,5 +251,3 @@ def test_gpu(x, ind): # - # This example also shows how to create a log file for each item in data with the `setup_logger` function. - - diff --git a/notebooks/py_scripts/33_Decoding_Clusterless.py b/notebooks/py_scripts/33_Decoding_Clusterless.py index a6a3ad0fc..db0ecedcf 100644 --- a/notebooks/py_scripts/33_Decoding_Clusterless.py +++ b/notebooks/py_scripts/33_Decoding_Clusterless.py @@ -24,15 +24,15 @@ # # - To set up your Spyglass environment and database, see # [the Setup notebook](./00_Setup.ipynb) -# - This tutorial assumes you've already -# [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded -# position data. If 1D decodint, this data should also be +# - This tutorial assumes you've already +# [extracted marks](./31_Extract_Mark_Indicators.ipynb), as well as loaded +# position data. If 1D decoding, this data should also be # [linearized](./24_Linearization.ipynb). # - Ths tutorial also assumes you're familiar with how to run processes on GPU, as # presented in [this notebook](./32_Decoding_with_GPUs.ipynb) # # Clusterless decoding can be performed on either 1D or 2D data. A few steps in -# this notebook will refer to a `decode_1d` variable set in +# this notebook will refer to a `decode_1d` variable set in # [select data](#select-data) to include these steps. # @@ -115,7 +115,7 @@ # for items that look overly correlated (strong diagonal on the off-diagonal # plots) and extreme amplitudes. # -# For tutorial purposes, we only look at the first 2 plots, but removing this +# For tutorial purposes, we only look at the first 2 plots, but removing this # argument will show all plots. sgd_clusterless.UnitMarksIndicator.plot_all_marks(marks, plot_limit=2) @@ -125,7 +125,7 @@ # ### Get position -# Next, we'll grab the 2D position data from `IntervalPositionInfo` table. +# Next, we'll grab the 2D position data from `IntervalPositionInfo` table. # # _Note:_ Position will need to be upsampled to our decoding frequency (500 Hz). # See [this notebook](./20_Position_Trodes.ipynb#upsampling-position) for more @@ -261,7 +261,7 @@ # ## Decoding # -# After sanity checks, we can finally get to decoding. +# After sanity checks, we can finally get to decoding. # # _Note:_ Portions of the code below have been integrated into # `spyglass.decoding`, but are presented here in full. diff --git a/src/spyglass/__init__.py b/src/spyglass/__init__.py index 81b2e02b1..3f8887d2d 100644 --- a/src/spyglass/__init__.py +++ b/src/spyglass/__init__.py @@ -1,4 +1,4 @@ -from .settings import load_config +from .settings import config try: import ndx_franklab_novela @@ -11,5 +11,3 @@ pass __all__ = ["ndx_franklab_novela", "__version__", "config"] - -config = load_config() diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 4054de52e..949f3cd4e 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,4 +1,3 @@ -import os import datajoint as dj from .common_device import CameraDevice, DataAcquisitionDevice, Probe @@ -6,6 +5,7 @@ from .common_nwbfile import Nwbfile from .common_subject import Subject from ..utils.nwb_helper_fn import get_nwb_file, get_config +from ..settings import config schema = dj.schema("common_session") @@ -243,11 +243,17 @@ def get_group_sessions(session_group_name: str): def create_spyglass_view(session_group_name: str): import figurl as fig - FIGURL_CHANNEL = os.getenv("FIGURL_CHANNEL") - assert FIGURL_CHANNEL, "Environment variable not set: FIGURL_CHANNEL" - data = {"type": "spyglassview", "sessionGroupName": session_group_name} - F = fig.Figure(view_url="gs://figurl/spyglassview-1", data=data) - return F + FIGURL_CHANNEL = config.get("FIGURL_CHANNEL") + if not FIGURL_CHANNEL: + raise ValueError("FIGURL_CHANNEL conifg/env variagle not set") + + return fig.Figure( + view_url="gs://figurl/spyglassview-1", + data={ + "type": "spyglassview", + "sessionGroupName": session_group_name, + }, + ) # The reason this is not implemented as a dj.Part is that diff --git a/src/spyglass/common/prepopulate/prepopulate.py b/src/spyglass/common/prepopulate/prepopulate.py index ce36a33be..39896bc20 100644 --- a/src/spyglass/common/prepopulate/prepopulate.py +++ b/src/spyglass/common/prepopulate/prepopulate.py @@ -5,20 +5,13 @@ import datajoint as dj import yaml -from ...settings import load_config +from ...settings import base_dir def prepopulate_default(): - """Prepopulate the database with the default values in SPYGLASS_BASE_DIR/entries.yaml.""" - - base_dir = os.getenv("SPYGLASS_BASE_DIR", None) or load_config().get( - "SPYGLASS_BASE_DIR" - ) - if not base_dir: - raise ValueError( - "You must set SPYGLASS_BASE_DIR or provide the base_dir argument" - ) - + """ + Populate the database with default values in SPYGLASS_BASE_DIR/entries.yaml + """ yaml_path = pathlib.Path(base_dir) / "entries.yaml" if os.path.exists(yaml_path): populate_from_yaml(yaml_path) diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 9c9cc906e..0a80d0a52 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -20,7 +20,6 @@ def insert_sessions(nwb_file_names: Union[str, List[str]]): File names in raw directory ($SPYGLASS_RAW_DIR) pointing to existing .nwb files. Each file represents a session. """ - _ = load_config() if not isinstance(nwb_file_names, list): nwb_file_names = [nwb_file_names] diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 84a5384d9..eb2cf2d1b 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -18,6 +18,8 @@ import pandas as pd from tqdm import tqdm as tqdm +from ...settings import raw_dir + def _set_permissions(directory, mode, username: str, groupname: str = None): """ @@ -325,9 +327,8 @@ def get_video_path(key): VideoFile() & {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} ).fetch1() - nwb_path = ( - f"{os.getenv('SPYGLASS_BASE_DIR')}/raw/{video_info['nwb_file_name']}" - ) + nwb_path = f"{raw_dir}/{video_info['nwb_file_name']}" + with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out: nwb_file = in_out.read() nwb_video = nwb_file.objects[video_info["video_file_object_id"]] @@ -338,6 +339,7 @@ def get_video_path(key): video_filename = video_filepath.split(video_dir)[-1] meters_per_pixel = nwb_video.device.meters_per_pixel timestamps = np.asarray(nwb_video.timestamps) + return video_dir, video_filename, meters_per_pixel, timestamps diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 37f00ae5d..da366fb7d 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -20,17 +20,18 @@ sorting="spikesorting", # "SPYGLASS_SORTING_DIR" waveforms="waveforms", temp="tmp", + video="video", ), kachery=dict( - cloud="kachery-storage", - storage="kachery-storage", + cloud="kachery_storage", + storage="kachery_storage", temp="tmp", ), ) def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: - """Gets syglass dirs from dj.config or environment variables. + """Gets Spyglass dirs from dj.config or environment variables. Uses a relative_dirs dict defined in settings.py to (a) gather user settings from dj.config or os environment variables or defaults relative to @@ -72,6 +73,7 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: raise ValueError( "SPYGLASS_BASE_DIR not defined in dj.config or os env vars" ) + config_dirs = {"SPYGLASS_BASE_DIR": resolved_base} for prefix, dirs in relative_dirs.items(): for dir, dir_str in dirs.items(): @@ -94,16 +96,56 @@ def load_config(base_dir: Path = None, force_reload: bool = False) -> dict: ) } - _set_env_with_dict({**config_dirs, **kachery_zone_dict}) + loaded_env = _load_env_vars(env_defaults) + _set_env_with_dict({**config_dirs, **kachery_zone_dict, **loaded_env}) _mkdirs_from_dict_vals(config_dirs) _set_dj_config_stores(config_dirs) - config = dict(**config_defaults, **config_dirs, **kachery_zone_dict) + config = dict( + **config_defaults, **config_dirs, **kachery_zone_dict, **loaded_env + ) config_loaded = True return config -def base_dir() -> str: +def _load_env_vars(env_dict: dict) -> dict: + """Loads env vars from dict {str: Any}.""" + loaded_dict = {} + for var, val in env_dict.items(): + loaded_dict[var] = os.getenv(var, val) + return loaded_dict + + +def _set_env_with_dict(env_dict: dict): + """Sets env vars from dict {str: Any} where Any is convertible to str.""" + for var, val in env_dict.items(): + os.environ[var] = str(val) + + +def _mkdirs_from_dict_vals(dir_dict: dict): + for dir_str in dir_dict.values(): + Path(dir_str).mkdir(exist_ok=True) + + +def _set_dj_config_stores(dir_dict: dict): + raw_dir = dir_dict["SPYGLASS_RAW_DIR"] + analysis_dir = dir_dict["SPYGLASS_ANALYSIS_DIR"] + + dj.config["stores"] = { + "raw": { + "protocol": "file", + "location": str(raw_dir), + "stage": str(raw_dir), + }, + "analysis": { + "protocol": "file", + "location": str(analysis_dir), + "stage": str(analysis_dir), + }, + } + + +def load_base_dir() -> str: """Retrieve the base directory from the configuration. Returns @@ -117,13 +159,13 @@ def base_dir() -> str: return config.get("SPYGLASS_BASE_DIR") -def raw_dir() -> str: - """Retrieve the base directory from the configuration. +def load_raw_dir() -> str: + """Retrieve the raw directory from the configuration. Returns ------- str - The base directory path. + The raw directory path. """ global config if not config_loaded or not config: @@ -131,31 +173,79 @@ def raw_dir() -> str: return config.get("SPYGLASS_RAW_DIR") -def _set_env_with_dict(env_dict: dict): - """Sets env vars from dict {str: Any} where Any is convertible to str.""" - env_to_set = {**env_defaults, **env_dict} - for var, val in env_to_set.items(): - os.environ[var] = str(val) +def load_analysis_dir() -> str: + """Retrieve the analysis directory from the configuration. + Returns + ------- + str + The recording directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_ANALYSIS_DIR") -def _mkdirs_from_dict_vals(dir_dict: dict): - for dir_str in dir_dict.values(): - Path(dir_str).mkdir(exist_ok=True) +def load_recording_dir() -> str: + """Retrieve the recording directory from the configuration. -def _set_dj_config_stores(dir_dict: dict): - raw_dir = dir_dict["SPYGLASS_RAW_DIR"] - analysis_dir = dir_dict["SPYGLASS_ANALYSIS_DIR"] + Returns + ------- + str + The recording directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_RECORDING_DIR") - dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - "analysis": { - "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), - }, - } + +def load_sorting_dir() -> str: + """Retrieve the sorting directory from the configuration. + + Returns + ------- + str + The sorting directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_SORTING_DIR") + + +def load_temp_dir() -> str: + """Retrieve the temp directory from the configuration. + + Returns + ------- + str + The temp directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_TEMP_DIR") + + +def load_waveform_dir() -> str: + """Retrieve the temp directory from the configuration. + + Returns + ------- + str + The temp directory path. + """ + global config + if not config_loaded or not config: + config = load_config() + return config.get("SPYGLASS_TEMP_DIR") + + +base_dir = load_base_dir() +raw_dir = load_raw_dir() +recording_dir = load_recording_dir() +temp_dir = load_temp_dir() +analysis_dir = load_analysis_dir() +sorting_dir = load_sorting_dir() diff --git a/src/spyglass/spikesorting/spikesorting_sorting.py b/src/spyglass/spikesorting/spikesorting_sorting.py index afc261f1c..08c807626 100644 --- a/src/spyglass/spikesorting/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/spikesorting_sorting.py @@ -14,8 +14,7 @@ from ..common.common_lab import LabMember, LabTeam from ..common.common_nwbfile import AnalysisNwbfile -from ..settings import load_config -from ..utils.dj_helper_fn import fetch_nwb +from ..settings import temp_dir, sorting_dir from .spikesorting_artifact import ArtifactRemovedIntervalList from .spikesorting_recording import ( SpikeSortingRecording, @@ -141,7 +140,7 @@ def make(self, key: dict): # first, get the timestamps timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - fs = recording.get_sampling_frequency() + _ = recording.get_sampling_frequency() # then concatenate the recordings # Note: the timestamps are lost upon concatenation, # i.e. concat_recording.get_times() doesn't return true timestamps anymore. @@ -192,9 +191,7 @@ def make(self, key: dict): "sorter", "sorter_params" ) - sorter_temp_dir = tempfile.TemporaryDirectory( - dir=os.getenv("SPYGLASS_TEMP_DIR") - ) + sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) # add tempdir option for mountainsort sorter_params["tempdir"] = sorter_temp_dir.name @@ -229,7 +226,7 @@ def make(self, key: dict): print("Saving sorting results...") - sorting_folder = Path(load_config().get("SPYGLASS_SORTING_DIR")) + sorting_folder = Path(sorting_dir) sorting_name = self._get_sorting_name(key) key["sorting_path"] = str(sorting_folder / Path(sorting_name)) @@ -290,16 +287,14 @@ def nightly_cleanup(self): This should be run after AnalysisNwbFile().nightly_cleanup() """ # get a list of the files in the spike sorting storage directory - dir_names = next(os.walk(os.environ["SPYGLASS_SORTING_DIR"]))[1] + dir_names = next(os.walk(sorting_dir))[1] # now retrieve a list of the currently used analysis nwb files analysis_file_names = self.fetch("analysis_file_name") for dir in dir_names: if dir not in analysis_file_names: - full_path = str(Path(os.environ["SPYGLASS_SORTING_DIR"]) / dir) + full_path = str(Path(sorting_dir) / dir) print(f"removing {full_path}") - shutil.rmtree( - str(Path(os.environ["SPYGLASS_SORTING_DIR"]) / dir) - ) + shutil.rmtree(str(Path(sorting_dir) / dir)) @staticmethod def _get_sorting_name(key): diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index cc399cc30..f79dde154 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -3,7 +3,7 @@ from pprint import pprint import datajoint as dj -from datajoint.condition import AndList, make_condition +from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html from datajoint.utils import from_camel_case, to_camel_case @@ -39,16 +39,16 @@ def __init__(self): if not self.is_declared: if self.definition != merge_def: print( - "WARNING: merge table declared with non-default definition\n\t" + "WARNING: merge table with non-default definition\n\t" + f"Expected: {merge_def.strip()}\n\t" + f"Actual : {self.definition.strip()}" ) for part in self.parts(as_objects=True): if part.primary_key != self.primary_key: print( - f"WARNING: unexpected primary key for {part.table_name}\n\t" - + f"Expected: {self.primary_key}\n\t" - + f"Actual : {part.primary_key}" + f"WARNING: unexpected primary key in {part.table_name}" + + f"\n\tExpected: {self.primary_key}" + + f"\n\tActual : {part.primary_key}" ) @classmethod @@ -187,7 +187,7 @@ def _merge_restrict_parents( @classmethod def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: - """Merged view, including null entries for columns unique to one part table. + """Merged view, including null entries for columns unique to one part. Parameters --------- @@ -237,7 +237,7 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: def _merge_insert( cls, rows: list, part_name: str = None, mutual_exclusvity=True, **kwargs ) -> None: - """Insert rows into merge table, ensuring db integrity and mutual exclusivity + """Insert rows into merge, ensuring db integrity and mutual exclusivity Parameters --------- @@ -433,8 +433,8 @@ def merge_delete_parent( Optional restriction to apply before deletion from parents. If not provided, delete all entries present in Merge Table. dry_run: bool - Default True. If true, return list of tables with entries that would be - deleted. Otherwise, table entries. + Default True. If true, return list of tables with entries that would + be deleted. Otherwise, table entries. kwargs: dict Additional keyword arguments for DataJoint delete. """ @@ -645,7 +645,7 @@ def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: if not results: print( "No merge_fetch results.\n\t" - + "If not restriction, try: `M.merge_fetch(True,'attr')\n\t" + + "If not restricting, try: `M.merge_fetch(True,'attr')\n\t" + "If restricting by source, use dict: " + "`M.merge_fetch({'source':'X'})" ) @@ -665,14 +665,12 @@ def merge_populate(source: str, key=None): # Aliased because underscore otherwise excludes from API docs. -_Merge = Merge - -# Underscore as class name avoids errors when this included in a Diagram -# Aliased because underscore otherwise excludes from API docs. - - def delete_downstream_merge( - table: dj.Table, restriction: str = True, dry_run=True, **kwargs + table: dj.Table, + restriction: str = True, + dry_run=True, + recurse_level=2, + **kwargs, ) -> list: """Given a table/restriction, id or delete relevant downstream merge entries @@ -686,6 +684,8 @@ def delete_downstream_merge( dry_run: bool Default True. If true, return list of tuples, merge/part tables downstream of table input. Otherwise, delete merge/part table entries. + recurse_level: int + Default 2. Depth to recurse into table descendants. kwargs: dict Additional keyword arguments for DataJoint delete. @@ -694,21 +694,24 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ - restriction = AndList((table.restriction, restriction)) + if table.restriction: + print( + f"Warning: ignoring table restriction: {table.restriction}.\n\t" + + "Please pass restrictions as an arg" + ) - if not restriction: - restriction = True + descendants = _unique_descendants(table, recurse_level) # Adapted from Spyglass PR 535 # dj.utils.get_master could maybe help here, but it uses names, not objs - merge_pairs = [ # get each merge/part table + merge_table_pairs = [ # get each merge/part table (master, descendant.restrict(restriction)) - for descendant in table.descendants(as_objects=True) # given tbl desc + for descendant in descendants for master in descendant.parents(as_objects=True) # and those parents # if is a part table (using a dunder not immediately after schema name) if "__" in descendant.full_table_name.replace("`.`__", "") # and it is not in not in direct descendants - and master.full_table_name not in table.descendants(as_objects=False) + and master.full_table_name not in descendants # and it uses our reserved primary key in attributes and RESERVED_PRIMARY_KEY in master.heading.attributes.keys() ] @@ -716,7 +719,7 @@ def delete_downstream_merge( # restrict the merge table based on uuids in part merge_pairs = [ (merge & uuids, part) # don't need part for del, but show on dry_run - for merge, part in merge_pairs + for merge, part in merge_table_pairs for uuids in part.fetch(RESERVED_PRIMARY_KEY, as_dict=True) ] @@ -725,3 +728,43 @@ def delete_downstream_merge( for merge_table, _ in merge_pairs: merge_table.delete(**kwargs) + + +def _unique_descendants( + table: dj.Table, recurse_level: int, return_names: bool = False +) -> list: + """Recurisively find unique descendants of a given table + + Parameters + ---------- + table: dj.Table + The node in the tree from which to find descendants. + recurse_level: int + The maximum level of descendants to find. + return_names: bool + If True, return names of descendants found. + + Returns + ------- + List[dj.Table] + List descendants found when recurisively called to recurse_level + """ + + if recurse_level == 0: + return [] + + descendants = {} + + def recurse_descendants(sub_table, level): + for descendant in sub_table.descendants(as_objects=True): + if descendant.full_table_name not in descendants: + descendants[descendant.full_table_name] = descendant + if level > 1: + recurse_descendants(descendant, level - 1) + + recurse_descendants(table, recurse_level) + + if return_names: + return list(descendants.keys()) + + return list(descendants.values())