Skip to content

Commit

Permalink
Minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
khl02007 committed Oct 30, 2023
1 parent 52ab117 commit 7fabe29
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
16 changes: 13 additions & 3 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,16 @@ def interval_set_difference_inds(intervals1, intervals2):
return result


def interval_list_complement(intervals1, intervals2):
"Finds intervals in intervals1 that are not in intervals2"
def interval_list_complement(intervals1, intervals2, min_length=0.0):
"""
Finds intervals in intervals1 that are not in intervals2
Parameters
----------
min_length : float, optional
Minimum interval length in seconds. Defaults to 0.0.
"""

result = []

for start1, end1 in intervals1:
Expand All @@ -536,4 +544,6 @@ def interval_list_complement(intervals1, intervals2):

result.extend(subtracted)

return np.asarray(result)
return intervals_by_length(
np.asarray(result), min_length=min_length, max_length=1e100
)
4 changes: 1 addition & 3 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def _get_artifact_times(

artifact_frames = executor.run()
artifact_frames = np.concatenate(artifact_frames)
print(f"artifact_frames: {artifact_frames}")

# turn ms to remove total into s to remove from either side of each detected artifact
half_removal_window_s = removal_window_ms / 2 / 1000
Expand All @@ -276,7 +275,6 @@ def _get_artifact_times(

# convert indices to intervals
artifact_intervals = interval_from_inds(artifact_frames)
print(f"artifact_intervals: {artifact_intervals}")

# convert to seconds and pad with window
artifact_intervals_s = np.zeros(
Expand All @@ -303,7 +301,7 @@ def _get_artifact_times(

# find non-artifact intervals in timestamps
artifact_removed_valid_times = interval_list_complement(
sort_interval_valid_times, artifact_intervals_s
sort_interval_valid_times, artifact_intervals_s, min_length=1
)
artifact_removed_valid_times = reduce(
_union_concat, artifact_removed_valid_times
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def make(self, key):
"valid_times": sort_interval_valid_times,
}
)
AnalysisNwbfile().add((SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), key["analysis_file_name"])
AnalysisNwbfile().add(
(SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"),
key["analysis_file_name"],
)
self.insert1(key)

@classmethod
Expand Down

0 comments on commit 7fabe29

Please sign in to comment.