diff --git a/.vscode/settings.json b/.vscode/settings.json index 2c05895e3..4a1bb2175 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,9 +2,7 @@ "editor.formatOnSave": true, "files.trimTrailingWhitespace": true, "files.trimFinalNewlines": true, - "editor.multiCursorModifier": "ctrlCmd", "autoDocstring.docstringFormat": "numpy", - "python.formatting.provider": "none", "remote.SSH.remoteServerListenOnSocket": true, "git.confirmSync": false, "python.analysis.typeCheckingMode": "off", diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 79273c4a3..d7d8d61da 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -30,9 +30,10 @@ class SortGroup(dj.Manual): definition = """ # Set of electrodes to spike sort together -> Session - sort_group_id: int # identifier for a group of electrodes + sort_group_id: str # identifier for a group of electrodes --- - sort_reference_electrode_id = -1: int # the electrode to use for reference. -1: no reference, -2: common median + sort_reference_electrode_id = -1: int # the electrode to use for referencing + # -1: no reference, -2: common median """ class SortGroupElectrode(dj.Part): @@ -49,8 +50,7 @@ def set_group_by_shank( omit_ref_electrode_group=False, omit_unitrode=True, ): - """Divides electrodes into groups based on their shank position. - + """Create sort group for each shank. * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a single group * Electrodes from probes with multiple shanks (e.g. polymer probes) are @@ -163,153 +163,9 @@ def set_group_by_shank( cls.SortGroupElectrode.insert1(sge_key) sort_group += 1 - def set_group_by_electrode_group(self, nwb_file_name: str): - """Assign groups to all non-bad channel electrodes based on their electrode group - and sets the reference for each group to the reference for the first channel of the group. - - Parameters - ---------- - nwb_file_name: str - the name of the nwb whose electrodes should be put into sorting groups - """ - # delete any current groups - (SortGroup & {"nwb_file_name": nwb_file_name}).delete() - # get the electrodes from this NWB file - electrodes = ( - Electrode() - & {"nwb_file_name": nwb_file_name} - & {"bad_channel": "False"} - ).fetch() - e_groups = np.unique(electrodes["electrode_group_name"]) - sg_key = dict() - sge_key = dict() - sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name - sort_group = 0 - for e_group in e_groups: - sge_key["electrode_group_name"] = e_group - # sg_key['sort_group_id'] = sge_key['sort_group_id'] = sort_group - # TEST - sg_key["sort_group_id"] = sge_key["sort_group_id"] = int(e_group) - # get the list of references and make sure they are all the same - shank_elect_ref = electrodes["original_reference_electrode"][ - electrodes["electrode_group_name"] == e_group - ] - if np.max(shank_elect_ref) == np.min(shank_elect_ref): - sg_key["sort_reference_electrode_id"] = shank_elect_ref[0] - else: - ValueError( - f"Error in electrode group {e_group}: reference electrodes are not all the same" - ) - self.insert1(sg_key) - - shank_elect = electrodes["electrode_id"][ - electrodes["electrode_group_name"] == e_group - ] - for elect in shank_elect: - sge_key["electrode_id"] = elect - self.SortGroupElectrode().insert1(sge_key) - sort_group += 1 - - def set_reference_from_list(self, nwb_file_name, sort_group_ref_list): - """ - Set the reference electrode from a list containing sort groups and reference electrodes - :param: sort_group_ref_list - 2D array or list where each row is [sort_group_id reference_electrode] - :param: nwb_file_name - The name of the NWB file whose electrodes' references should be updated - :return: Null - """ - key = dict() - key["nwb_file_name"] = nwb_file_name - sort_group_list = (SortGroup() & key).fetch1() - for sort_group in sort_group_list: - key["sort_group_id"] = sort_group - self.insert( - dj_replace( - sort_group_list, - sort_group_ref_list, - "sort_group_id", - "sort_reference_electrode_id", - ), - replace="True", - ) - - """ - Returns a list with the x,y coordinates of the electrodes in the sort group - for use with the SpikeInterface package. - - Converts z locations to y where appropriate. - - Parameters - ---------- - sort_group_id : int - nwb_file_name : str - prb_file_name : str - - Returns - ------- - geometry : list - List of coordinate pairs, one per electrode - """ - - # create the channel_groups dictiorary - channel_group = dict() - key = dict() - key["nwb_file_name"] = nwb_file_name - electrodes = (Electrode() & key).fetch() - - key["sort_group_id"] = sort_group_id - sort_group_electrodes = (SortGroup.SortGroupElectrode() & key).fetch() - electrode_group_name = sort_group_electrodes["electrode_group_name"][0] - probe_id = ( - ElectrodeGroup - & { - "nwb_file_name": nwb_file_name, - "electrode_group_name": electrode_group_name, - } - ).fetch1("probe_id") - channel_group[sort_group_id] = dict() - channel_group[sort_group_id]["channels"] = sort_group_electrodes[ - "electrode_id" - ].tolist() - - n_chan = len(channel_group[sort_group_id]["channels"]) - - geometry = np.zeros((n_chan, 2), dtype="float") - tmp_geom = np.zeros((n_chan, 3), dtype="float") - for i, electrode_id in enumerate( - channel_group[sort_group_id]["channels"] - ): - # get the relative x and y locations of this channel from the probe table - probe_electrode = int( - electrodes["probe_electrode"][ - electrodes["electrode_id"] == electrode_id - ] - ) - rel_x, rel_y, rel_z = ( - Probe().Electrode() - & {"probe_id": probe_id, "probe_electrode": probe_electrode} - ).fetch("rel_x", "rel_y", "rel_z") - # TODO: Fix this HACK when we can use probeinterface: - rel_x = float(rel_x) - rel_y = float(rel_y) - rel_z = float(rel_z) - tmp_geom[i, :] = [rel_x, rel_y, rel_z] - - # figure out which columns have coordinates - n_found = 0 - for i in range(3): - if np.any(np.nonzero(tmp_geom[:, i])): - if n_found < 2: - geometry[:, n_found] = tmp_geom[:, i] - n_found += 1 - else: - Warning( - "Relative electrode locations have three coordinates; only two are currenlty supported" - ) - return np.ndarray.tolist(geometry) - @schema -class SpikeSortingPreprocessingParameter(dj.Manual): +class SpikeSortingPreprocessingParameter(dj.Lookup): definition = """ # Parameter for denoising (filtering and referencing/whitening) recording # prior to spike sorting @@ -317,23 +173,34 @@ class SpikeSortingPreprocessingParameter(dj.Manual): --- preproc_param: blob """ + freq_min = 300 # high pass filter value + freq_max = 6000 # low pass filter value + margin_ms = 5 # margin in ms on border to avoid border effect + seed = 0 # random seed for whitening + + contents = [ + [ + "default", + { + "frequency_min": freq_min, + "frequency_max": freq_max, + "margin_ms": margin_ms, + "seed": seed, + }, + ] + ] - def insert_default(self): - # set up the default filter parameters - freq_min = 300 # high pass filter value - freq_max = 6000 # low pass filter value - margin_ms = 5 # margin in ms on border to avoid border effect - seed = 0 # random seed for whitening - + @classmethod + def insert_default(cls): key = dict() key["preproc_params_name"] = "default" key["preproc_params"] = { - "frequency_min": freq_min, - "frequency_max": freq_max, - "margin_ms": margin_ms, - "seed": seed, + "frequency_min": cls.freq_min, + "frequency_max": cls.freq_max, + "margin_ms": cls.margin_ms, + "seed": cls.seed, } - self.insert1(key, skip_duplicates=True) + cls.insert1(key, skip_duplicates=True) @schema @@ -417,7 +284,7 @@ def _get_recording_timestamps(recording): timestamps = recording.get_times() return timestamps - def _get_sort_interval_valid_times(self, key): + def _get_sort_interval_valid_times(self, key: dict): """Identifies the intersection between sort interval specified by the user and the valid times (times for which neural data exist) @@ -429,14 +296,14 @@ def _get_sort_interval_valid_times(self, key): Returns ------- sort_interval_valid_times: ndarray of tuples - (start, end) times for valid stretches of the sorting interval + (start, end) times for valid intervals in the sort interval """ sort_interval = ( - SortInterval + IntervalList & { "nwb_file_name": key["nwb_file_name"], - "sort_interval_name": key["sort_interval_name"], + "interval_list_name": key["sort_interval_name"], } ).fetch1("sort_interval") interval_list_name = (SpikeSortingRecordingSelection & key).fetch1(