Skip to content

Commit

Permalink
Update artifact detection
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Oct 30, 2023
1 parent 6afc02b commit 52ab117
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 73 deletions.
31 changes: 31 additions & 0 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,34 @@ def interval_set_difference_inds(intervals1, intervals2):
i += 1
result += intervals1[i:]
return result


def interval_list_complement(intervals1, intervals2):
"Finds intervals in intervals1 that are not in intervals2"
result = []

for start1, end1 in intervals1:
subtracted = [(start1, end1)]

for start2, end2 in intervals2:
new_subtracted = []

for s, e in subtracted:
if start2 <= s and e <= end2:
continue

if e <= start2 or end2 <= s:
new_subtracted.append((s, e))
continue

if start2 > s:
new_subtracted.append((s, start2))

if end2 < e:
new_subtracted.append((end2, e))

subtracted = new_subtracted

result.extend(subtracted)

return np.asarray(result)
100 changes: 27 additions & 73 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from functools import reduce
from typing import Union
from typing import Union, List

import datajoint as dj
import numpy as np
Expand All @@ -14,6 +14,7 @@
IntervalList,
_union_concat,
interval_from_inds,
interval_list_complement,
)
from spyglass.spikesorting.v1.utils import generate_nwb_uuid
from spyglass.spikesorting.v1.recording import (
Expand Down Expand Up @@ -119,7 +120,19 @@ def make(self, key):
* ArtifactDetectionSelection
& key
).fetch1("artifact_params", "analysis_file_name")

sort_interval_valid_times = (
IntervalList
& {
"nwb_file_name": (
SpikeSortingRecordingSelection * ArtifactDetectionSelection
& key
).fetch1("nwb_file_name"),
"interval_list_name": (
SpikeSortingRecordingSelection * ArtifactDetectionSelection
& key
).fetch1("interval_list_name"),
}
).fetch1("valid_times")
# DO:
# - load recording
recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
Expand All @@ -128,9 +141,11 @@ def make(self, key):
recording = se.read_nwb_recording(
recording_analysis_nwb_file_abs_path, load_time_vector=True
)

# - detect artifacts
artifact_removed_valid_times, _ = _get_artifact_times(
recording,
sort_interval_valid_times,
**artifact_params,
)

Expand All @@ -153,6 +168,7 @@ def make(self, key):

def _get_artifact_times(
recording: si.BaseRecording,
sort_interval_valid_times: List[List],
zscore_thresh: Union[float, None] = None,
amplitude_thresh_uV: Union[float, None] = None,
proportion_above_thresh: float = 1.0,
Expand All @@ -171,6 +187,8 @@ def _get_artifact_times(
Parameters
----------
recording : si.BaseRecording
sort_interval_valid_times : List[List]
The sort interval for the recording, unit: seconds
zscore_thresh : float, optional
Stdev threshold for exclusion, should be >=0, defaults to None
amplitude_thresh_uV : float, optional
Expand Down Expand Up @@ -243,6 +261,7 @@ def _get_artifact_times(

artifact_frames = executor.run()
artifact_frames = np.concatenate(artifact_frames)
print(f"artifact_frames: {artifact_frames}")

# turn ms to remove total into s to remove from either side of each detected artifact
half_removal_window_s = removal_window_ms / 2 / 1000
Expand All @@ -257,6 +276,7 @@ def _get_artifact_times(

# convert indices to intervals
artifact_intervals = interval_from_inds(artifact_frames)
print(f"artifact_intervals: {artifact_intervals}")

# convert to seconds and pad with window
artifact_intervals_s = np.zeros(
Expand All @@ -282,8 +302,11 @@ def _get_artifact_times(
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)

# find non-artifact intervals in timestamps
artifact_removed_valid_times = find_missing_intervals(
artifact_intervals_s, valid_timestamps
artifact_removed_valid_times = interval_list_complement(
sort_interval_valid_times, artifact_intervals_s
)
artifact_removed_valid_times = reduce(
_union_concat, artifact_removed_valid_times
)

return artifact_removed_valid_times, artifact_intervals_s
Expand Down Expand Up @@ -401,75 +424,6 @@ def _check_artifact_thresholds(
return amplitude_thresh_uV, zscore_thresh, proportion_above_thresh


def find_missing_intervals(intervals, timestamps):
"""Given a list of intervals each of which is [start_time, end_time] and an array of timestamps,
find intervals are not contained in the input list of intervals but contained in the array of timestamps.
Note that the start and stop times of such intervals must be explicitly contained in the array of timestamps
Parameters
----------
intervals : _type_
_description_
timestamps : _type_
_description_
Returns
-------
_type_
_description_
"""
# Sort the list of intervals and timestamps
intervals.sort()
timestamps.sort()

missing_intervals = []
timestamp_idx = 0

# Initialize an empty interval
new_interval = []

for start, end in intervals:
# Look for potential missing intervals
while (
timestamp_idx < len(timestamps)
and timestamps[timestamp_idx] < start
):
new_interval.append(timestamps[timestamp_idx])
timestamp_idx += 1

if len(new_interval) == 1:
continue

if timestamps[timestamp_idx] > new_interval[-1]:
new_interval.append(timestamps[timestamp_idx - 1])
missing_intervals.append(new_interval)
new_interval = []

# Move the index to the point after the end of the current interval
while (
timestamp_idx < len(timestamps) and timestamps[timestamp_idx] <= end
):
timestamp_idx += 1

# Check for any remaining missing intervals
while timestamp_idx < len(timestamps):
new_interval.append(timestamps[timestamp_idx])
timestamp_idx += 1

if len(new_interval) == 1:
continue

if (
timestamp_idx == len(timestamps)
or timestamps[timestamp_idx] > new_interval[-1]
):
new_interval.append(timestamps[timestamp_idx - 1])
missing_intervals.append(new_interval)
new_interval = []

return np.asarray(missing_intervals)


def merge_intervals(intervals):
"""Takes a list of intervals each of which is [start_time, stop_time]
and takes union over intervals that are intersecting
Expand Down

0 comments on commit 52ab117

Please sign in to comment.