Skip to content

Commit

Permalink
Expose clear_Cache argument in KS4
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 5, 2024
1 parent 0ed4876 commit 8fbf100
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
24 changes: 24 additions & 0 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,30 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp
with pytest.raises(AssertionError):
check_sortings_equal(default_kilosort_sorting, sorting_si)

def test_clear_cache(self,recording_and_paths, tmp_path):
"""
Test clear_cache parameter in kilosort4.run_kilosort
"""
recording, paths = recording_and_paths

spikeinterface_output_dir = tmp_path / "spikeinterface_output_clear"
sorting_si_clear = si.run_sorter(
"kilosort4",
recording,
remove_existing_folder=True,
folder=spikeinterface_output_dir,
clear_cache=True
)
spikeinterface_output_dir = tmp_path / "spikeinterface_output_no_clear"
sorting_si_no_clear = si.run_sorter(
"kilosort4",
recording,
remove_existing_folder=True,
folder=spikeinterface_output_dir,
clear_cache=False
)
check_sortings_equal(sorting_si_clear, sorting_si_no_clear)

def test_kilosort4_no_correction(self, recording_and_paths, tmp_path):
"""
Test the SpikeInterface wrappers `do_correction` argument. We set
Expand Down
23 changes: 20 additions & 3 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Kilosort4Sorter(BaseSorter):
"save_preprocessed_copy": False,
"torch_device": "auto",
"bad_channels": None,
"clear_cache": False,
"use_binary_file": None,
"delete_recording_dat": True,
}
Expand Down Expand Up @@ -111,6 +112,7 @@ class Kilosort4Sorter(BaseSorter):
"save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data",
"torch_device": "Select the torch device auto/cuda/cpu",
"bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.",
"clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.",
"use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. "
"If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. "
"Default is None.",
Expand Down Expand Up @@ -284,6 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
data_dir = ""
results_dir = sorter_output_folder
bad_channels = params["bad_channels"]
clear_cache = params["clear_cache"]

filename, data_dir, results_dir, probe = set_files(
settings=settings,
Expand Down Expand Up @@ -347,17 +350,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# this function applies both preprocessing and drift correction
ops, bfile, st0 = compute_drift_correction(
ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object
ops=ops,
device=device,
tic0=tic0,
progress_bar=progress_bar,
file_object=file_object,
clear_cache=clear_cache,
)

if save_preprocessed_copy:
save_preprocessing(results_dir / "temp_wh.dat", ops, bfile)

# Sort spikes and save results
st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar)
st, tF, _, _ = detect_spikes(
ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache
)

clu, Wall = cluster_spikes(
st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar
st=st,
tF=tF,
ops=ops,
device=device,
bfile=bfile,
tic0=tic0,
progress_bar=progress_bar,
clear_cache=clear_cache,
)

if params["skip_kilosort_preprocessing"]:
Expand Down

0 comments on commit 8fbf100

Please sign in to comment.