From f46f9c9391bdf3a11a0fef953c2eb9d097dd15ab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 11:55:08 +0200 Subject: [PATCH 1/3] Fix serializability of InjectDriftingTemplatesRecording --- src/spikeinterface/core/base.py | 11 +++++------ src/spikeinterface/core/generate.py | 2 ++ src/spikeinterface/generation/drift_tools.py | 3 +++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5800166f39..304a85e74f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -7,7 +7,6 @@ import weakref import json import pickle -import os import random import string from packaging.version import parse @@ -928,13 +927,14 @@ def save_to_folder( folder.mkdir(parents=True, exist_ok=False) # dump provenance - provenance_file = folder / f"provenance.json" if self.check_serializability("json"): + provenance_file = folder / f"provenance.json" + self.dump(provenance_file) + elif self.check_serializability("pickle"): + provenance_file = folder / f"provenance.pkl" self.dump(provenance_file) else: - provenance_file.write_text( - json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" - ) + warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") self.save_metadata_to_folder(folder) @@ -1001,7 +1001,6 @@ def save_to_zarr( cached: ZarrExtractor Saved copy of the extractor. """ - import zarr from .zarrextractors import read_zarr save_kwargs.pop("format", None) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 11909bce0e..4e265f3766 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1738,6 +1738,8 @@ def __init__( ) self.add_recording_segment(recording_segment) + # to discuss: maybe we could set json serializability to False always + # because templates could be large! if not sorting.check_serializability("json"): self._serializability["json"] = False if parent_recording is not None: diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index cce2e08b58..70e13160f4 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -458,6 +458,9 @@ def __init__( self.set_probe(drifting_templates.probe, in_place=True) + # templates are too large, we don't serialize them to JSON + self._serializability["json"] = False + self._kwargs = { "sorting": sorting, "drifting_templates": drifting_templates, From b0b8b9aac2e480e102bbdc4980955d28778bd919 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 12:55:58 +0200 Subject: [PATCH 2/3] Fix select peaks --- .../sortingcomponents/peak_selection.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index 1ccfbc4d22..fed026b6a7 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -76,19 +76,18 @@ def select_peaks( selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs) selected_peaks = peaks[selected_indices] + num_segments = len(np.unique(selected_peaks["segment_index"])) if margin is not None: to_keep = np.zeros(len(selected_peaks), dtype=bool) - offset = 0 - for segment_index in range(recording.get_num_segments()): - duration = recording.get_num_frames(segment_index) + for segment_index in range(num_segments): + num_samples_in_segment = recording.get_num_samples(segment_index) i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1]) - while selected_peaks["sample_index"][i0] <= margin[0] + offset: + while selected_peaks["sample_index"][i0] <= margin[0]: i0 += 1 - while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset: + while selected_peaks["sample_index"][i1 - 1] >= (num_samples_in_segment - margin[1]): i1 -= 1 to_keep[i0:i1] = True - offset += duration selected_indices = selected_indices[to_keep] selected_peaks = peaks[selected_indices] @@ -284,7 +283,9 @@ def select_peak_indices(peaks, method, seed, **method_kwargs): ) selected_indices = np.concatenate(selected_indices) - selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])] + selected_indices = selected_indices[ + np.lexsort((peaks[selected_indices]["sample_index"], peaks[selected_indices]["segment_index"])) + ] return selected_indices From 6501252bbea1ed8f6f3fa955ca89f2363f08c409 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 19:57:11 +0200 Subject: [PATCH 3/3] Revert "Fix select peaks" This reverts commit b0b8b9aac2e480e102bbdc4980955d28778bd919. --- .../sortingcomponents/peak_selection.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index fed026b6a7..1ccfbc4d22 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -76,18 +76,19 @@ def select_peaks( selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs) selected_peaks = peaks[selected_indices] - num_segments = len(np.unique(selected_peaks["segment_index"])) if margin is not None: to_keep = np.zeros(len(selected_peaks), dtype=bool) - for segment_index in range(num_segments): - num_samples_in_segment = recording.get_num_samples(segment_index) + offset = 0 + for segment_index in range(recording.get_num_segments()): + duration = recording.get_num_frames(segment_index) i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1]) - while selected_peaks["sample_index"][i0] <= margin[0]: + while selected_peaks["sample_index"][i0] <= margin[0] + offset: i0 += 1 - while selected_peaks["sample_index"][i1 - 1] >= (num_samples_in_segment - margin[1]): + while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset: i1 -= 1 to_keep[i0:i1] = True + offset += duration selected_indices = selected_indices[to_keep] selected_peaks = peaks[selected_indices] @@ -283,9 +284,7 @@ def select_peak_indices(peaks, method, seed, **method_kwargs): ) selected_indices = np.concatenate(selected_indices) - selected_indices = selected_indices[ - np.lexsort((peaks[selected_indices]["sample_index"], peaks[selected_indices]["segment_index"])) - ] + selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])] return selected_indices