From 9765f345edc788eb37c923be5142184c3a65a9bb Mon Sep 17 00:00:00 2001 From: Kyu Hyun Lee Date: Fri, 6 Oct 2023 16:58:30 -0700 Subject: [PATCH] Add comments --- src/spyglass/spikesorting/v1/curation.py | 3 ++- .../spikesorting/v1/figurl_curation.py | 24 +++++++++++++++++-- .../spikesorting/v1/metric_curation.py | 9 ++++--- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index 0727e6798..3daaa9df2 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -19,8 +19,9 @@ @schema class Curation(dj.Manual): definition = """ + # Curation of a SpikeSorting. Use `insert_curation` to insert rows if possible. -> SpikeSorting - curation_id=0: int # a number corresponding to the index of this curation + curation_id=0: int --- parent_curation_id=-1: int -> AnalysisNwbfile diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py index dbfc26c95..7f8144a62 100644 --- a/src/spyglass/spikesorting/v1/figurl_curation.py +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -24,6 +24,7 @@ class FigURLCurationSelection(dj.Manual): metrics_figurl: longblob # metrics to display in the figURL """ + @staticmethod def generate_curation_uri(key: Dict) -> str: """Generates a kachery-cloud URI containing curation info from a row in Curation table @@ -44,12 +45,20 @@ def generate_curation_uri(key: Dict) -> str: unit_ids = nwb_sorting["id"][:] labels = nwb_sorting["labels"][:] merge_groups = nwb_sorting["merge_groups"][:] + + unit_ids = [str(unit_id) for unit_id in unit_ids] + if labels: labels_dict = dict(zip(unit_ids, labels)) else: labels_dict = {} + if merge_groups: merge_groups_list = _merge_dict_to_list(merge_groups) + merge_groups_list = [ + [str(unit_id) for unit_id in merge_group] + for merge_group in merge_groups_list + ] else: merge_groups_list = [] @@ -70,7 +79,7 @@ class FigURLCuration(dj.Computed): url: varchar(1000) """ - def make(self, key: Dict): + def make(self, key: dict): # FETCH sorting_analysis_file_name = (Curation & key).fetch1( "analysis_file_name" @@ -101,6 +110,8 @@ def make(self, key: Dict): unit_metrics = _reformat_metrics(metric_dict) + # TODO: figure out a way to specify the similarity metrics + # Generate the figURL key["url"] = _generate_figurl( R=recording, @@ -114,6 +125,14 @@ def make(self, key: Dict): # INSERT self.insert1(key, skip_duplicates=True) + @classmethod + def get_labels(cls): + return NotImplementedError + + @classmethod + def get_merge_groups(cls): + return NotImplementedError + def _generate_figurl( R: si.BaseRecording, @@ -143,6 +162,7 @@ def _generate_figurl( max_num_snippets_per_segment=max_num_snippets_per_segment, channel_neighborhood_size=channel_neighborhood_size, ) + # create a fake unit similarity matrix (for future reference) # similarity_scores = [] # for u1 in X.unit_ids: @@ -154,7 +174,7 @@ def _generate_figurl( # similarity=similarity_matrix[(X.unit_ids==u1),(X.unit_ids==u2)] # ) # ) - # Create the similarity matrix view + # # Create the similarity matrix view # unit_similarity_matrix_view = vv.UnitSimilarityMatrix( # unit_ids=X.unit_ids, # similarity_scores=similarity_scores diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index d1cc81f05..baa66884b 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -210,6 +210,7 @@ def make(self, key): os.mkdir(waveforms_dir) except FileExistsError: pass + print("Extracting waveforms...") waveforms = si.extract_waveforms( recording=recording, sorting=sorting, @@ -217,16 +218,18 @@ def make(self, key): **waveform_param, ) # compute metrics + print("Computing metrics...") metrics = {} for metric_name, metric_param_dict in metric_param.items(): metrics[metric_name] = self._compute_metric( nwb_file_name, waveforms, metric_name, **metric_param_dict ) - # generate labels and merge groups - labels = self._compute_labels(metrics, label_param) + print("Applying curation...") + labels = self._compute_labels(metrics, label_param) merge_groups = self._compute_merge_groups(metrics, merge_param) - # save everything in NWB + + print("Saving to NWB...") ( key["analysis_file_name"], key["object_id"],