Skip to content

Commit

Permalink
Blackify 24.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 30, 2024
1 parent c5bf75b commit 6fa248b
Show file tree
Hide file tree
Showing 25 changed files with 115 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ repos:
- tomli

- repo: https://github.com/ambv/black
rev: 23.11.0
rev: 24.1.1
hooks:
- id: black
language_version: python3.9
4 changes: 1 addition & 3 deletions config/dj_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def main(*args):
save_method = (
"local"
if filename == "dj_local_conf.json"
else "global"
if filename is None
else "custom"
else "global" if filename is None else "custom"
)

config.save_dj_config(
Expand Down
10 changes: 6 additions & 4 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ def _get_column_names(rp, pos_id):
INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1)
n_pos_dims = rp.data.shape[1]
column_names = [
col # use existing columns if already numbered
if "1" in rp.description or "2" in rp.description
# else number them by id
else col + str(pos_id + INDEX_ADJUST)
(
col # use existing columns if already numbered
if "1" in rp.description or "2" in rp.description
# else number them by id
else col + str(pos_id + INDEX_ADJUST)
)
for col in rp.description.split(", ")
]
if len(column_names) != n_pos_dims:
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def __read_ndx_probe_data(
{
"probe_id": nwb_probe_obj.probe_type,
"probe_type": nwb_probe_obj.probe_type,
"contact_side_numbering": "True"
if nwb_probe_obj.contact_side_numbering
else "False",
"contact_side_numbering": (
"True" if nwb_probe_obj.contact_side_numbering else "False"
),
}
)
# go through the shanks and add each one to the Shank table
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/common/common_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ def filter_data(
for ii, (start, stop) in enumerate(indices):
extracted_ts = timestamps[start:stop:decimation]

new_timestamps[
ts_offset : ts_offset + len(extracted_ts)
] = extracted_ts
new_timestamps[ts_offset : ts_offset + len(extracted_ts)] = (
extracted_ts
)
ts_offset += len(extracted_ts)

# finally ready to filter data!
Expand Down
8 changes: 5 additions & 3 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name):
for _, epoch_data in epochs.iterrows():
epoch_dict = {
"nwb_file_name": nwb_file_name,
"interval_list_name": epoch_data.tags[0]
if epoch_data.tags
else f"interval_{epoch_data[0]}",
"interval_list_name": (
epoch_data.tags[0]
if epoch_data.tags
else f"interval_{epoch_data[0]}"
),
"valid_times": np.asarray(
[[epoch_data.start_time, epoch_data.stop_time]]
),
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Schema for institution, lab team/name/members. Session-independent."""

import datajoint as dj

from spyglass.utils import SpyglassMixin, logger
Expand Down
11 changes: 6 additions & 5 deletions src/spyglass/decoding/v0/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world
speeds. eLife 10, e64505 (2021).
"""

import os
import shutil
import uuid
Expand Down Expand Up @@ -654,11 +655,11 @@ def make(self, key):
key["nwb_file_name"]
)

key[
"multiunit_firing_rate_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=multiunit_firing_rate.reset_index(),
key["multiunit_firing_rate_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=multiunit_firing_rate.reset_index(),
)
)

nwb_analysis_file.add(
Expand Down
37 changes: 19 additions & 18 deletions src/spyglass/decoding/v0/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Converts decoder classes into dictionaries and dictionaries into classes
so that datajoint can store them in tables."""

from spyglass.utils import logger

try:
Expand Down Expand Up @@ -116,17 +117,17 @@ def restore_classes(params: dict) -> dict:
_convert_env_dict(env_params)
for env_params in params["classifier_params"]["environments"]
]
params["classifier_params"][
"discrete_transition_type"
] = _convert_dict_to_class(
params["classifier_params"]["discrete_transition_type"],
discrete_state_transition_types,
params["classifier_params"]["discrete_transition_type"] = (
_convert_dict_to_class(
params["classifier_params"]["discrete_transition_type"],
discrete_state_transition_types,
)
)
params["classifier_params"][
"initial_conditions_type"
] = _convert_dict_to_class(
params["classifier_params"]["initial_conditions_type"],
initial_conditions_types,
params["classifier_params"]["initial_conditions_type"] = (
_convert_dict_to_class(
params["classifier_params"]["initial_conditions_type"],
initial_conditions_types,
)
)

if params["classifier_params"].get("observation_models"):
Expand Down Expand Up @@ -176,10 +177,10 @@ def convert_classes_to_dict(key: dict) -> dict:
key["classifier_params"]["environments"]
)
]
key["classifier_params"][
"continuous_transition_types"
] = _convert_transitions_to_dict(
key["classifier_params"]["continuous_transition_types"]
key["classifier_params"]["continuous_transition_types"] = (
_convert_transitions_to_dict(
key["classifier_params"]["continuous_transition_types"]
)
)
key["classifier_params"]["discrete_transition_type"] = _to_dict(
key["classifier_params"]["discrete_transition_type"]
Expand All @@ -194,10 +195,10 @@ def convert_classes_to_dict(key: dict) -> dict:
]

try:
key["classifier_params"][
"clusterless_algorithm_params"
] = _convert_algorithm_params(
key["classifier_params"]["clusterless_algorithm_params"]
key["classifier_params"]["clusterless_algorithm_params"] = (
_convert_algorithm_params(
key["classifier_params"]["clusterless_algorithm_params"]
)
)
except KeyError:
pass
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/decoding/v0/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
speeds. eLife 10, e64505 (2021).
"""

import pprint

import datajoint as dj
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/decoding/v0/visualization_2D_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_static_track_animation(
"xmin": np.min(ul_corners[0]),
"xmax": np.max(ul_corners[0]) + track_rect_width,
"ymin": np.min(ul_corners[1]),
"ymax": np.max(ul_corners[1]) + track_rect_height
"ymax": np.max(ul_corners[1]) + track_rect_height,
# Speed: should this be displayed?
# TODO: Better approach for accommodating further data streams
}
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def make(self, key):
vars(classifier).get("discrete_transition_coefficients_")
is not None
):
results[
"discrete_transition_coefficients"
] = classifier.discrete_transition_coefficients_
results["discrete_transition_coefficients"] = (
classifier.discrete_transition_coefficients_
)

