diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 93462ac5d8..c0cfb4aa75 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -27,7 +27,8 @@ class FilterRecording(BasePreprocessor): Generic filter class based on: * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt + * scipy.signal.lfilt or scipy.signal.sosfilt when causal_mode = True BandpassFilterRecording is built on top of it. @@ -56,6 +57,10 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + causal_mode : Bool, default: False + If true, filtering is applied in just one direction. + direction : "forward" | "backward", default: "forward" + when causal_mode = True, defines the direction of the filtering Returns ------- @@ -77,6 +82,8 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + causal_mode=False, + direction="forward" ): import scipy.signal @@ -108,7 +115,7 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, filter_coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=add_reflect_padding ) ) @@ -123,14 +130,18 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + causal_mode=causal_mode, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__(self, parent_recording_segment, coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=False): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.causal_mode = causal_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -151,12 +162,26 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_chunk = traces_chunk.astype("float32") import scipy.signal - - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + + if self.causal_mode: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis = 0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis = 0) + + else: + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -290,12 +315,52 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) +class Causal_filter(FilterRecording): + """ + Performs causal filtering using: + * scipy.signal.lfilt or scipy.signal.sosfilt + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + margin_ms : float + Margin in ms on border to avoid border effect + dtype : dtype or None + The dtype of the returned traces. If None, the dtype of the parent recording is used + causal_mode : Bool, default: True + If true, filtering is applied in just one direction. + direction : "forward" | "backward", default: "forward" + when causal_mode = True, defines the direction of the filtering + Returns + ------- + filter_recording : CausalFilterRecording + The causal-filtered recording extractor object + + {} + + """ + name = "causal_filter" + def __init__(self, recording, band=[300.0, 6000.0], margin_ms=5.0, dtype=None,causal_mode = True, direction = "Forward", **filter_kwargs): + FilterRecording.__init__( + self, recording, band=band, margin_ms=margin_ms, dtype=dtype, causal_mode = causal_mode, direction = direction ,**filter_kwargs + ) + dtype = fix_dtype(recording, dtype) + self._kwargs = dict( + recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str, causal_mode = causal_mode, direction = direction + ) + self._kwargs.update(filter_kwargs) + # functions for API filter = define_function_from_class(source_class=FilterRecording, name="filter") bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter") notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +causal_filter = define_function_from_class(source_class=Causal_filter, name="causal_filter") bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)