From 34b7211383e6c63db3c1d2829889714c3822cb34 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 22 Aug 2023 14:22:00 -0700 Subject: [PATCH] LFP artifact optimizations --- src/spyglass/lfp/v1/lfp_artifact.py | 218 +++++++++--------- .../lfp/v1/lfp_artifact_MAD_detection.py | 2 + .../v1/lfp_artifact_difference_detection.py | 104 ++++----- 3 files changed, 159 insertions(+), 165 deletions(-) diff --git a/src/spyglass/lfp/v1/lfp_artifact.py b/src/spyglass/lfp/v1/lfp_artifact.py index 5666559bd..3c6389af4 100644 --- a/src/spyglass/lfp/v1/lfp_artifact.py +++ b/src/spyglass/lfp/v1/lfp_artifact.py @@ -29,72 +29,80 @@ class LFPArtifactDetectionParameters(dj.Manual): """ def insert_default(self): - """Insert the default artifact parameters with an appropriate parameter dict.""" - artifact_params = { - "artifact_detection_algorithm": "difference", - "artifact_detection_algorithm_params": { - "amplitude_thresh_1st": 500, # must be None or >= 0 - "proportion_above_thresh_1st": 0.1, - "amplitude_thresh_2nd": 1000, # must be None or >= 0 - "proportion_above_thresh_2nd": 0.05, - "removal_window_ms": 10.0, # in milliseconds - "local_window_ms": 40.0, # in milliseconds + """Insert the default artifact parameters.""" + diff_params = [ + "default_difference", + { + "artifact_detection_algorithm": "difference", + "artifact_detection_algorithm_params": { + "amplitude_thresh_1st": 500, # must be None or >= 0 + "proportion_above_thresh_1st": 0.1, + "amplitude_thresh_2nd": 1000, # must be None or >= 0 + "proportion_above_thresh_2nd": 0.05, + "removal_window_ms": 10.0, # in milliseconds + "local_window_ms": 40.0, # in milliseconds + }, }, - } + ] - self.insert1( - ["default_difference", artifact_params], skip_duplicates=True - ) + diff_ref_params = [ + "default_difference_ref", + { + "artifact_detection_algorithm": "difference", + "artifact_detection_algorithm_params": { + "amplitude_thresh_1st": 500, # must be None or >= 0 + "proportion_above_thresh_1st": 0.1, + "amplitude_thresh_2nd": 1000, # must be None or >= 0 + "proportion_above_thresh_2nd": 0.05, + "removal_window_ms": 10.0, # in milliseconds + "local_window_ms": 40.0, # in milliseconds + }, + "referencing": { + "ref_on": 1, + "reference_list": [0, 0, 0, 0, 0], + "electrode_list": [0, 0], + }, + }, + ] - artifact_params = { - "artifact_detection_algorithm": "difference", - "artifact_detection_algorithm_params": { - "amplitude_thresh_1st": 500, # must be None or >= 0 - "proportion_above_thresh_1st": 0.1, - "amplitude_thresh_2nd": 1000, # must be None or >= 0 - "proportion_above_thresh_2nd": 0.05, - "removal_window_ms": 10.0, # in milliseconds - "local_window_ms": 40.0, # in milliseconds + no_params = [ + "none", + { + "artifact_detection_algorithm": "difference", + "artifact_detection_algorithm_params": { + "amplitude_thresh_1st": None, # must be None or >= 0 + "proportion_above_thresh_1st": None, + "amplitude_thresh_2nd": None, # must be None or >= 0 + "proportion_above_thresh_2nd": None, + "removal_window_ms": None, # in milliseconds + "local_window_ms": None, # in milliseconds + }, }, - "referencing": { - "ref_on": 1, - "reference_list": [0, 0, 0, 0, 0], - "electrode_list": [0, 0], + ] + + mad_params = [ + "default_mad", + { + "artifact_detection_algorithm": "mad", + "artifact_detection_algorithm_params": { + # akin to z-score std dev if the distribution is normal + "mad_thresh": 6.0, + "proportion_above_thresh": 0.1, + "removal_window_ms": 10.0, # in milliseconds + }, }, - } + ] - self.insert1( - ["default_difference_ref", artifact_params], skip_duplicates=True + self.insert( + [diff_params, diff_ref_params, no_params, mad_params], + skip_duplicates=True, ) - artifact_params_none = { - "artifact_detection_algorithm": "difference", - "artifact_detection_algorithm_params": { - "amplitude_thresh_1st": None, # must be None or >= 0 - "proportion_above_thresh_1st": None, - "amplitude_thresh_2nd": None, # must be None or >= 0 - "proportion_above_thresh_2nd": None, - "removal_window_ms": None, # in milliseconds - "local_window_ms": None, # in milliseconds - }, - } - self.insert1(["none", artifact_params_none], skip_duplicates=True) - - artifact_params_mad = { - "artifact_detection_algorithm": "mad", - "artifact_detection_algorithm_params": { - "mad_thresh": 6.0, # akin to z-score standard deviations if the distribution is normal - "proportion_above_thresh": 0.1, - "removal_window_ms": 10.0, # in milliseconds - }, - } - self.insert1(["default_mad", artifact_params_mad], skip_duplicates=True) - @schema class LFPArtifactDetectionSelection(dj.Manual): definition = """ - # Specifies artifact detection parameters to apply to a sort group's recording. + # Artifact detection parameters to apply to a sort group's recording. -> LFPV1 -> LFPArtifactDetectionParameters --- @@ -108,8 +116,9 @@ class LFPArtifactDetection(dj.Computed): -> LFPArtifactDetectionSelection --- artifact_times: longblob # np array of artifact intervals - artifact_removed_valid_times: longblob # np array of valid no-artifact intervals - artifact_removed_interval_list_name: varchar(200) # name of the array of no-artifact valid time intervals + artifact_removed_valid_times: longblob # np array of no-artifact intervals + artifact_removed_interval_list_name: varchar(200) + # name of the array of no-artifact valid time intervals """ def make(self, key): @@ -118,18 +127,13 @@ def make(self, key): & {"artifact_params_name": key["artifact_params_name"]} ).fetch1("artifact_params") - artifact_detection_algorithm = artifact_params[ - "artifact_detection_algorithm" - ] - artifact_detection_params = artifact_params[ - "artifact_detection_algorithm_params" - ] - + algorithm = artifact_params["artifact_detection_algorithm"] + params = artifact_params["artifact_detection_algorithm_params"] lfp_band_ref_id = artifact_params["referencing"]["reference_list"] # get LFP data - lfp_eseries = (LFPV1() & key).fetch_nwb()[0]["lfp"] - sampling_frequency = (LFPV1() & key).fetch("lfp_sampling_rate")[0] + lfp_eseries = (LFPV1 & key).fetch_nwb()[0]["lfp"] + sampling_frequency = (LFPV1 & key).fetch("lfp_sampling_rate")[0] # do referencing at this step lfp_data = np.asarray( @@ -145,67 +149,61 @@ def make(self, key): # maybe this lfp_elec_list is supposed to be a list on indices for index, elect_index in enumerate(lfp_band_elect_index): - if lfp_band_ref_id[index] != -1: - lfp_data[:, elect_index] = ( - lfp_data[:, elect_index] - - lfp_data[:, lfp_band_ref_index[index]] - ) - - if artifact_detection_algorithm == "difference": - ( - artifact_removed_valid_times, - artifact_times, - ) = ARTIFACT_DETECTION_ALGORITHMS[artifact_detection_algorithm]( - lfp_data, - timestamps=lfp_eseries.timestamps, - **artifact_detection_params, - sampling_frequency=sampling_frequency, - referencing=artifact_params["referencing"]["ref_on"], - ) - else: - ( - artifact_removed_valid_times, - artifact_times, - ) = ARTIFACT_DETECTION_ALGORITHMS[artifact_detection_algorithm]( - lfp_eseries, - **artifact_detection_params, - sampling_frequency=sampling_frequency, + if lfp_band_ref_id[index] == -1: + continue + lfp_data[:, elect_index] = ( + lfp_data[:, elect_index] + - lfp_data[:, lfp_band_ref_index[index]] ) - key["artifact_times"] = artifact_times - key["artifact_removed_valid_times"] = artifact_removed_valid_times - - # set up a name for no-artifact times using recording id - # we need some name here for recording_name - key["artifact_removed_interval_list_name"] = "_".join( - [ - key["nwb_file_name"], - key["target_interval_list_name"], - "LFP", - key["artifact_params_name"], - "artifact_removed_valid_times", - ] + is_diff = algorithm == "difference" + data = lfp_data if is_diff else lfp_eseries + ref = artifact_params["referencing"]["ref_on"] if is_diff else None + + ( + artifact_removed_valid_times, + artifact_times, + ) = ARTIFACT_DETECTION_ALGORITHMS[algorithm]( + data, + **params, + sampling_frequency=sampling_frequency, + timestamps=lfp_eseries.timestamps if is_diff else None, + referencing=ref, ) - LFPArtifactRemovedIntervalList.insert1(key, replace=True) + key.update( + dict( + artifact_times=artifact_times, + artifact_removed_valid_times=artifact_removed_valid_times, + # name for no-artifact time name using recording id + artifact_removed_interval_list_name="_".join( + [ + key["nwb_file_name"], + key["target_interval_list_name"], + "LFP", + key["artifact_params_name"], + "artifact_removed_valid_times", + ] + ), + ) + ) - # also insert into IntervalList - interval_key = { + interval_key = { # also insert into IntervalList "nwb_file_name": key["nwb_file_name"], "interval_list_name": key["artifact_removed_interval_list_name"], "valid_times": key["artifact_removed_valid_times"], } - IntervalList.insert1(interval_key, replace=True) - # insert into computed table + LFPArtifactRemovedIntervalList.insert1(key, replace=True) + IntervalList.insert1(interval_key, replace=True) self.insert1(key) @schema class LFPArtifactRemovedIntervalList(dj.Manual): definition = """ - # Stores intervals without detected artifacts. - # Note that entries can come from either ArtifactDetection() or alternative artifact removal analyses. + # Stores intervals without detected artifacts. Entries can come from either + # ArtifactDetection() or alternative artifact removal analyses. artifact_removed_interval_list_name: varchar(200) --- -> LFPArtifactDetectionSelection diff --git a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py index d5775653a..eac1943a5 100644 --- a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py +++ b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py @@ -9,6 +9,8 @@ def mad_artifact_detector( proportion_above_thresh: float = 0.1, removal_window_ms: float = 10.0, sampling_frequency: float = 1000.0, + *args, + **kwargs, ) -> tuple[np.ndarray, np.ndarray]: """Detect LFP artifacts using the median absolute deviation method. diff --git a/src/spyglass/lfp/v1/lfp_artifact_difference_detection.py b/src/spyglass/lfp/v1/lfp_artifact_difference_detection.py index a77ec9d7c..449b8c56d 100644 --- a/src/spyglass/lfp/v1/lfp_artifact_difference_detection.py +++ b/src/spyglass/lfp/v1/lfp_artifact_difference_detection.py @@ -27,34 +27,38 @@ def difference_artifact_detector( ): """Detects times during which artifacts do and do not occur. - Artifacts are defined as periods where the absolute value of the change in LFP exceeds - amplitude change thresholds on the proportion of channels specified, - with the period extended by the removal_window_ms/2 on each side. amplitude change - threshold values of None are ignored. + Artifacts are defined as periods where the absolute value of the change in + LFP exceeds amplitude change thresholds on the proportion of channels + specified, with the period extended by the removal_window_ms/2 on each side. + amplitude change threshold values of None are ignored. Parameters ---------- - recording : lfp eseries - zscore_thresh : float, optional + recording : lfp eseries zscore_thresh : float, optional Stdev threshold for exclusion, should be >=0, defaults to None amplitude_thresh : float, optional - Amplitude (ad units) threshold for exclusion, should be >=0, defaults to None + Amplitude (ad units) threshold for exclusion, should be >=0, defaults to + None proportion_above_thresh : float, optional, should be>0 and <=1 - Proportion of electrodes that need to have threshold crossings, defaults to 1 + Proportion of electrodes that need to have threshold crossings, defaults + to 1 removal_window_ms : float, optional - Width of the window in milliseconds to mask out per artifact - (window/2 removed on each side of threshold crossing), defaults to 1 ms + Width of the window in milliseconds to mask out per artifact (window/2 + removed on each side of threshold crossing), defaults to 1 ms Returns ------- artifact_removed_valid_times : np.ndarray - Intervals of valid times where artifacts were not detected, unit: seconds + Intervals of valid times where artifacts were not detected, unit: + seconds artifact_intervals : np.ndarray - Intervals in which artifacts are detected (including removal windows), unit: seconds + Intervals in which artifacts are detected (including removal windows), + unit: seconds """ - # NOTE: 7-17-23 updated to remove recording.data, since it will converted to numpy array before referencing - # check for referencing flag + # NOTE: 7-17-23 updated to remove recording.data, since it will converted to + # numpy array before referencing check for referencing flag + if referencing == 1: print("referencing activated. may be set to -1") @@ -92,24 +96,21 @@ def difference_artifact_detector( print("num tets 1", nelect_above_1st, "num tets 2", nelect_above_2nd) print("data shape", recording.shape) - # find the artifact occurrences using one or both thresholds, across channels + # find the artifact occurrences using one or both thresholds, across + # channels if amplitude_thresh_1st is not None: # first find times with large amp change: sum diff over several timebins diff_array = np.diff(recording, axis=0) - print("updated script") - if referencing == 0: - print("referencing off") - window = np.ones((15, 1)) - - elif referencing == 1: - print("referencing on") - window = np.ones((3, 1)) + window = np.ones((3, 1)) if referencing else np.ones((15, 1)) + # sum differences over bins using convolution for speed width = int((window.size - 1) / 2) diff_array = np.pad( - diff_array, pad_width=((width, width), (0, 0)), mode="constant" + diff_array, + pad_width=((width, width), (0, 0)), + mode="constant", ) diff_array_5 = scipy.signal.convolve(diff_array, window, mode="valid") @@ -159,7 +160,7 @@ def difference_artifact_detector( artifact_frames = above_thresh.copy() print("detected ", artifact_frames.shape[0], " artifacts") - # turn ms to remove total into s to remove from either side of each detected artifact + # Convert to s to remove from either side of each detected artifact half_removal_window_s = removal_window_ms / 1000 * 0.5 if len(artifact_frames) == 0: @@ -175,6 +176,7 @@ def difference_artifact_detector( artifact_intervals_s = np.zeros( (len(artifact_intervals), 2), dtype=np.float64 ) + for interval_idx, interval in enumerate(artifact_intervals): artifact_intervals_s[interval_idx] = [ valid_timestamps[interval[0]] - half_removal_window_s, @@ -218,7 +220,7 @@ def _check_artifact_thresholds( proportion_above_thresh_1st, proportion_above_thresh_2nd, ): - """Alerts user to likely unintended parameters. Not an exhaustive verification. + """Alerts user to likely unintended parameters. Not exhaustive verification. Parameters ---------- @@ -238,40 +240,32 @@ def _check_artifact_thresholds( ------ ValueError: if signal thresholds are negative """ - # amplitude or zscore thresholds should be not negative, as they are applied to an absolute signal - signal_thresholds = [ + signal_thresholds = [ # amplitude or zscore thresh should be not negative t for t in [amplitude_thresh_1st, amplitude_thresh_1st] if t is not None ] + for t in signal_thresholds: if t < 0: - raise ValueError("Amplitude thresholds must be >= 0, or None") + raise ValueError( + f"Amplitude thresholds must be >= 0, or None. Recieved {t}" + ) - # proportion_above_threshold should be in [0:1] inclusive - if proportion_above_thresh_1st < 0: - warnings.warn( - "Warning: proportion_above_thresh must be a proportion >0 and <=1." - f" Using proportion_above_thresh = 0.01 instead of {str(proportion_above_thresh_1st)}" - ) - proportion_above_thresh_1st = 0.01 - elif proportion_above_thresh_1st > 1: - warnings.warn( - "Warning: proportion_above_thresh must be a proportion >0 and <=1. " - f"Using proportion_above_thresh = 1 instead of {str(proportion_above_thresh_1st)}" - ) - proportion_above_thresh_1st = 1 - # proportion_above_threshold should be in [0:1] inclusive - if proportion_above_thresh_2nd < 0: - warnings.warn( - "Warning: proportion_above_thresh must be a proportion >0 and <=1." - f" Using proportion_above_thresh = 0.01 instead of {str(proportion_above_thresh_2nd)}" - ) - proportion_above_thresh_2nd = 0.01 - elif proportion_above_thresh_2nd > 1: - warnings.warn( - "Warning: proportion_above_thresh must be a proportion >0 and <=1. " - f"Using proportion_above_thresh = 1 instead of {str(proportion_above_thresh_2nd)}" - ) - proportion_above_thresh_2nd = 1 + bound_warn = ( + "Warning: proportion_above_thresh must be a proportion >0 and <=1.\n" + + "Replacing {} with {}" + ) + + def clamp(n, min_n=0.01, max_n=1): # replace n outside bounds with bound + if n < 0: + warnings.warn(bound_warn.format(n, min_n)) + elif n < min_n: # handle case where n is btwn 0 and low bound + return n + elif n > max_n: + warnings.warn(bound_warn.format(n, max_n)) + return max(min(n, max_n), min_n) + + proportion_above_thresh_1st = clamp(proportion_above_thresh_1st) + proportion_above_thresh_2nd = clamp(proportion_above_thresh_2nd) return ( amplitude_thresh_1st,