From b1cd74c2f934c529ab0c332b5c0e4542f15897bd Mon Sep 17 00:00:00 2001 From: Kyu Hyun Lee Date: Thu, 5 Oct 2023 12:21:57 -0700 Subject: [PATCH] Add docstring --- src/spyglass/spikesorting/v1/curation.py | 47 +++++++++++++++---- .../spikesorting/v1/metric_curation.py | 14 +++--- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index d6bf5a8ba..857e5af37 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -262,13 +262,19 @@ def _write_sorting_to_nwb_with_curation( load_namespaces=True, ) as io: nwbf = io.read() - + # write sorting to the nwb file + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=spike_times, + id=unit_id, + ) # add labels, merge groups, metrics if labels is not None: label_values = [] for unit_id in unit_ids: if unit_id not in labels: - label_values.append("") + label_values.append([""]) else: label_values.append(labels[unit_id]) nwbf.add_unit_column( @@ -297,13 +303,6 @@ def _write_sorting_to_nwb_with_curation( data=metric_values, ) - # write sorting to the nwb file - for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id) - nwbf.add_unit( - spike_times=spike_times, - id=unit_id, - ) units_object_id = nwbf.units.object_id io.write(nwbf) return analysis_nwb_file, units_object_id @@ -333,6 +332,23 @@ def _union_intersecting_lists(lists): def _list_to_merge_dict(lists_of_strings, target_strings): + """Converts a list of merge groups to a dict. + The keys of the dict (unit ids) are provided separately in case + the merge groups do not contain all the unit ids. + Example: [[1,2,3],[4,5]], [1,2,3,4,5,6] -> {1: [2, 3], 2:[1,3], 3:[1,2] 4: [5], 5: [4], 6: []} + + Parameters + ---------- + lists_of_strings : _type_ + _description_ + target_strings : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ lists_of_strings = _union_intersecting_lists(lists_of_strings) result = {string: [] for string in target_strings} @@ -357,6 +373,19 @@ def _reverse_associations(assoc_dict): def _merge_dict_to_list(merge_groups): + """Converts dict of merge groups to list of merge groups. + Example: {1: [2, 3], 4: [5]} -> [[1, 2, 3], [4, 5]] + + Parameters + ---------- + merge_groups : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ units_to_merge = _union_intersecting_lists( _reverse_associations(merge_groups) ) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index 31aa5f0ac..1f88fd497 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -291,7 +291,8 @@ def _compute_labels( ("<", 1, ["noise"])] } This indicates that units with values of the "snr" quality metric - greater than 5 should be given the labels "good" and "mua". + greater than 1 should be given the labels "good" and "mua" and values + less than 1 should be given the label "noise". Returns ------- @@ -309,12 +310,14 @@ def _compute_labels( Warning(f"{metric} not found in quality metrics; skipping") else: for condition in label_param[metric]: - assert len(condition) == 3, f"Condition {condition} must be of length 3" - compare = _comparison_to_function[label_param[metric][0]] + assert ( + len(condition) == 3 + ), f"Condition {condition} must be of length 3" + compare = _comparison_to_function[condition[0]] for unit_id in unit_ids: if compare( metrics[metric][unit_id], - label_param[metric][1], + condition[1], ): labels[unit_id].extend(label_param[metric][2]) return labels @@ -324,8 +327,7 @@ def _compute_merge_groups( metrics: Dict[str, Dict[str, Union[float, List[float]]]], merge_param: Dict[str, List[Any]], ) -> Dict[str, List[str]]: - """Identifies units to be merged based on the metrics and - merge parameters. + """Identifies units to be merged based on the metrics and merge parameters. Parameters ---------