diff --git a/src/spyglass/spikesorting/__init__.py b/src/spyglass/spikesorting/__init__.py index c6cf9ae55..05c7e9948 100644 --- a/src/spyglass/spikesorting/__init__.py +++ b/src/spyglass/spikesorting/__init__.py @@ -1,3 +1,4 @@ +from .curation_figurl import CurationFigurl, CurationFigurlSelection from .sortingview import SortingviewWorkspace, SortingviewWorkspaceSelection from .spikesorting_artifact import ( ArtifactDetection, @@ -20,6 +21,10 @@ Waveforms, WaveformSelection, ) +from .spikesorting_populator import ( + SpikeSortingPipelineParameters, + spikesorting_pipeline_populator, +) from .spikesorting_recording import ( SortGroup, SortInterval, @@ -32,10 +37,3 @@ SpikeSorting, SpikeSortingSelection, ) - -from .curation_figurl import CurationFigurlSelection, CurationFigurl - -from .spikesorting_populator import ( - spikesorting_pipeline_populator, - SpikeSortingPipelineParameters, -) diff --git a/src/spyglass/spikesorting/spikesorting_populator.py b/src/spyglass/spikesorting/spikesorting_populator.py index 57ef1d076..32d5753b0 100644 --- a/src/spyglass/spikesorting/spikesorting_populator.py +++ b/src/spyglass/spikesorting/spikesorting_populator.py @@ -1,29 +1,30 @@ import datajoint as dj -from ..common import IntervalList, ElectrodeGroup -from .spikesorting_recording import ( - SortGroup, - SortInterval, - SpikeSortingRecordingSelection, - SpikeSortingRecording, -) + +from ..common import ElectrodeGroup, IntervalList +from .curation_figurl import CurationFigurl, CurationFigurlSelection from .spikesorting_artifact import ( - ArtifactDetectionSelection, ArtifactDetection, + ArtifactDetectionSelection, ArtifactRemovedIntervalList, ) -from .spikesorting_sorting import SpikeSortingSelection, SpikeSorting from .spikesorting_curation import ( + AutomaticCuration, + AutomaticCurationSelection, + CuratedSpikeSorting, + CuratedSpikeSortingSelection, Curation, - WaveformSelection, - Waveforms, MetricSelection, QualityMetrics, - AutomaticCurationSelection, - AutomaticCuration, - CuratedSpikeSortingSelection, - CuratedSpikeSorting, + Waveforms, + WaveformSelection, ) -from .curation_figurl import CurationFigurlSelection, CurationFigurl +from .spikesorting_recording import ( + SortGroup, + SortInterval, + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) +from .spikesorting_sorting import SpikeSorting, SpikeSortingSelection schema = dj.schema("spikesorting_sorting") @@ -46,8 +47,8 @@ class SpikeSortingPipelineParameters(dj.Manual): def spikesorting_pipeline_populator( nwb_file_name: str, team_name: str, - fig_url_repo: str, - interval_list_name: str, + fig_url_repo: str = None, + interval_list_name: str = None, sort_interval_name: str = None, pipeline_parameters_name: str = None, probe_restriction: dict = {}, @@ -59,7 +60,7 @@ def spikesorting_pipeline_populator( metric_params_name: str = "peak_offest_num_spikes_2", auto_curation_params_name: str = "mike_noise_03_offset_2_isi_0025_mua", ): - """Function top auomatically populate the spike sorting pipeline for a given epoch + """Automatically populate the spike sorting pipeline for a given epoch Parameters ---------- @@ -67,156 +68,139 @@ def spikesorting_pipeline_populator( Session ID team_name : str Which team to assign the spike sorting to - fig_url_repo : str - Whewre to store the curation figurl json files (e.x. 'gh://LorenFrankLab/sorting-curations/main/sambray/'), leave empty string to skip figurl + fig_url_repo : str, optional + Where to store the curation figurl json files (e.g., + 'gh://LorenFrankLab/sorting-curations/main/user/'). Default None to + skip figurl interval_list_name : str, - if sort_interval_name not provided, will create a sort interval for the given interval with the same name + if sort_interval_name not provided, will create a sort interval for the + given interval with the same name sort_interval_name : str, default None - if provided, will use the given sort interval, requires making this interval yourself + if provided, will use the given sort interval, requires making this + interval yourself pipeline_parameters_name : str, optional - If provided, will lookup pipeline parameters from the SpikeSortingPipelineParameters table, superceeds other values provided, by default None + If provided, will lookup pipeline parameters from the + SpikeSortingPipelineParameters table, supersedes other values provided, + by default None restrict_probe_type : dict, optional - Restricts analysis to sort groups with matching keys. Can use keys from the SortGroup and ElectrodeGroup Tables (e.g. electrode_group_name, probe_id, target_hemisphere), by default {} + Restricts analysis to sort groups with matching keys. Can use keys from + the SortGroup and ElectrodeGroup Tables (e.g. electrode_group_name, + probe_id, target_hemisphere), by default {} artifact_parameters : str, optional parameter set for artifact detection, by default "ampl_2000_prop_75" preproc_params_name : str, optional - parameter set for spikesorting recording, by default "franklab_tetrode_hippocampus" + parameter set for spikesorting recording, by default + "franklab_tetrode_hippocampus" sorter : str, optional which spikesorting algorithm to use, by default "mountainsort4" sorter_params_name : str, optional - parameters for the spike sorting algorithm, by default "franklab_tetrode_hippocampus_30KHz_tmp" + parameters for the spike sorting algorithm, by default + "franklab_tetrode_hippocampus_30KHz_tmp" waveform_params_name : str, optional - Parameters for spike waveform extraction. If empty string, will skip automatic curation steps, by default "default_whitened" + Parameters for spike waveform extraction. If empty string, will skip + automatic curation steps, by default "default_whitened" metric_params_name : str, optional - Parameters defining which QualityMetrics to calculate and how. If empty string, will skip automatic curation steps, by default "peak_offest_num_spikes_2" + Parameters defining which QualityMetrics to calculate and how. If empty + string, will skip automatic curation steps, by default + "peak_offest_num_spikes_2" auto_curation_params_name : str, optional - Thresholds applied to Quality metrics for automatic unit curation. If empty string, will skip automatic curation steps, by default "mike_noise_03_offset_2_isi_0025_mua" + Thresholds applied to Quality metrics for automatic unit curation. If + empty string, will skip automatic curation steps, by default + "mike_noise_03_offset_2_isi_0025_mua" """ - + nwbf_dict = dict(nwb_file_name=nwb_file_name) # Define pipeline parameters if pipeline_parameters_name is not None: print(f"Using pipeline parameters {pipeline_parameters_name}") - artifact_parameters = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("artifact_parameters") - preproc_params_name = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("preproc_params_name") - sorter = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("sorter") - sorter_params_name = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("sorter_params_name") - waveform_params_name = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("waveform_params_name") - metric_params_name = ( - SpikeSortingPipelineParameters - & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("metric_params_name") - auto_curation_params_name = ( + ( + artifact_parameters, + preproc_params_name, + sorter, + sorter_params_name, + waveform_params_name, + metric_params_name, + auto_curation_params_name, + ) = ( SpikeSortingPipelineParameters & {"pipeline_parameters_name": pipeline_parameters_name} - ).fetch1("auto_curation_params_name") - - ## Sorting - ## Sort groups - ## Sort intervals - ## Spike sorting recording - ## Artifact detection - ## Spike sorting + ).fetch1( + "artifact_parameters", + "preproc_params_name", + "sorter", + "sorter_params_name", + "waveform_params_name", + "metric_params_name", + "auto_curation_params_name", + ) - # make sort groups only if not currently available (don't overwrite existing ones!) - if len(SortGroup() & {"nwb_file_name": nwb_file_name}) == 0: + # make sort groups only if not currently available + # don't overwrite existing ones! + if not SortGroup() & nwbf_dict: print("Generating sort groups") SortGroup().set_group_by_shank(nwb_file_name) - # find desired sort group(s) for these settings - sort_group_id_list = ( - (SortGroup.SortGroupElectrode * ElectrodeGroup) - & {"nwb_file_name": nwb_file_name} - & probe_restriction - ).fetch("sort_group_id") # Define sort interval + interval_dict = dict(**nwbf_dict, interval_list_name=interval_list_name) + if sort_interval_name is not None: print(f"Using sort interval {sort_interval_name}") - if ( - len( - SortInterval() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ) - == 0 + if not ( + SortInterval + & nwbf_dict + & {"sort_interval_name": sort_interval_name} ): raise KeyError(f"Sort interval {sort_interval_name} not found") else: print(f"Generating sort interval from {interval_list_name}") - interval_list = ( - IntervalList - & { - "nwb_file_name": nwb_file_name, - "interval_list_name": interval_list_name, - } - ).fetch1("valid_times")[0] + interval_list = (IntervalList & interval_dict).fetch1("valid_times")[0] + sort_interval_name = interval_list_name sort_interval = interval_list + SortInterval.insert1( { - "nwb_file_name": nwb_file_name, + **nwbf_dict, "sort_interval_name": sort_interval_name, "sort_interval": sort_interval, }, skip_duplicates=True, ) + sort_dict = dict(**nwbf_dict, sort_interval_name=sort_interval_name) + + # find desired sort group(s) for these settings + sort_group_id_list = ( + (SortGroup.SortGroupElectrode * ElectrodeGroup) + & nwbf_dict + & probe_restriction + ).fetch("sort_group_id") + # make spike sorting recording print("Generating spike sorting recording") for sort_group_id in sort_group_id_list: ssr_key = dict( - nwb_file_name=nwb_file_name, + **sort_dict, sort_group_id=sort_group_id, # See SortGroup - sort_interval_name=sort_interval_name, # First N seconds above preproc_params_name=preproc_params_name, # See preproc_params interval_list_name=interval_list_name, team_name=team_name, ) SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) - ssr_pj = ( - SpikeSortingRecordingSelection() - & { - "nwb_file_name": nwb_file_name, - "interval_list_name": interval_list_name, - } - ).proj() - SpikeSortingRecording.populate([ssr_pj]) + + SpikeSortingRecording.populate(interval_dict) # Artifact detection print("Running artifact detection") - artifact_key_list = (ssr_pj).fetch("KEY") - for artifact_key in artifact_key_list: - artifact_key["artifact_params_name"] = artifact_parameters - ArtifactDetectionSelection().insert1(artifact_key, skip_duplicates=True) - - art_pj = ( - ArtifactDetectionSelection() - & { - "nwb_file_name": nwb_file_name, - "interval_list_name": interval_list_name, - } - ).proj() - ArtifactDetection.populate([art_pj]) + artifact_keys = [ + {**k, "artifact_params_name": artifact_parameters} + for k in (SpikeSortingRecordingSelection() & interval_dict).fetch("KEY") + ] + ArtifactDetectionSelection().insert(artifact_keys, skip_duplicates=True) + ArtifactDetection.populate(interval_dict) # Spike sorting print("Running spike sorting") - for artifact_key in artifact_key_list: + for artifact_key in artifact_keys: ss_key = dict( **(ArtifactDetection & artifact_key).fetch1("KEY"), **(ArtifactRemovedIntervalList() & artifact_key).fetch1("KEY"), @@ -225,33 +209,11 @@ def spikesorting_pipeline_populator( ) ss_key.pop("artifact_params_name") SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) - ss_proj = ( - SpikeSortingSelection - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).proj() - SpikeSorting.populate([ss_proj]) - - ## Curation - ## Initial curation - ## Extract waveforms - ## Quality Metrics - ## Automatic Curation - ## Curated Spike Sorting - ## Curation Figurl + SpikeSorting.populate(sort_dict) # initial curation print("Beginning curation") - sorting_key_list = ( - SpikeSorting() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("KEY") - for sorting_key in sorting_key_list: + for sorting_key in (SpikeSorting() & sort_dict).fetch("KEY"): Curation.insert_curation(sorting_key) # Calculate quality metrics and perform automatic curation if specified @@ -262,77 +224,38 @@ def spikesorting_pipeline_populator( ): # Extract waveforms print("Extracting waveforms") - curation_key_list = ( - Curation() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("KEY") - for curation_key in curation_key_list: - curation_key["waveform_params_name"] = waveform_params_name - WaveformSelection.insert1(curation_key, skip_duplicates=True) - wave_pj = ( - WaveformSelection() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).proj() - Waveforms.populate([wave_pj]) + curation_keys = [ + {**k, "waveform_params_name": waveform_params_name} + for k in (Curation() & sort_dict).fetch("KEY") + ] + WaveformSelection.insert(curation_keys, skip_duplicates=True) + Waveforms.populate(sort_dict) # Quality Metrics print("Calculating quality metrics") - waveform_key_list = ( - Waveforms() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("KEY") - for waveform_key in waveform_key_list: - waveform_key["metric_params_name"] = metric_params_name - MetricSelection.insert1(waveform_key, skip_duplicates=True) - metrics_pj = ( - MetricSelection() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).proj() - QualityMetrics().populate([metrics_pj]) + waveform_keys = [ + {**k, "metric_params_name": metric_params_name} + for k in (Waveforms() & sort_dict).fetch("KEY") + ] + MetricSelection.insert(waveform_keys, skip_duplicates=True) + QualityMetrics().populate(sort_dict) # Automatic Curation print("Creating automatic curation") - metric_key_list = ( - QualityMetrics() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("KEY") - for metric_key in metric_key_list: - metric_key["auto_curation_params_name"] = auto_curation_params_name - AutomaticCurationSelection.insert1(metric_key, skip_duplicates=True) - auto_pj = ( - AutomaticCurationSelection - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).proj() - AutomaticCuration().populate([auto_pj]) + metric_keys = [ + {**k, "auto_curation_name": auto_curation_params_name} + for k in (QualityMetrics() & sort_dict).fetch("KEY") + ] + AutomaticCurationSelection.insert(metric_keys, skip_duplicates=True) + AutomaticCuration().populate(sort_dict) - # get curation keys of the automatic curation to populate into curated spike sorting selection # Curated Spike Sorting + # get curation keys of the automatic curation to populate into curated + # spike sorting selection print("Creating curated spike sorting") - auto_key_list = ( - AutomaticCuration() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("auto_curation_key") + auto_key_list = (AutomaticCuration() & sort_dict).fetch( + "auto_curation_key" + ) for auto_key in auto_key_list: curation_auto_key = (Curation() & auto_key).fetch1("KEY") CuratedSpikeSortingSelection.insert1( @@ -340,51 +263,37 @@ def spikesorting_pipeline_populator( ) else: - # Perform no automatic curation, just populate curated spike sorting selection with the initial curation. - # Used in case of clusterless decoding + # Perform no automatic curation, just populate curated spike sorting + # selection with the initial curation. Used in case of clusterless + # decoding print("Creating curated spike sorting") - # list of just initial curations - curation_key_list = ( - Curation() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("KEY") - for curation_key in curation_key_list: + curation_keys = (Curation() & sort_dict).fetch("KEY") + for curation_key in curation_keys: CuratedSpikeSortingSelection.insert1( curation_auto_key, skip_duplicates=True ) # Populate curated spike sorting - cur_proj = CuratedSpikeSortingSelection() & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - CuratedSpikeSorting.populate([cur_proj]) + CuratedSpikeSorting.populate(sort_dict) - if len(fig_url_repo) > 0: + if fig_url_repo: # Curation Figurl print("Creating curation figurl") - sort_interval_name = interval_list_name + f"_entire" - for auto_id in ( - AutomaticCuration() - & { - "nwb_file_name": nwb_file_name, - "sort_interval_name": sort_interval_name, - } - ).fetch("auto_curation_key"): - tetrode = auto_id["sort_group_id"] - session_id = nwb_file_name + "_" + sort_interval_name - github_url = ( - fig_url_repo - + str(session_id) - + "/" - + str(tetrode) - + "/curation.json" + sort_interval_name = interval_list_name + "_entire" + gh_url = ( + fig_url_repo + + str(nwb_file_name + "_" + sort_interval_name) # session id + + "/{}" # tetrode using auto_id['sort_group_id'] + + "/curation.json" + ) + + for auto_id in (AutomaticCuration() & sort_dict).fetch( + "auto_curation_key" + ): + auto_curation_out_key = dict( + **(Curation() & auto_id).fetch1("KEY"), + new_curation_uri=gh_url.format(str(auto_id["sort_group_id"])), ) - auto_curation_out_key = (Curation() & auto_id).fetch1("KEY") - auto_curation_out_key["new_curation_uri"] = github_url CurationFigurlSelection.insert1( auto_curation_out_key, skip_duplicates=True )