# Insert results
# in future use https://github.com/rly/ndx-xarray and analysis nwb file?
Expand Down
1 change: 0 additions & 1 deletion src/spyglass/decoding/v1/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Converts decoder classes into dictionaries and dictionaries into classes
so that datajoint can store them in tables."""


import copy

import datajoint as dj
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def make(self, key):
vars(classifier).get("discrete_transition_coefficients_")
is not None
):
results[
"discrete_transition_coefficients"
] = classifier.discrete_transition_coefficients_
results["discrete_transition_coefficients"] = (
classifier.discrete_transition_coefficients_
)

# Insert results
# in future use https://github.com/rly/ndx-xarray and analysis nwb file?
Expand Down
18 changes: 9 additions & 9 deletions src/spyglass/position/v1/position_dlc_orient.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def interp_orientation(orientation, spans_to_interp, **kwargs):
# TODO: add parameters to refine interpolation
for ind, (span_start, span_stop) in enumerate(spans_to_interp):
if (span_stop + 1) >= len(orientation):
orientation.loc[
idx[span_start:span_stop], idx["orientation"]
] = np.nan
orientation.loc[idx[span_start:span_stop], idx["orientation"]] = (
np.nan
)
print(f"ind: {ind} has no endpoint with which to interpolate")
continue
if span_start < 1:
orientation.loc[
idx[span_start:span_stop], idx["orientation"]
] = np.nan
orientation.loc[idx[span_start:span_stop], idx["orientation"]] = (
np.nan
)
print(f"ind: {ind} has no startpoint with which to interpolate")
continue
orient = [
Expand All @@ -263,7 +263,7 @@ def interp_orientation(orientation, spans_to_interp, **kwargs):
xp=[start_time, stop_time],
fp=[orient[0], orient[-1]],
)
orientation.loc[
idx[start_time:stop_time], idx["orientation"]
] = orientnew
orientation.loc[idx[start_time:stop_time], idx["orientation"]] = (
orientnew
)
return orientation
20 changes: 10 additions & 10 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,17 @@ def make(self, key):
description="video_frame_ind",
)
nwb_analysis_file = AnalysisNwbfile()
key[
"dlc_pose_estimation_position_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
key["dlc_pose_estimation_position_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
)
)
key[
"dlc_pose_estimation_likelihood_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=likelihood,
key["dlc_pose_estimation_likelihood_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=likelihood,
)
)
nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
Expand Down
20 changes: 10 additions & 10 deletions src/spyglass/position/v1/position_dlc_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ def make(self, key):
comments="no comments",
description="video_frame_ind",
)
key[
"dlc_smooth_interp_position_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
key["dlc_smooth_interp_position_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
)
)
key[
"dlc_smooth_interp_info_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=video_frame_ind,
key["dlc_smooth_interp_info_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=video_frame_ind,
)
)
nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/sharing/sharing_kachery.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def set_resource_url(key: dict):
def reset_resource_url():
KacheryZone.reset_zone()
if default_kachery_resource_url is not None:
os.environ[
kachery_resource_url_envar
] = default_kachery_resource_url
os.environ[kachery_resource_url_envar] = (
default_kachery_resource_url
)


@schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,16 @@ def prepare_spikesortingview_data(
channel_neighborhood_size=channel_neighborhood_size,
)
if len(spike_train) >= 10:
unit_peak_channel_ids[
str(unit_id)
] = peak_channel_id
unit_peak_channel_ids[str(unit_id)] = (
peak_channel_id
)
else:
fallback_unit_peak_channel_ids[
str(unit_id)
] = peak_channel_id
unit_channel_neighborhoods[
str(unit_id)
] = channel_neighborhood
fallback_unit_peak_channel_ids[str(unit_id)] = (
peak_channel_id
)
unit_channel_neighborhoods[str(unit_id)] = (
channel_neighborhood
)
for unit_id in unit_ids:
peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None)
if peak_channel_id is None:
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def _get_recording_timestamps(recording):

timestamps = np.zeros((total_frames,))
for i in range(recording.get_num_segments()):
timestamps[
cumsum_frames[i] : cumsum_frames[i + 1]
] = recording.get_times(segment_index=i)
timestamps[cumsum_frames[i] : cumsum_frames[i + 1]] = (
recording.get_times(segment_index=i)
)
else:
timestamps = recording.get_times()
return timestamps
Expand Down
9 changes: 6 additions & 3 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper functions for manipulating information from DataJoint fetch calls."""

import inspect
import os
from typing import Type
Expand Down Expand Up @@ -193,9 +194,11 @@ def get_child_tables(table):
return [
dj.FreeTable(
table.connection,
s
if not s.isdigit()
else next(iter(table.connection.dependencies.children(s))),
(
s
if not s.isdigit()
else next(iter(table.connection.dependencies.children(s)))
),
)
for s in table.children()
]
Loading

0 comments on commit 6fa248b

Please sign in to comment.