Skip to content

Commit

Permalink
Add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Oct 5, 2023
1 parent 652bcee commit b1cd74c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
47 changes: 38 additions & 9 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,19 @@ def _write_sorting_to_nwb_with_curation(
load_namespaces=True,
) as io:
nwbf = io.read()

# write sorting to the nwb file
for unit_id in unit_ids:
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=spike_times,
id=unit_id,
)
# add labels, merge groups, metrics
if labels is not None:
label_values = []
for unit_id in unit_ids:
if unit_id not in labels:
label_values.append("")
label_values.append([""])
else:
label_values.append(labels[unit_id])
nwbf.add_unit_column(
Expand Down Expand Up @@ -297,13 +303,6 @@ def _write_sorting_to_nwb_with_curation(
data=metric_values,
)

# write sorting to the nwb file
for unit_id in unit_ids:
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=spike_times,
id=unit_id,
)
units_object_id = nwbf.units.object_id
io.write(nwbf)
return analysis_nwb_file, units_object_id
Expand Down Expand Up @@ -333,6 +332,23 @@ def _union_intersecting_lists(lists):


def _list_to_merge_dict(lists_of_strings, target_strings):
"""Converts a list of merge groups to a dict.
The keys of the dict (unit ids) are provided separately in case
the merge groups do not contain all the unit ids.
Example: [[1,2,3],[4,5]], [1,2,3,4,5,6] -> {1: [2, 3], 2:[1,3], 3:[1,2] 4: [5], 5: [4], 6: []}
Parameters
----------
lists_of_strings : _type_
_description_
target_strings : _type_
_description_
Returns
-------
_type_
_description_
"""
lists_of_strings = _union_intersecting_lists(lists_of_strings)
result = {string: [] for string in target_strings}

Expand All @@ -357,6 +373,19 @@ def _reverse_associations(assoc_dict):


def _merge_dict_to_list(merge_groups):
"""Converts dict of merge groups to list of merge groups.
Example: {1: [2, 3], 4: [5]} -> [[1, 2, 3], [4, 5]]
Parameters
----------
merge_groups : _type_
_description_
Returns
-------
_type_
_description_
"""
units_to_merge = _union_intersecting_lists(
_reverse_associations(merge_groups)
)
Expand Down
14 changes: 8 additions & 6 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def _compute_labels(
("<", 1, ["noise"])]
}
This indicates that units with values of the "snr" quality metric
greater than 5 should be given the labels "good" and "mua".
greater than 1 should be given the labels "good" and "mua" and values
less than 1 should be given the label "noise".
Returns
-------
Expand All @@ -309,12 +310,14 @@ def _compute_labels(
Warning(f"{metric} not found in quality metrics; skipping")
else:
for condition in label_param[metric]:
assert len(condition) == 3, f"Condition {condition} must be of length 3"
compare = _comparison_to_function[label_param[metric][0]]
assert (
len(condition) == 3
), f"Condition {condition} must be of length 3"
compare = _comparison_to_function[condition[0]]
for unit_id in unit_ids:
if compare(
metrics[metric][unit_id],
label_param[metric][1],
condition[1],
):
labels[unit_id].extend(label_param[metric][2])
return labels
Expand All @@ -324,8 +327,7 @@ def _compute_merge_groups(
metrics: Dict[str, Dict[str, Union[float, List[float]]]],
merge_param: Dict[str, List[Any]],
) -> Dict[str, List[str]]:
"""Identifies units to be merged based on the metrics and
merge parameters.
"""Identifies units to be merged based on the metrics and merge parameters.
Parameters
---------
Expand Down

0 comments on commit b1cd74c

Please sign in to comment.