Skip to content

Commit

Permalink
Update data type
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Oct 1, 2023
1 parent 93935dd commit 3a7f8f1
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def insert_curation(
sorting_id: str,
parent_curation_id: int = -1,
labels: Union[None, Dict[str, List[str]]] = None,
merge_groups: Union[None, List[List[int]]] = None,
merge_groups: Union[None, List[List[str]]] = None,
apply_merge: bool = False,
metrics: Union[None, Dict[str, Dict[int, float]]] = None,
metrics: Union[None, Dict[str, Dict[str, float]]] = None,
description: str = "",
):
"""Given a sorting_id and the parent_sorting_id (and optional
Expand Down Expand Up @@ -186,19 +186,22 @@ def get_merged_sorting(key: dict) -> si.BaseSorting:
nwb_sorting = nwbfile.objects[object_id]
merge_groups = nwb_sorting["merge_groups"][:]
if merge_groups:
units_to_merge = _union_intersecting_lists(_reverse_associations(merge_groups))
units_to_merge = _union_intersecting_lists(
_reverse_associations(merge_groups)
)
units_to_merge = [lst for lst in units_to_merge if len(lst) >= 2]
return sc.MergeUnitsSorting(
parent_sorting=si_sorting, units_to_merge=units_to_merge
)
else:
return si_sorting


def _write_sorting_to_nwb_with_curation(
sorting_id: str,
labels: Union[None, Dict[str, List[str]]] = None,
merge_groups: Union[None, List[List[int]]] = None,
metrics: Union[None, Dict[str, Dict[int, float]]] = None,
merge_groups: Union[None, List[List[str]]] = None,
metrics: Union[None, Dict[str, Dict[str, float]]] = None,
apply_merge: bool = False,
):
"""Save sorting to NWB with curation information.
Expand Down Expand Up @@ -239,7 +242,7 @@ def _write_sorting_to_nwb_with_curation(
sorting = se.read_nwb_sorting(sorting_analysis_file_abs_path)
if apply_merge:
sorting = sc.MergeUnitsSorting(
parent_sorting=sorting, parent_sorting=merge_groups
parent_sorting=sorting, units_to_merge=merge_groups
)
merge_groups = None

Expand Down Expand Up @@ -269,7 +272,7 @@ def _write_sorting_to_nwb_with_curation(
data=label_values,
)
if merge_groups is not None:
merge_groups_dict = _extract_associations(merge_groups,unit_ids)
merge_groups_dict = _extract_associations(merge_groups, unit_ids)
nwbf.add_unit_column(
name="merge_groups",
description="merge groups",
Expand Down Expand Up @@ -323,6 +326,7 @@ def _union_intersecting_lists(lists):

return result


def _extract_associations(lists_of_strings, target_strings):
lists_of_strings = _union_intersecting_lists(lists_of_strings)
result = {string: [] for string in target_strings}
Expand All @@ -334,6 +338,7 @@ def _extract_associations(lists_of_strings, target_strings):

return result


def _reverse_associations(assoc_dict):
result = []

Expand Down

0 comments on commit 3a7f8f1

Please sign in to comment.