From 6fa248b47bc6e2809a442a70d1a8f9cd7c460700 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 30 Jan 2024 09:52:54 -0600 Subject: [PATCH] Blackify 24.1.1 --- .pre-commit-config.yaml | 2 +- config/dj_config.py | 4 +- src/spyglass/common/common_behav.py | 10 +++-- src/spyglass/common/common_device.py | 6 +-- src/spyglass/common/common_filter.py | 6 +-- src/spyglass/common/common_interval.py | 8 ++-- src/spyglass/common/common_lab.py | 1 + src/spyglass/decoding/v0/clusterless.py | 11 +++--- .../decoding/v0/dj_decoder_conversion.py | 37 ++++++++++--------- src/spyglass/decoding/v0/sorted_spikes.py | 1 + .../decoding/v0/visualization_2D_view.py | 2 +- src/spyglass/decoding/v1/clusterless.py | 6 +-- .../decoding/v1/dj_decoder_conversion.py | 1 - src/spyglass/decoding/v1/sorted_spikes.py | 6 +-- .../position/v1/position_dlc_orient.py | 18 ++++----- .../v1/position_dlc_pose_estimation.py | 20 +++++----- .../position/v1/position_dlc_position.py | 20 +++++----- src/spyglass/sharing/sharing_kachery.py | 6 +-- .../prepare_spikesortingview_data.py | 18 ++++----- src/spyglass/spikesorting/v1/recording.py | 6 +-- src/spyglass/utils/dj_helper_fn.py | 9 +++-- src/spyglass/utils/dj_merge_tables.py | 8 ++-- src/spyglass/utils/dj_mixin.py | 4 +- src/spyglass/utils/logging.py | 1 + src/spyglass/utils/nwb_helper_fn.py | 8 ++-- 25 files changed, 115 insertions(+), 104 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7643f74c2..0d7b19203 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,7 +77,7 @@ repos: - tomli - repo: https://github.com/ambv/black - rev: 23.11.0 + rev: 24.1.1 hooks: - id: black language_version: python3.9 diff --git a/config/dj_config.py b/config/dj_config.py index 55fd8a9ad..6b25a5d57 100755 --- a/config/dj_config.py +++ b/config/dj_config.py @@ -15,9 +15,7 @@ def main(*args): save_method = ( "local" if filename == "dj_local_conf.json" - else "global" - if filename is None - else "custom" + else "global" if filename is None else "custom" ) config.save_dj_config( diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index ed9673ecb..d7d4759fb 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -191,10 +191,12 @@ def _get_column_names(rp, pos_id): INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) n_pos_dims = rp.data.shape[1] column_names = [ - col # use existing columns if already numbered - if "1" in rp.description or "2" in rp.description - # else number them by id - else col + str(pos_id + INDEX_ADJUST) + ( + col # use existing columns if already numbered + if "1" in rp.description or "2" in rp.description + # else number them by id + else col + str(pos_id + INDEX_ADJUST) + ) for col in rp.description.split(", ") ] if len(column_names) != n_pos_dims: diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 2dd03c822..96fa11d44 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -476,9 +476,9 @@ def __read_ndx_probe_data( { "probe_id": nwb_probe_obj.probe_type, "probe_type": nwb_probe_obj.probe_type, - "contact_side_numbering": "True" - if nwb_probe_obj.contact_side_numbering - else "False", + "contact_side_numbering": ( + "True" if nwb_probe_obj.contact_side_numbering else "False" + ), } ) # go through the shanks and add each one to the Shank table diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index 988266d0d..59870f266 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -500,9 +500,9 @@ def filter_data( for ii, (start, stop) in enumerate(indices): extracted_ts = timestamps[start:stop:decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts + new_timestamps[ts_offset : ts_offset + len(extracted_ts)] = ( + extracted_ts + ) ts_offset += len(extracted_ts) # finally ready to filter data! diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index d754261fc..24b143ad6 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -56,9 +56,11 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): for _, epoch_data in epochs.iterrows(): epoch_dict = { "nwb_file_name": nwb_file_name, - "interval_list_name": epoch_data.tags[0] - if epoch_data.tags - else f"interval_{epoch_data[0]}", + "interval_list_name": ( + epoch_data.tags[0] + if epoch_data.tags + else f"interval_{epoch_data[0]}" + ), "valid_times": np.asarray( [[epoch_data.start_time, epoch_data.stop_time]] ), diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 177fc4424..ca9d4359a 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -1,4 +1,5 @@ """Schema for institution, lab team/name/members. Session-independent.""" + import datajoint as dj from spyglass.utils import SpyglassMixin, logger diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index f6fd9df37..e5577225d 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -6,6 +6,7 @@ [1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). """ + import os import shutil import uuid @@ -654,11 +655,11 @@ def make(self, key): key["nwb_file_name"] ) - key[ - "multiunit_firing_rate_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=multiunit_firing_rate.reset_index(), + key["multiunit_firing_rate_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=multiunit_firing_rate.reset_index(), + ) ) nwb_analysis_file.add( diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index 73566f24a..edcb0d637 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -1,5 +1,6 @@ """Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.""" + from spyglass.utils import logger try: @@ -116,17 +117,17 @@ def restore_classes(params: dict) -> dict: _convert_env_dict(env_params) for env_params in params["classifier_params"]["environments"] ] - params["classifier_params"][ - "discrete_transition_type" - ] = _convert_dict_to_class( - params["classifier_params"]["discrete_transition_type"], - discrete_state_transition_types, + params["classifier_params"]["discrete_transition_type"] = ( + _convert_dict_to_class( + params["classifier_params"]["discrete_transition_type"], + discrete_state_transition_types, + ) ) - params["classifier_params"][ - "initial_conditions_type" - ] = _convert_dict_to_class( - params["classifier_params"]["initial_conditions_type"], - initial_conditions_types, + params["classifier_params"]["initial_conditions_type"] = ( + _convert_dict_to_class( + params["classifier_params"]["initial_conditions_type"], + initial_conditions_types, + ) ) if params["classifier_params"].get("observation_models"): @@ -176,10 +177,10 @@ def convert_classes_to_dict(key: dict) -> dict: key["classifier_params"]["environments"] ) ] - key["classifier_params"][ - "continuous_transition_types" - ] = _convert_transitions_to_dict( - key["classifier_params"]["continuous_transition_types"] + key["classifier_params"]["continuous_transition_types"] = ( + _convert_transitions_to_dict( + key["classifier_params"]["continuous_transition_types"] + ) ) key["classifier_params"]["discrete_transition_type"] = _to_dict( key["classifier_params"]["discrete_transition_type"] @@ -194,10 +195,10 @@ def convert_classes_to_dict(key: dict) -> dict: ] try: - key["classifier_params"][ - "clusterless_algorithm_params" - ] = _convert_algorithm_params( - key["classifier_params"]["clusterless_algorithm_params"] + key["classifier_params"]["clusterless_algorithm_params"] = ( + _convert_algorithm_params( + key["classifier_params"]["clusterless_algorithm_params"] + ) ) except KeyError: pass diff --git a/src/spyglass/decoding/v0/sorted_spikes.py b/src/spyglass/decoding/v0/sorted_spikes.py index abe7ec207..acfc501cf 100644 --- a/src/spyglass/decoding/v0/sorted_spikes.py +++ b/src/spyglass/decoding/v0/sorted_spikes.py @@ -7,6 +7,7 @@ speeds. eLife 10, e64505 (2021). """ + import pprint import datajoint as dj diff --git a/src/spyglass/decoding/v0/visualization_2D_view.py b/src/spyglass/decoding/v0/visualization_2D_view.py index 52338ea78..14dcd204c 100644 --- a/src/spyglass/decoding/v0/visualization_2D_view.py +++ b/src/spyglass/decoding/v0/visualization_2D_view.py @@ -38,7 +38,7 @@ def create_static_track_animation( "xmin": np.min(ul_corners[0]), "xmax": np.max(ul_corners[0]) + track_rect_width, "ymin": np.min(ul_corners[1]), - "ymax": np.max(ul_corners[1]) + track_rect_height + "ymax": np.max(ul_corners[1]) + track_rect_height, # Speed: should this be displayed? # TODO: Better approach for accommodating further data streams } diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 3b179d7ee..1751ff8eb 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -240,9 +240,9 @@ def make(self, key): vars(classifier).get("discrete_transition_coefficients_") is not None ): - results[ - "discrete_transition_coefficients" - ] = classifier.discrete_transition_coefficients_ + results["discrete_transition_coefficients"] = ( + classifier.discrete_transition_coefficients_ + ) # Insert results # in future use https://github.com/rly/ndx-xarray and analysis nwb file? diff --git a/src/spyglass/decoding/v1/dj_decoder_conversion.py b/src/spyglass/decoding/v1/dj_decoder_conversion.py index 2795f8be9..c52c95a72 100644 --- a/src/spyglass/decoding/v1/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v1/dj_decoder_conversion.py @@ -1,7 +1,6 @@ """Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.""" - import copy import datajoint as dj diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 9f968d768..40041691a 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -232,9 +232,9 @@ def make(self, key): vars(classifier).get("discrete_transition_coefficients_") is not None ): - results[ - "discrete_transition_coefficients" - ] = classifier.discrete_transition_coefficients_ + results["discrete_transition_coefficients"] = ( + classifier.discrete_transition_coefficients_ + ) # Insert results # in future use https://github.com/rly/ndx-xarray and analysis nwb file? diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index 421e5330e..9b226d1a0 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -241,15 +241,15 @@ def interp_orientation(orientation, spans_to_interp, **kwargs): # TODO: add parameters to refine interpolation for ind, (span_start, span_stop) in enumerate(spans_to_interp): if (span_stop + 1) >= len(orientation): - orientation.loc[ - idx[span_start:span_stop], idx["orientation"] - ] = np.nan + orientation.loc[idx[span_start:span_stop], idx["orientation"]] = ( + np.nan + ) print(f"ind: {ind} has no endpoint with which to interpolate") continue if span_start < 1: - orientation.loc[ - idx[span_start:span_stop], idx["orientation"] - ] = np.nan + orientation.loc[idx[span_start:span_stop], idx["orientation"]] = ( + np.nan + ) print(f"ind: {ind} has no startpoint with which to interpolate") continue orient = [ @@ -263,7 +263,7 @@ def interp_orientation(orientation, spans_to_interp, **kwargs): xp=[start_time, stop_time], fp=[orient[0], orient[-1]], ) - orientation.loc[ - idx[start_time:stop_time], idx["orientation"] - ] = orientnew + orientation.loc[idx[start_time:stop_time], idx["orientation"]] = ( + orientnew + ) return orientation diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 63603932c..500f888b5 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -309,17 +309,17 @@ def make(self, key): description="video_frame_ind", ) nwb_analysis_file = AnalysisNwbfile() - key[ - "dlc_pose_estimation_position_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=position, + key["dlc_pose_estimation_position_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=position, + ) ) - key[ - "dlc_pose_estimation_likelihood_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=likelihood, + key["dlc_pose_estimation_likelihood_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=likelihood, + ) ) nwb_analysis_file.add( nwb_file_name=key["nwb_file_name"], diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 0e1ae4ef5..0916115e5 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -248,17 +248,17 @@ def make(self, key): comments="no comments", description="video_frame_ind", ) - key[ - "dlc_smooth_interp_position_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=position, + key["dlc_smooth_interp_position_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=position, + ) ) - key[ - "dlc_smooth_interp_info_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=video_frame_ind, + key["dlc_smooth_interp_info_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=video_frame_ind, + ) ) nwb_analysis_file.add( nwb_file_name=key["nwb_file_name"], diff --git a/src/spyglass/sharing/sharing_kachery.py b/src/spyglass/sharing/sharing_kachery.py index e3b9111ec..5aa4ebe56 100644 --- a/src/spyglass/sharing/sharing_kachery.py +++ b/src/spyglass/sharing/sharing_kachery.py @@ -105,9 +105,9 @@ def set_resource_url(key: dict): def reset_resource_url(): KacheryZone.reset_zone() if default_kachery_resource_url is not None: - os.environ[ - kachery_resource_url_envar - ] = default_kachery_resource_url + os.environ[kachery_resource_url_envar] = ( + default_kachery_resource_url + ) @schema diff --git a/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py b/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py index 46138b696..c43031225 100644 --- a/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py +++ b/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py @@ -102,16 +102,16 @@ def prepare_spikesortingview_data( channel_neighborhood_size=channel_neighborhood_size, ) if len(spike_train) >= 10: - unit_peak_channel_ids[ - str(unit_id) - ] = peak_channel_id + unit_peak_channel_ids[str(unit_id)] = ( + peak_channel_id + ) else: - fallback_unit_peak_channel_ids[ - str(unit_id) - ] = peak_channel_id - unit_channel_neighborhoods[ - str(unit_id) - ] = channel_neighborhood + fallback_unit_peak_channel_ids[str(unit_id)] = ( + peak_channel_id + ) + unit_channel_neighborhoods[str(unit_id)] = ( + channel_neighborhood + ) for unit_id in unit_ids: peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None) if peak_channel_id is None: diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index d795c8fe3..996611d9a 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -318,9 +318,9 @@ def _get_recording_timestamps(recording): timestamps = np.zeros((total_frames,)) for i in range(recording.get_num_segments()): - timestamps[ - cumsum_frames[i] : cumsum_frames[i + 1] - ] = recording.get_times(segment_index=i) + timestamps[cumsum_frames[i] : cumsum_frames[i + 1]] = ( + recording.get_times(segment_index=i) + ) else: timestamps = recording.get_times() return timestamps diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 390bb2add..4a0495778 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -1,4 +1,5 @@ """Helper functions for manipulating information from DataJoint fetch calls.""" + import inspect import os from typing import Type @@ -193,9 +194,11 @@ def get_child_tables(table): return [ dj.FreeTable( table.connection, - s - if not s.isdigit() - else next(iter(table.connection.dependencies.children(s))), + ( + s + if not s.isdigit() + else next(iter(table.connection.dependencies.children(s))) + ), ) for s in table.children() ] diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index c37122c70..0e0681782 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -591,9 +591,11 @@ def merge_get_part( ) parts = [ - getattr(cls, source)().restrict(restriction) - if restrict_part # Re-apply restriction or don't - else getattr(cls, source)() + ( + getattr(cls, source)().restrict(restriction) + if restrict_part # Re-apply restriction or don't + else getattr(cls, source)() + ) for source in sources ] if join_master: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 490274fe0..00163b605 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -86,9 +86,7 @@ def _nwb_table_tuple(self): self._nwb_table_resolved = ( AnalysisNwbfile if "-> AnalysisNwbfile" in self.definition - else Nwbfile - if "-> Nwbfile" in self.definition - else None + else Nwbfile if "-> Nwbfile" in self.definition else None ) if getattr(self, "_nwb_table_resolved", None) is None: diff --git a/src/spyglass/utils/logging.py b/src/spyglass/utils/logging.py index e16706f45..1771a160f 100644 --- a/src/spyglass/utils/logging.py +++ b/src/spyglass/utils/logging.py @@ -1,4 +1,5 @@ """Logging configuration based on datajoint/logging.py""" + import logging import sys diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 6b7947b2d..a5b184635 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -383,9 +383,11 @@ def get_electrode_indices(nwb_object, electrode_ids): # that if it's there and invalid_electrode_index if not. return [ - selected_elect_ids.index(elect_id) - if elect_id in selected_elect_ids - else invalid_electrode_index + ( + selected_elect_ids.index(elect_id) + if elect_id in selected_elect_ids + else invalid_electrode_index + ) for elect_id in electrode_ids ]