Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 3, 2024
1 parent f5cbc4d commit e595458
Showing 1 changed file with 64 additions and 23 deletions.
87 changes: 64 additions & 23 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
coeff=None,
dtype=None,
causal_mode=False,
direction="forward"
direction="forward",
):
import scipy.signal

Expand Down Expand Up @@ -115,7 +115,14 @@ def __init__(
for parent_segment in recording._recording_segments:
self.add_recording_segment(
FilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=add_reflect_padding
parent_segment,
filter_coeff,
filter_mode,
causal_mode,
direction,
margin,
dtype,
add_reflect_padding=add_reflect_padding,
)
)

Expand All @@ -136,7 +143,17 @@ def __init__(


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, causal_mode, direction, margin, dtype, add_reflect_padding=False):
def __init__(
self,
parent_recording_segment,
coeff,
filter_mode,
causal_mode,
direction,
margin,
dtype,
add_reflect_padding=False,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
Expand All @@ -162,26 +179,26 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk = traces_chunk.astype("float32")

import scipy.signal

if self.causal_mode:
if self.direction == "backward":
traces_chunk = np.flip(traces_chunk, axis = 0)
traces_chunk = np.flip(traces_chunk, axis=0)

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)

if self.direction == "backward":
filtered_traces = np.flip(filtered_traces, axis = 0)
filtered_traces = np.flip(filtered_traces, axis=0)

else:
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -315,11 +332,12 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)


class Causal_filter(FilterRecording):
"""
Performs causal filtering using:
* scipy.signal.lfilt or scipy.signal.sosfilt
Parameters
----------
recording : Recording
Expand All @@ -334,27 +352,50 @@ class Causal_filter(FilterRecording):
causal_mode : Bool, default: True
If true, filtering is applied in just one direction.
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering
when causal_mode = True, defines the direction of the filtering
Returns
-------
filter_recording : CausalFilterRecording
The causal-filtered recording extractor object
{}
"""

name = "causal_filter"

def __init__(self, recording, band=[300.0, 6000.0], margin_ms=5.0, dtype=None,causal_mode = True, direction = "Forward", **filter_kwargs):
def __init__(
self,
recording,
band=[300.0, 6000.0],
margin_ms=5.0,
dtype=None,
causal_mode=True,
direction="Forward",
**filter_kwargs,
):
FilterRecording.__init__(
self, recording, band=band, margin_ms=margin_ms, dtype=dtype, causal_mode = causal_mode, direction = direction ,**filter_kwargs
self,
recording,
band=band,
margin_ms=margin_ms,
dtype=dtype,
causal_mode=causal_mode,
direction=direction,
**filter_kwargs,
)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(
recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str, causal_mode = causal_mode, direction = direction
recording=recording,
band=band,
margin_ms=margin_ms,
dtype=dtype.str,
causal_mode=causal_mode,
direction=direction,
)
self._kwargs.update(filter_kwargs)



# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
Expand Down

0 comments on commit e595458

Please sign in to comment.