Skip to content

Commit

Permalink
Merge branch 'spikesorting_v1' of github.com:khl02007/nwb_datajoint i…
Browse files Browse the repository at this point in the history
…nto spikesorting_v1
  • Loading branch information
khl02007 committed Dec 8, 2023
2 parents f78e96a + 996a8de commit 3592f0d
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 45 deletions.
9 changes: 4 additions & 5 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,10 @@ def _check_artifact_thresholds(
ValueError: if signal thresholds are negative
"""
# amplitude or zscore thresholds should be negative, as they are applied to an absolute signal
signal_thresholds = [
t for t in [amplitude_thresh_uV, zscore_thresh] if t is not None
]
for t in signal_thresholds:
if t < 0:
def is_negative(value):
return value < 0 if value is not None else False

if is_negative(amplitude_thresh_uV) or is_negative(zscore_thresh):
raise ValueError(
"Amplitude and Z-Score thresholds must be >= 0, or None"
)
Expand Down
55 changes: 26 additions & 29 deletions src/spyglass/spikesorting/v1/figurl_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def insert_selection(cls, key: dict):

@staticmethod
def generate_curation_uri(key: Dict) -> str:
"""Generates a kachery-cloud URI containing curation info from a row in CurationV1 table
"""Generates a kachery-cloud URI from a row in CurationV1 table
Parameters
----------
Expand All @@ -82,30 +82,29 @@ def generate_curation_uri(key: Dict) -> str:

unit_ids = [str(unit_id) for unit_id in unit_ids]

if labels:
labels_dict = {
unit_id: list(label) for unit_id, label in zip(unit_ids, labels)
}
else:
labels_dict = {}
labels_dict = (
{unit_id: list(label) for unit_id, label in zip(unit_ids, labels)}
if labels
else {}
)

if merge_groups:
merge_groups_dict = dict(zip(unit_ids, merge_groups))
merge_groups_list = _merge_dict_to_list(merge_groups_dict)
merge_groups_list = [
merge_groups_list = (
[
[str(unit_id) for unit_id in merge_group]
for merge_group in merge_groups_list
for merge_group in _merge_dict_to_list(
dict(zip(unit_ids, merge_groups))
)
]
else:
merge_groups_list = []

curation_dict = {
"labelsByUnit": labels_dict,
"mergeGroups": merge_groups_list,
}
curation_uri = kcl.store_json(curation_dict)
if merge_groups
else []
)

return curation_uri
return kcl.store_json(
{
"labelsByUnit": labels_dict,
"mergeGroups": merge_groups_list,
}
)


@schema
Expand Down Expand Up @@ -279,18 +278,16 @@ def _generate_figurl(


def _reformat_metrics(metrics: Dict[str, Dict[str, float]]) -> List[Dict]:
for metric_name in metrics:
metrics[metric_name] = {
str(unit_id): metric_value
for unit_id, metric_value in metrics[metric_name].items()
}
new_external_metrics = [
return [
{
"name": metric_name,
"label": metric_name,
"tooltip": metric_name,
"data": metric,
"data": {
str(unit_id): metric_value
for unit_id, metric_value in metric.items()
},
}
for metric_name, metric in metrics.items()
]
return new_external_metrics

10 changes: 4 additions & 6 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,10 @@ def _write_metric_curation_to_nwb(
)
if metrics is not None:
for metric, metric_dict in metrics.items():
metric_values = []
for unit_id in unit_ids:
if unit_id not in metric_dict:
metric_values.append([])
else:
metric_values.append(metric_dict[unit_id])
metric_values = [
metric_dict[unit_id] if unit_id in metric_dict else []
for unit_id in unit_ids
]
nwbf.add_unit_column(
name=metric,
description=metric,
Expand Down
6 changes: 2 additions & 4 deletions src/spyglass/spikesorting/v1/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def get_peak_channel(
peak_sign=peak_sign,
**metric_params,
)
peak_channel = {key: int(val) for key, val in peak_channel_dict.items()}
return peak_channel
return {key: int(val) for key, val in peak_channel_dict.items()}


def get_num_spikes(waveform_extractor: si.WaveformExtractor, this_unit_id: str):
"""Computes the number of spikes for each unit."""
num_spikes = sq.compute_num_spikes(waveform_extractor)
return num_spikes[this_unit_id]
return sq.compute_num_spikes(waveform_extractor)[this_unit_id]
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def set_group_by_shank(
@schema
class SpikeSortingPreprocessingParameters(dj.Lookup):
definition = """
# Parameters for denoising (filtering and referencing/whitening) a recording prior to spike sorting.
# Parameters for denoising a recording prior to spike sorting.
preproc_param_name: varchar(200)
---
preproc_params: blob
Expand Down

0 comments on commit 3592f0d

Please sign in to comment.