Skip to content

Commit

Permalink
Remove methods
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Sep 7, 2023
1 parent 184cb82 commit c1d83eb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 167 deletions.
2 changes: 0 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
"editor.formatOnSave": true,
"files.trimTrailingWhitespace": true,
"files.trimFinalNewlines": true,
"editor.multiCursorModifier": "ctrlCmd",
"autoDocstring.docstringFormat": "numpy",
"python.formatting.provider": "none",
"remote.SSH.remoteServerListenOnSocket": true,
"git.confirmSync": false,
"python.analysis.typeCheckingMode": "off",
Expand Down
197 changes: 32 additions & 165 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class SortGroup(dj.Manual):
definition = """
# Set of electrodes to spike sort together
-> Session
sort_group_id: int # identifier for a group of electrodes
sort_group_id: str # identifier for a group of electrodes
---
sort_reference_electrode_id = -1: int # the electrode to use for reference. -1: no reference, -2: common median
sort_reference_electrode_id = -1: int # the electrode to use for referencing
# -1: no reference, -2: common median
"""

class SortGroupElectrode(dj.Part):
Expand All @@ -49,8 +50,7 @@ def set_group_by_shank(
omit_ref_electrode_group=False,
omit_unitrode=True,
):
"""Divides electrodes into groups based on their shank position.
"""Create sort group for each shank.
* Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a
single group
* Electrodes from probes with multiple shanks (e.g. polymer probes) are
Expand Down Expand Up @@ -163,177 +163,44 @@ def set_group_by_shank(
cls.SortGroupElectrode.insert1(sge_key)
sort_group += 1

def set_group_by_electrode_group(self, nwb_file_name: str):
"""Assign groups to all non-bad channel electrodes based on their electrode group
and sets the reference for each group to the reference for the first channel of the group.
Parameters
----------
nwb_file_name: str
the name of the nwb whose electrodes should be put into sorting groups
"""
# delete any current groups
(SortGroup & {"nwb_file_name": nwb_file_name}).delete()
# get the electrodes from this NWB file
electrodes = (
Electrode()
& {"nwb_file_name": nwb_file_name}
& {"bad_channel": "False"}
).fetch()
e_groups = np.unique(electrodes["electrode_group_name"])
sg_key = dict()
sge_key = dict()
sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name
sort_group = 0
for e_group in e_groups:
sge_key["electrode_group_name"] = e_group
# sg_key['sort_group_id'] = sge_key['sort_group_id'] = sort_group
# TEST
sg_key["sort_group_id"] = sge_key["sort_group_id"] = int(e_group)
# get the list of references and make sure they are all the same
shank_elect_ref = electrodes["original_reference_electrode"][
electrodes["electrode_group_name"] == e_group
]
if np.max(shank_elect_ref) == np.min(shank_elect_ref):
sg_key["sort_reference_electrode_id"] = shank_elect_ref[0]
else:
ValueError(
f"Error in electrode group {e_group}: reference electrodes are not all the same"
)
self.insert1(sg_key)

shank_elect = electrodes["electrode_id"][
electrodes["electrode_group_name"] == e_group
]
for elect in shank_elect:
sge_key["electrode_id"] = elect
self.SortGroupElectrode().insert1(sge_key)
sort_group += 1

def set_reference_from_list(self, nwb_file_name, sort_group_ref_list):
"""
Set the reference electrode from a list containing sort groups and reference electrodes
:param: sort_group_ref_list - 2D array or list where each row is [sort_group_id reference_electrode]
:param: nwb_file_name - The name of the NWB file whose electrodes' references should be updated
:return: Null
"""
key = dict()
key["nwb_file_name"] = nwb_file_name
sort_group_list = (SortGroup() & key).fetch1()
for sort_group in sort_group_list:
key["sort_group_id"] = sort_group
self.insert(
dj_replace(
sort_group_list,
sort_group_ref_list,
"sort_group_id",
"sort_reference_electrode_id",
),
replace="True",
)

"""
Returns a list with the x,y coordinates of the electrodes in the sort group
for use with the SpikeInterface package.
Converts z locations to y where appropriate.
Parameters
----------
sort_group_id : int
nwb_file_name : str
prb_file_name : str
Returns
-------
geometry : list
List of coordinate pairs, one per electrode
"""

# create the channel_groups dictiorary
channel_group = dict()
key = dict()
key["nwb_file_name"] = nwb_file_name
electrodes = (Electrode() & key).fetch()

key["sort_group_id"] = sort_group_id
sort_group_electrodes = (SortGroup.SortGroupElectrode() & key).fetch()
electrode_group_name = sort_group_electrodes["electrode_group_name"][0]
probe_id = (
ElectrodeGroup
& {
"nwb_file_name": nwb_file_name,
"electrode_group_name": electrode_group_name,
}
).fetch1("probe_id")
channel_group[sort_group_id] = dict()
channel_group[sort_group_id]["channels"] = sort_group_electrodes[
"electrode_id"
].tolist()

n_chan = len(channel_group[sort_group_id]["channels"])

geometry = np.zeros((n_chan, 2), dtype="float")
tmp_geom = np.zeros((n_chan, 3), dtype="float")
for i, electrode_id in enumerate(
channel_group[sort_group_id]["channels"]
):
# get the relative x and y locations of this channel from the probe table
probe_electrode = int(
electrodes["probe_electrode"][
electrodes["electrode_id"] == electrode_id
]
)
rel_x, rel_y, rel_z = (
Probe().Electrode()
& {"probe_id": probe_id, "probe_electrode": probe_electrode}
).fetch("rel_x", "rel_y", "rel_z")
# TODO: Fix this HACK when we can use probeinterface:
rel_x = float(rel_x)
rel_y = float(rel_y)
rel_z = float(rel_z)
tmp_geom[i, :] = [rel_x, rel_y, rel_z]

# figure out which columns have coordinates
n_found = 0
for i in range(3):
if np.any(np.nonzero(tmp_geom[:, i])):
if n_found < 2:
geometry[:, n_found] = tmp_geom[:, i]
n_found += 1
else:
Warning(
"Relative electrode locations have three coordinates; only two are currenlty supported"
)
return np.ndarray.tolist(geometry)


@schema
class SpikeSortingPreprocessingParameter(dj.Manual):
class SpikeSortingPreprocessingParameter(dj.Lookup):
definition = """
# Parameter for denoising (filtering and referencing/whitening) recording
# prior to spike sorting
preproc_param_name: varchar(200)
---
preproc_param: blob
"""
freq_min = 300 # high pass filter value
freq_max = 6000 # low pass filter value
margin_ms = 5 # margin in ms on border to avoid border effect
seed = 0 # random seed for whitening

contents = [
[
"default",
{
"frequency_min": freq_min,
"frequency_max": freq_max,
"margin_ms": margin_ms,
"seed": seed,
},
]
]

def insert_default(self):
# set up the default filter parameters
freq_min = 300 # high pass filter value
freq_max = 6000 # low pass filter value
margin_ms = 5 # margin in ms on border to avoid border effect
seed = 0 # random seed for whitening

@classmethod
def insert_default(cls):
key = dict()
key["preproc_params_name"] = "default"
key["preproc_params"] = {
"frequency_min": freq_min,
"frequency_max": freq_max,
"margin_ms": margin_ms,
"seed": seed,
"frequency_min": cls.freq_min,
"frequency_max": cls.freq_max,
"margin_ms": cls.margin_ms,
"seed": cls.seed,
}
self.insert1(key, skip_duplicates=True)
cls.insert1(key, skip_duplicates=True)


@schema
Expand Down Expand Up @@ -417,7 +284,7 @@ def _get_recording_timestamps(recording):
timestamps = recording.get_times()
return timestamps

def _get_sort_interval_valid_times(self, key):
def _get_sort_interval_valid_times(self, key: dict):
"""Identifies the intersection between sort interval specified by the user
and the valid times (times for which neural data exist)
Expand All @@ -429,14 +296,14 @@ def _get_sort_interval_valid_times(self, key):
Returns
-------
sort_interval_valid_times: ndarray of tuples
(start, end) times for valid stretches of the sorting interval
(start, end) times for valid intervals in the sort interval
"""
sort_interval = (
SortInterval
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"sort_interval_name": key["sort_interval_name"],
"interval_list_name": key["sort_interval_name"],
}
).fetch1("sort_interval")
interval_list_name = (SpikeSortingRecordingSelection & key).fetch1(
Expand Down

0 comments on commit c1d83eb

Please sign in to comment.