Skip to content

Commit

Permalink
Streamline metric
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Oct 2, 2023
1 parent e52aec7 commit d578e65
Showing 1 changed file with 43 additions and 104 deletions.
147 changes: 43 additions & 104 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import datajoint as dj
import numpy as np
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as sip
import spikeinterface.qualitymetrics as sq

from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.spikesorting.v1.metric_utils import (
get_num_spikes,
Expand Down Expand Up @@ -94,17 +92,8 @@ class MetricParameter(dj.Lookup):
---
metric_param: blob
"""
available_metrics = [
"snr",
"isi_violation",
"nn_isolation",
"nn_noise_overlap",
"peak_offset",
"peak_channel",
"num_spikes",
]

metric_default_params = {
metric_default_param_name = "franklab_default"
metric_default_param = {
"snr": {
"peak_sign": "neg",
"random_chunk_kwargs_dict": {
Expand Down Expand Up @@ -133,7 +122,7 @@ class MetricParameter(dj.Lookup):
"peak_channel": {"peak_sign": "neg"},
"num_spikes": {},
}
contents = [["franklab_default", metric_default_params]]
contents = [[metric_default_param_name, metric_default_param]]

@classmethod
def insert_default(cls):
Expand All @@ -142,83 +131,30 @@ def insert_default(cls):
skip_duplicates=True,
)

def get_metric_default_params(self, metric: str):
"Returns default params for the given metric"
return self.metric_default_params(metric)

def get_available_metrics(self):
@classmethod
def show_available_metrics(self):
for metric in _metric_name_to_func:
if metric in self.available_metrics:
metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
metric_string = ("{metric_name} : {metric_doc}").format(
metric_name=metric, metric_doc=metric_doc
)
print(metric_string + "\n")

# TODO
def _validate_metrics_list(self, key):
"""Checks whether a row to be inserted contains only the available metrics"""
# get available metrics list
# get metric list from key
# compare
return NotImplementedError
metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
print(f"{metric} : {metric_doc}\n")


@schema
class AutomaticCurationParameter(dj.Lookup):
definition = """
auto_curation_param_name: varchar(200) # name of this parameter set
auto_curation_param_name: varchar(200)
---
merge_params: blob # dictionary of params to merge units
label_params: blob # dictionary params to label units
merge_params: blob # dict of param to merge units
label_params: blob # dict of param to label units
"""

def insert1(self, key, **kwargs):
# validate the labels and then insert
# TODO: add validation for merge_params
for metric in key["label_params"]:
if metric not in _metric_name_to_func:
raise Exception(f"{metric} not in list of available metrics")
comparison_list = key["label_params"][metric]
if comparison_list[0] not in _comparison_to_function:
raise Exception(
f'{metric}: "{comparison_list[0]}" '
f"not in list of available comparisons"
)
if not isinstance(comparison_list[1], (int, float)):
raise Exception(
f"{metric}: {comparison_list[1]} is of type "
f"{type(comparison_list[1])} and not a number"
)
for label in comparison_list[2]:
if label not in valid_labels:
raise Exception(
f'{metric}: "{label}" '
f"not in list of valid labels: {valid_labels}"
)
super().insert1(key, **kwargs)

def insert_default(self):
# label_params parsing: Each key is the name of a metric,
# the contents are a three value list with the comparison, a value,
# and a list of labels to apply if the comparison is true
default_params = {
"auto_curation_params_name": "default",
"merge_params": {},
"label_params": {
"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]
},
}
self.insert1(default_params, skip_duplicates=True)
contents = [
["default", {}, {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]}],
["none", {}, {}],
]

# Second default parameter set for not applying any labels,
# or merges, but adding metrics
no_label_params = {
"auto_curation_params_name": "none",
"merge_params": {},
"label_params": {},
}
self.insert1(no_label_params, skip_duplicates=True)
@classmethod
def insert_default(cls):
cls.insert1(cls.contents, skip_duplicates=True)


@schema
Expand All @@ -237,42 +173,45 @@ class MetricCuration(dj.Computed):
metric_curation_id: varchar(32)
---
-> AutomaticCurationSelection
-> AnalysisNwbfile
object_id: varchar(40) # Object ID for the metrics in NWB file
metrics: longblob
labels: longblob
merge_groups: longblob
"""

def make(self, key):
metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path")
with open(metrics_path) as f:
quality_metrics = json.load(f)

# get the curation information and the curated sorting
parent_curation = (Curation & key).fetch(as_dict=True)[0]
parent_merge_groups = parent_curation["merge_groups"]
parent_labels = parent_curation["curation_labels"]
parent_curation_id = parent_curation["curation_id"]
parent_sorting = Curation.get_curated_sorting(key)

merge_params = (AutomaticCurationParameters & key).fetch1(
"merge_params"
)
merge_groups, units_merged = self.get_merge_groups(
parent_sorting, parent_merge_groups, quality_metrics, merge_params
)
# FETCH
# load recording and sorting

label_params = (AutomaticCurationParameters & key).fetch1(
"label_params"
# DO
# create uuid for this metric curation
# extract waveforms
waveforms = si.extract_waveforms(
recording=recording,
sorting=sorting,
folder=os.environ.get("SPYGLASS_TEMP_DIR"),
**waveform_params,
)
# compute metrics
params = (MetricParameters & key).fetch1("metric_params")
for metric_name, metric_params in params.items():
metric = self._compute_metric(
waveform_extractor, metric_name, **metric_params
)
qm[metric_name] = metric
# generate labels and merge groups
labels = self.get_labels(
parent_sorting, parent_labels, quality_metrics, label_params
)

# keep the quality metrics only if no merging occurred.
metrics = quality_metrics if not units_merged else None
merge_groups, units_merged = self.get_merge_groups(
parent_sorting, parent_merge_groups, quality_metrics, merge_params
)
# save everything in NWB

# INSERT

# insert this sorting into the CuratedSpikeSorting Table
# first remove keys that aren't part of the Sorting (the primary key of curation)
c_key = (SpikeSorting & key).fetch("KEY")[0]
curation_key = {item: key[item] for item in key if item in c_key}
key["auto_curation_key"] = Curation.insert_curation(
Expand Down

0 comments on commit d578e65

Please sign in to comment.