Skip to content

Commit

Permalink
change causal mode for direction forward-backward
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPimientoCaicedo authored Jul 5, 2024
1 parent 25d25fd commit 7c5956e
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class FilterRecording(BasePreprocessor):
Generic filter class based on:
* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt
* scipy.signal.lfilt or scipy.signal.sosfilt when causal_mode = True
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward"
* scipy.signal.lfilt or scipy.signal.sosfilt
BandpassFilterRecording is built on top of it.
Expand Down Expand Up @@ -57,10 +57,10 @@ class FilterRecording(BasePreprocessor):
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
causal_mode : bool, default: False
If true, filtering is applied in just one direction.
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering
direction : "forward" | "backward" | "forward-backward", default: "forward-backward"
Direction of filtering:
- forward and backward filter in just one direction, creating phase shifts in the signal.
- forward-backward filters in both directions, a zero-phase filtering.
Returns
-------
Expand All @@ -82,8 +82,7 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,
causal_mode=False,
direction="forward",
direction="forward-backward",
):
import scipy.signal

Expand Down Expand Up @@ -121,7 +120,6 @@ def __init__(
margin,
dtype,
add_reflect_padding=add_reflect_padding,
causal_mode=causal_mode,
direction=direction,
)
)
Expand All @@ -137,7 +135,6 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
causal_mode=causal_mode,
direction=direction,
)

Expand All @@ -151,13 +148,11 @@ def __init__(
margin,
dtype,
add_reflect_padding=False,
causal_mode=False,
direction="forward",
direction="forward-backward",
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.causal_mode = causal_mode
self.direction = direction
self.margin = margin
self.add_reflect_padding = add_reflect_padding
Expand All @@ -180,7 +175,13 @@ def get_traces(self, start_frame, end_frame, channel_indices):

import scipy.signal

if self.causal_mode:
if self.direction == "forward-backward":
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
else:
if self.direction == "backward":
traces_chunk = np.flip(traces_chunk, axis=0)

Expand All @@ -191,14 +192,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)

if self.direction == "backward":
filtered_traces = np.flip(filtered_traces, axis=0)

else:
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
filtered_traces = np.flip(filtered_traces, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -378,7 +372,6 @@ def __init__(
band=band,
margin_ms=margin_ms,
dtype=dtype,
causal_mode=True,
direction=direction,
**filter_kwargs,
)
Expand All @@ -388,7 +381,6 @@ def __init__(
band=band,
margin_ms=margin_ms,
dtype=dtype.str,
causal_mode=True,
direction=direction,
)
self._kwargs.update(filter_kwargs)
Expand Down

0 comments on commit 7c5956e

Please sign in to comment.