Skip to content

Commit

Permalink
LFP artifact optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 22, 2023
1 parent 1b45d8d commit 34b7211
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 165 deletions.
218 changes: 108 additions & 110 deletions src/spyglass/lfp/v1/lfp_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,72 +29,80 @@ class LFPArtifactDetectionParameters(dj.Manual):
"""

def insert_default(self):
"""Insert the default artifact parameters with an appropriate parameter dict."""
artifact_params = {
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": 500, # must be None or >= 0
"proportion_above_thresh_1st": 0.1,
"amplitude_thresh_2nd": 1000, # must be None or >= 0
"proportion_above_thresh_2nd": 0.05,
"removal_window_ms": 10.0, # in milliseconds
"local_window_ms": 40.0, # in milliseconds
"""Insert the default artifact parameters."""
diff_params = [
"default_difference",
{
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": 500, # must be None or >= 0
"proportion_above_thresh_1st": 0.1,
"amplitude_thresh_2nd": 1000, # must be None or >= 0
"proportion_above_thresh_2nd": 0.05,
"removal_window_ms": 10.0, # in milliseconds
"local_window_ms": 40.0, # in milliseconds
},
},
}
]

self.insert1(
["default_difference", artifact_params], skip_duplicates=True
)
diff_ref_params = [
"default_difference_ref",
{
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": 500, # must be None or >= 0
"proportion_above_thresh_1st": 0.1,
"amplitude_thresh_2nd": 1000, # must be None or >= 0
"proportion_above_thresh_2nd": 0.05,
"removal_window_ms": 10.0, # in milliseconds
"local_window_ms": 40.0, # in milliseconds
},
"referencing": {
"ref_on": 1,
"reference_list": [0, 0, 0, 0, 0],
"electrode_list": [0, 0],
},
},
]

artifact_params = {
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": 500, # must be None or >= 0
"proportion_above_thresh_1st": 0.1,
"amplitude_thresh_2nd": 1000, # must be None or >= 0
"proportion_above_thresh_2nd": 0.05,
"removal_window_ms": 10.0, # in milliseconds
"local_window_ms": 40.0, # in milliseconds
no_params = [
"none",
{
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": None, # must be None or >= 0
"proportion_above_thresh_1st": None,
"amplitude_thresh_2nd": None, # must be None or >= 0
"proportion_above_thresh_2nd": None,
"removal_window_ms": None, # in milliseconds
"local_window_ms": None, # in milliseconds
},
},
"referencing": {
"ref_on": 1,
"reference_list": [0, 0, 0, 0, 0],
"electrode_list": [0, 0],
]

mad_params = [
"default_mad",
{
"artifact_detection_algorithm": "mad",
"artifact_detection_algorithm_params": {
# akin to z-score std dev if the distribution is normal
"mad_thresh": 6.0,
"proportion_above_thresh": 0.1,
"removal_window_ms": 10.0, # in milliseconds
},
},
}
]

self.insert1(
["default_difference_ref", artifact_params], skip_duplicates=True
self.insert(
[diff_params, diff_ref_params, no_params, mad_params],
skip_duplicates=True,
)

artifact_params_none = {
"artifact_detection_algorithm": "difference",
"artifact_detection_algorithm_params": {
"amplitude_thresh_1st": None, # must be None or >= 0
"proportion_above_thresh_1st": None,
"amplitude_thresh_2nd": None, # must be None or >= 0
"proportion_above_thresh_2nd": None,
"removal_window_ms": None, # in milliseconds
"local_window_ms": None, # in milliseconds
},
}
self.insert1(["none", artifact_params_none], skip_duplicates=True)

artifact_params_mad = {
"artifact_detection_algorithm": "mad",
"artifact_detection_algorithm_params": {
"mad_thresh": 6.0, # akin to z-score standard deviations if the distribution is normal
"proportion_above_thresh": 0.1,
"removal_window_ms": 10.0, # in milliseconds
},
}
self.insert1(["default_mad", artifact_params_mad], skip_duplicates=True)


@schema
class LFPArtifactDetectionSelection(dj.Manual):
definition = """
# Specifies artifact detection parameters to apply to a sort group's recording.
# Artifact detection parameters to apply to a sort group's recording.
-> LFPV1
-> LFPArtifactDetectionParameters
---
Expand All @@ -108,8 +116,9 @@ class LFPArtifactDetection(dj.Computed):
-> LFPArtifactDetectionSelection
---
artifact_times: longblob # np array of artifact intervals
artifact_removed_valid_times: longblob # np array of valid no-artifact intervals
artifact_removed_interval_list_name: varchar(200) # name of the array of no-artifact valid time intervals
artifact_removed_valid_times: longblob # np array of no-artifact intervals
artifact_removed_interval_list_name: varchar(200)
# name of the array of no-artifact valid time intervals
"""

def make(self, key):
Expand All @@ -118,18 +127,13 @@ def make(self, key):
& {"artifact_params_name": key["artifact_params_name"]}
).fetch1("artifact_params")

artifact_detection_algorithm = artifact_params[
"artifact_detection_algorithm"
]
artifact_detection_params = artifact_params[
"artifact_detection_algorithm_params"
]

algorithm = artifact_params["artifact_detection_algorithm"]
params = artifact_params["artifact_detection_algorithm_params"]
lfp_band_ref_id = artifact_params["referencing"]["reference_list"]

# get LFP data
lfp_eseries = (LFPV1() & key).fetch_nwb()[0]["lfp"]
sampling_frequency = (LFPV1() & key).fetch("lfp_sampling_rate")[0]
lfp_eseries = (LFPV1 & key).fetch_nwb()[0]["lfp"]
sampling_frequency = (LFPV1 & key).fetch("lfp_sampling_rate")[0]

# do referencing at this step
lfp_data = np.asarray(
Expand All @@ -145,67 +149,61 @@ def make(self, key):
# maybe this lfp_elec_list is supposed to be a list on indices

for index, elect_index in enumerate(lfp_band_elect_index):
if lfp_band_ref_id[index] != -1:
lfp_data[:, elect_index] = (
lfp_data[:, elect_index]
- lfp_data[:, lfp_band_ref_index[index]]
)

if artifact_detection_algorithm == "difference":
(
artifact_removed_valid_times,
artifact_times,
) = ARTIFACT_DETECTION_ALGORITHMS[artifact_detection_algorithm](
lfp_data,
timestamps=lfp_eseries.timestamps,
**artifact_detection_params,
sampling_frequency=sampling_frequency,
referencing=artifact_params["referencing"]["ref_on"],
)
else:
(
artifact_removed_valid_times,
artifact_times,
) = ARTIFACT_DETECTION_ALGORITHMS[artifact_detection_algorithm](
lfp_eseries,
**artifact_detection_params,
sampling_frequency=sampling_frequency,
if lfp_band_ref_id[index] == -1:
continue
lfp_data[:, elect_index] = (
lfp_data[:, elect_index]
- lfp_data[:, lfp_band_ref_index[index]]
)

key["artifact_times"] = artifact_times
key["artifact_removed_valid_times"] = artifact_removed_valid_times

# set up a name for no-artifact times using recording id
# we need some name here for recording_name
key["artifact_removed_interval_list_name"] = "_".join(
[
key["nwb_file_name"],
key["target_interval_list_name"],
"LFP",
key["artifact_params_name"],
"artifact_removed_valid_times",
]
is_diff = algorithm == "difference"
data = lfp_data if is_diff else lfp_eseries
ref = artifact_params["referencing"]["ref_on"] if is_diff else None

(
artifact_removed_valid_times,
artifact_times,
) = ARTIFACT_DETECTION_ALGORITHMS[algorithm](
data,
**params,
sampling_frequency=sampling_frequency,
timestamps=lfp_eseries.timestamps if is_diff else None,
referencing=ref,
)

LFPArtifactRemovedIntervalList.insert1(key, replace=True)
key.update(
dict(
artifact_times=artifact_times,
artifact_removed_valid_times=artifact_removed_valid_times,
# name for no-artifact time name using recording id
artifact_removed_interval_list_name="_".join(
[
key["nwb_file_name"],
key["target_interval_list_name"],
"LFP",
key["artifact_params_name"],
"artifact_removed_valid_times",
]
),
)
)

# also insert into IntervalList
interval_key = {
interval_key = { # also insert into IntervalList
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["artifact_removed_interval_list_name"],
"valid_times": key["artifact_removed_valid_times"],
}
IntervalList.insert1(interval_key, replace=True)

# insert into computed table
LFPArtifactRemovedIntervalList.insert1(key, replace=True)
IntervalList.insert1(interval_key, replace=True)
self.insert1(key)


@schema
class LFPArtifactRemovedIntervalList(dj.Manual):
definition = """
# Stores intervals without detected artifacts.
# Note that entries can come from either ArtifactDetection() or alternative artifact removal analyses.
# Stores intervals without detected artifacts. Entries can come from either
# ArtifactDetection() or alternative artifact removal analyses.
artifact_removed_interval_list_name: varchar(200)
---
-> LFPArtifactDetectionSelection
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def mad_artifact_detector(
proportion_above_thresh: float = 0.1,
removal_window_ms: float = 10.0,
sampling_frequency: float = 1000.0,
*args,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""Detect LFP artifacts using the median absolute deviation method.
Expand Down
Loading

0 comments on commit 34b7211

Please sign in to comment.