Skip to content

Commit

Permalink
Alternative implementation of causal filter in filter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPimientoCaicedo authored Jul 3, 2024
1 parent 2af38a3 commit f5cbc4d
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class FilterRecording(BasePreprocessor):
Generic filter class based on:
* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfilt
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt
* scipy.signal.lfilt or scipy.signal.sosfilt when causal_mode = True
BandpassFilterRecording is built on top of it.
Expand Down Expand Up @@ -56,6 +57,10 @@ class FilterRecording(BasePreprocessor):
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
causal_mode : Bool, default: False
If true, filtering is applied in just one direction.
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering
Returns
-------
Expand All @@ -77,6 +82,8 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,
causal_mode=False,
direction="forward"
):
import scipy.signal

Expand Down Expand Up @@ -108,7 +115,7 @@ def __init__(
for parent_segment in recording._recording_segments:
self.add_recording_segment(
FilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding
parent_segment, filter_coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=add_reflect_padding
)
)

Expand All @@ -123,14 +130,18 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
causal_mode=causal_mode,
direction=direction,
)


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False):
def __init__(self, parent_recording_segment, coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=False):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.causal_mode = causal_mode
self.direction = direction
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
Expand All @@ -151,12 +162,26 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk = traces_chunk.astype("float32")

import scipy.signal

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if self.causal_mode:
if self.direction == "backward":
traces_chunk = np.flip(traces_chunk, axis = 0)

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)

if self.direction == "backward":
filtered_traces = np.flip(filtered_traces, axis = 0)

else:
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -290,12 +315,52 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)

class Causal_filter(FilterRecording):
"""
Performs causal filtering using:
* scipy.signal.lfilt or scipy.signal.sosfilt
Parameters
----------
recording : Recording
The recording extractor to be re-referenced
band : float or list, default: [300.0, 6000.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list. band (low, high) in Hz for "bandpass" filter type
margin_ms : float
Margin in ms on border to avoid border effect
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
causal_mode : Bool, default: True
If true, filtering is applied in just one direction.
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering
Returns
-------
filter_recording : CausalFilterRecording
The causal-filtered recording extractor object
{}
"""
name = "causal_filter"

def __init__(self, recording, band=[300.0, 6000.0], margin_ms=5.0, dtype=None,causal_mode = True, direction = "Forward", **filter_kwargs):
FilterRecording.__init__(
self, recording, band=band, margin_ms=margin_ms, dtype=dtype, causal_mode = causal_mode, direction = direction ,**filter_kwargs
)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(
recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str, causal_mode = causal_mode, direction = direction
)
self._kwargs.update(filter_kwargs)

# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")
causal_filter = define_function_from_class(source_class=Causal_filter, name="causal_filter")

bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs)
highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)
Expand Down

0 comments on commit f5cbc4d

Please sign in to comment.