diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e28d995..974f6067 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Upcoming +### Features + +* Updated `Suite2pSegmentationExtractor` to support multi channel and multi plane data. [PR #242](https://github.com/catalystneuro/roiextractors/pull/242) + + + # v0.5.4 ### Features diff --git a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py index b5dfa62e..4ae952c0 100644 --- a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py +++ b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py @@ -8,10 +8,11 @@ import shutil from pathlib import Path from typing import Optional - +from warnings import warn +import os import numpy as np -from ...extraction_tools import PathType, IntType +from ...extraction_tools import PathType from ...extraction_tools import _image_mask_extractor from ...multisegmentationextractor import MultiSegmentationExtractor from ...segmentationextractor import SegmentationExtractor @@ -23,77 +24,191 @@ class Suite2pSegmentationExtractor(SegmentationExtractor): extractor_name = "Suite2pSegmentationExtractor" installed = True # check at class level if installed or not is_writable = False - mode = "file" + mode = "folder" installation_mesg = "" # error message when not installed + @classmethod + def get_available_channels(cls, folder_path: PathType): + """Get the available channel names from the folder paths produced by Suite2p. + + Parameters + ---------- + file_path : PathType + Path to Suite2p output path. + + Returns + ------- + channel_names: list + List of channel names. + """ + plane_names = cls.get_available_planes(folder_path=folder_path) + + channel_names = ["chan1"] + second_channel_paths = list((Path(folder_path) / plane_names[0]).glob("F_chan2.npy")) + if not second_channel_paths: + return channel_names + channel_names.append("chan2") + + return channel_names + + @classmethod + def get_available_planes(cls, folder_path: PathType): + """Get the available plane names from the folder produced by Suite2p. + + Parameters + ---------- + file_path : PathType + Path to Suite2p output path. + + Returns + ------- + plane_names: list + List of plane names. + """ + from natsort import natsorted + + folder_path = Path(folder_path) + prefix = "plane" + plane_paths = natsorted(folder_path.glob(pattern=prefix + "*")) + assert len(plane_paths), f"No planes found in '{folder_path}'." + plane_names = [plane_path.stem for plane_path in plane_paths] + return plane_names + def __init__( self, - folder_path: Optional[PathType] = None, - combined: bool = False, - plane_no: IntType = 0, - file_path: Optional[PathType] = None, + folder_path: PathType, + channel_name: Optional[str] = None, + plane_name: Optional[str] = None, + combined: Optional[bool] = None, # TODO: to be removed + plane_no: Optional[int] = None, # TODO: to be removed ): """Create SegmentationExtractor object out of suite 2p data type. Parameters ---------- folder_path: str or Path - ~/suite2p folder location on disk - combined: bool - if the plane is a combined plane as in the Suite2p pipeline - plane_no: int - the plane for which to extract segmentation for. - file_path: str or Path [Deprecated] - ~/suite2p folder location on disk + The path to the 'suite2p' folder. + channel_name: str, optional + The name of the channel to load, to determine what channels are available use Suite2pSegmentationExtractor.get_available_channels(folder_path). + plane_name: str, optional + The name of the plane to load, to determine what planes are available use Suite2pSegmentationExtractor.get_available_planes(folder_path). """ - from warnings import warn - - if file_path is not None: + if combined: + warning_string = "Keyword argument 'combined' is deprecated and will be removed on or after Nov, 2023. " + warn( + message=warning_string, + category=DeprecationWarning, + ) + if plane_no: warning_string = ( - "The keyword argument 'file_path' is being deprecated on or after August, 2022 in favor of 'folder_path'. " - "'folder_path' takes precence over 'file_path'." + "Keyword argument 'plane_no' is deprecated and will be removed on or after Nov, 2023 in favor of 'plane_name'." + "Specify which stream you wish to load with the 'plane_name' keyword argument." ) warn( message=warning_string, category=DeprecationWarning, ) - folder_path = file_path if folder_path is None else folder_path - SegmentationExtractor.__init__(self) - self.combined = combined - self.plane_no = plane_no + channel_names = self.get_available_channels(folder_path=folder_path) + if channel_name is None: + if len(channel_names) > 1: + # For backward compatibility maybe it is better to warn first + warn( + "More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. " + "To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`.", + UserWarning, + ) + channel_name = channel_names[0] + + self.channel_name = channel_name + if self.channel_name not in channel_names: + raise ValueError( + f"The selected channel '{channel_name}' is not a valid channel name. To see what channels are available, " + f"call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`." + ) + + plane_names = self.get_available_planes(folder_path=folder_path) + if plane_name is None: + if len(plane_names) > 1: + # For backward compatibility maybe it is better to warn first + warn( + "More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. " + "To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`.", + UserWarning, + ) + plane_name = plane_names[0] + + if plane_name not in plane_names: + raise ValueError( + f"The selected plane '{plane_name}' is not a valid plane name. To see what planes are available, " + f"call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`." + ) + self.plane_name = plane_name + + super().__init__() + self.folder_path = Path(folder_path) - self.stat = self._load_npy("stat.npy") - self._roi_response_raw = self._load_npy("F.npy", mmap_mode="r").T - self._roi_response_neuropil = self._load_npy("Fneu.npy", mmap_mode="r").T - self._roi_response_deconvolved = self._load_npy("spks.npy", mmap_mode="r").T + options = self._load_npy(file_name="ops.npy") + self.options = options.item() if options is not None else options + self._sampling_frequency = self.options["fs"] + self._num_frames = self.options["nframes"] + self._image_size = (self.options["Ly"], self.options["Lx"]) + + self.stat = self._load_npy(file_name="stat.npy") + + fluorescence_traces_file_name = "F.npy" if channel_name == "chan1" else "F_chan2.npy" + neuropil_traces_file_name = "Fneu.npy" if channel_name == "chan1" else "Fneu_chan2.npy" + self._roi_response_raw = self._load_npy(file_name=fluorescence_traces_file_name, mmap_mode="r", transpose=True) + self._roi_response_neuropil = self._load_npy(file_name=neuropil_traces_file_name, mmap_mode="r", transpose=True) + self._roi_response_deconvolved = ( + self._load_npy(file_name="spks.npy", mmap_mode="r", transpose=True) if channel_name == "chan1" else None + ) + self.iscell = self._load_npy("iscell.npy", mmap_mode="r") - self.ops = self._load_npy("ops.npy").item() - self._channel_names = [f"OpticalChannel{i}" for i in range(self.ops["nchannels"])] - self._sampling_frequency = self.ops["fs"] * [2 if self.combined else 1][0] - self._raw_movie_file_location = self.ops.get("filelist", [None])[0] - self._image_correlation = self._summary_image_read("Vcorr") - self._image_mean = self._summary_image_read("meanImg") + channel_name = "OpticalChannel" if len(channel_names) == 1 else channel_name.capitalize() + self._channel_names = [channel_name] + + self._image_correlation = self._correlation_image_read() + image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2" + self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None + roi_indices = list(range(self.get_num_rois())) + self._image_masks = _image_mask_extractor( + self.get_roi_pixel_masks(), + roi_indices, + self.get_image_size(), + ) - def _load_npy(self, filename, mmap_mode=None): - """Load a .npy file with specified filename. + def _load_npy(self, file_name: str, mmap_mode=None, transpose: bool = False): + """Load a .npy file with specified filename. Returns None if file is missing. Parameters ---------- - filename: str + file_name: str The name of the .npy file to load. mmap_mode: str The mode to use for memory mapping. See numpy.load for details. + transpose: bool, optional + Whether to transpose the loaded array. Returns ------- - The loaded .npy file. + The loaded .npy file. """ - file_path = self.folder_path / f"plane{self.plane_no}" / filename - return np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None) + file_path = self.folder_path / self.plane_name / file_name + if not file_path.exists(): + return + + data = np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None) + if transpose: + return data.T + + return data + + def get_num_frames(self) -> int: + return self._num_frames def get_accepted_list(self): return list(np.where(self.iscell[:, 0] == 1)[0]) @@ -101,29 +216,27 @@ def get_accepted_list(self): def get_rejected_list(self): return list(np.where(self.iscell[:, 0] == 0)[0]) - def _summary_image_read(self, bstr="meanImg"): - """Read summary image from ops (settings) dict. - - Parameters - ---------- - bstr: str - The name of the summary image to read. + def _correlation_image_read(self): + """Read correlation image from ops (settings) dict. Returns ------- img : numpy.ndarray | None - The summary image if bstr is in ops, else None. + The correlation image. """ - img = None - if bstr in self.ops: - if bstr == "Vcorr" or bstr == "max_proj": - img = np.zeros((self.ops["Ly"], self.ops["Lx"]), np.float32) - img[ - (self.ops["Ly"] - self.ops["yrange"][-1]) : (self.ops["Ly"] - self.ops["yrange"][0]), - self.ops["xrange"][0] : self.ops["xrange"][-1], - ] = self.ops[bstr] - else: - img = self.ops[bstr] + if "Vcorr" not in self.options: + return None + + correlation_image = self.options["Vcorr"] + if (self.options["yrange"][-1], self.options["xrange"][-1]) == self._image_size: + return correlation_image + + img = np.zeros(self._image_size, correlation_image.dtype) + img[ + (self.options["Ly"] - self.options["yrange"][-1]) : (self.options["Ly"] - self.options["yrange"][0]), + self.options["xrange"][0] : self.options["xrange"][-1], + ] = correlation_image + return img @property @@ -131,26 +244,13 @@ def roi_locations(self): """Returns the center locations (x, y) of each ROI.""" return np.array([j["med"] for j in self.stat]).T.astype(int) - def get_roi_image_masks(self, roi_ids=None): - if roi_ids is None: - roi_idx_ = range(self.get_num_rois()) - else: - roi_idx = [np.where(np.array(i) == self.get_roi_ids())[0] for i in roi_ids] - ele = [i for i, j in enumerate(roi_idx) if j.size == 0] - roi_idx_ = [j[0] for i, j in enumerate(roi_idx) if i not in ele] - return _image_mask_extractor( - self.get_roi_pixel_masks(roi_ids=roi_idx_), - list(range(len(roi_idx_))), - self.get_image_size(), - ) - def get_roi_pixel_masks(self, roi_ids=None): pixel_mask = [] for i in range(self.get_num_rois()): pixel_mask.append( np.vstack( [ - self.ops["Ly"] - 1 - self.stat[i]["ypix"], + self.stat[i]["ypix"], self.stat[i]["xpix"], self.stat[i]["lam"], ] @@ -165,7 +265,7 @@ def get_roi_pixel_masks(self, roi_ids=None): return [pixel_mask[i] for i in roi_idx_] def get_image_size(self): - return [self.ops["Ly"], self.ops["Lx"]] + return self._image_size @staticmethod def write_segmentation(segmentation_object: SegmentationExtractor, save_path: PathType, overwrite=True): @@ -238,7 +338,7 @@ def write_segmentation(segmentation_object: SegmentationExtractor, save_path: Pa for no, i in enumerate(stat): stat[no] = { "med": roi_locs[no, :].tolist(), - "ypix": segmentation_object.get_image_size()[0] - 1 - pixel_masks[no][:, 0], + "ypix": pixel_masks[no][:, 0], "xpix": pixel_masks[no][:, 1], "lam": pixel_masks[no][:, 2], } diff --git a/tests/test_io.py b/tests/test_io.py index f2692b2d..e1cbb64b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -138,11 +138,11 @@ def test_imaging_extractors_canonical_shape(self, extractor_class, extractor_kwa ), param( extractor_class=Suite2pSegmentationExtractor, - extractor_kwargs=dict(folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")), - ), - param( - extractor_class=Suite2pSegmentationExtractor, - extractor_kwargs=dict(file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")), + extractor_kwargs=dict( + folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), + channel_name="chan1", + plane_name="plane0", + ), ), ] diff --git a/tests/test_suite2psegmentationextractor.py b/tests/test_suite2psegmentationextractor.py new file mode 100644 index 00000000..59dae738 --- /dev/null +++ b/tests/test_suite2psegmentationextractor.py @@ -0,0 +1,127 @@ +import shutil +import tempfile +from pathlib import Path + +import numpy as np +from hdmf.testing import TestCase +from numpy.testing import assert_array_equal + +from roiextractors import Suite2pSegmentationExtractor +from roiextractors.extraction_tools import _image_mask_extractor +from tests.setup_paths import OPHYS_DATA_PATH + + +class TestSuite2pSegmentationExtractor(TestCase): + @classmethod + def setUpClass(cls): + folder_path = str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p") + cls.channel_names = ["chan1", "chan2"] + cls.plane_names = ["plane0", "plane1"] + + cls.folder_path = Path(folder_path) + + extractor = Suite2pSegmentationExtractor(folder_path=folder_path, channel_name="chan1", plane_name="plane0") + cls.extractor = extractor + + cls.test_dir = Path(tempfile.mkdtemp()) + + cls.first_channel_raw_traces = np.load(cls.folder_path / "plane0" / "F.npy").T + cls.second_channel_raw_traces = np.load(cls.folder_path / "plane0" / "F_chan2.npy").T + + cls.image_size = (128, 128) + cls.num_rois = 15 + + pixel_masks = cls.extractor.get_roi_pixel_masks() + image_masks = np.zeros(shape=(*cls.image_size, cls.num_rois)) + for roi_ind, pixel_mask in enumerate(pixel_masks): + for y, x, wt in pixel_mask: + image_masks[int(y), int(x), roi_ind] = wt + cls.image_masks = image_masks + + @classmethod + def tearDownClass(cls): + # remove the temporary directory and its contents + shutil.rmtree(cls.test_dir) + + def test_channel_names(self): + self.assertEqual( + Suite2pSegmentationExtractor.get_available_channels(folder_path=self.folder_path), self.channel_names + ) + + def test_plane_names(self): + self.assertEqual( + Suite2pSegmentationExtractor.get_available_planes(folder_path=self.folder_path), self.plane_names + ) + + def test_multi_channel_warns(self): + exc_msg = "More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`." + with self.assertWarnsWith(warn_type=UserWarning, exc_msg=exc_msg): + Suite2pSegmentationExtractor(folder_path=self.folder_path) + + def test_multi_plane_warns(self): + exc_msg = "More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`." + with self.assertWarnsWith(warn_type=UserWarning, exc_msg=exc_msg): + Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="chan2") + + def test_incorrect_plane_name_raises(self): + exc_msg = "The selected plane 'plane2' is not a valid plane name. To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`." + with self.assertRaisesWith(exc_type=ValueError, exc_msg=exc_msg): + Suite2pSegmentationExtractor(folder_path=self.folder_path, plane_name="plane2") + + def test_incorrect_channel_name_raises(self): + exc_msg = "The selected channel 'test' is not a valid channel name. To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`." + with self.assertRaisesWith(exc_type=ValueError, exc_msg=exc_msg): + Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="test") + + def test_incomplete_extractor_load(self): + """Check extractor can be initialized when not all traces are available.""" + # temporary directory for testing assertion when some of the files are missing + files_to_copy = ["stat.npy", "ops.npy", "iscell.npy", "Fneu.npy"] + (self.test_dir / "plane0").mkdir(exist_ok=True) + [ + shutil.copy(Path(self.folder_path) / "plane0" / file, self.test_dir / "plane0" / file) + for file in files_to_copy + ] + + extractor = Suite2pSegmentationExtractor(folder_path=self.test_dir) + traces_dict = extractor.get_traces_dict() + self.assertEqual(traces_dict["raw"], None) + self.assertEqual(traces_dict["dff"], None) + self.assertEqual(traces_dict["deconvolved"], None) + + def test_image_size(self): + self.assertEqual(self.extractor.get_image_size(), self.image_size) + + def test_num_frames(self): + self.assertEqual(self.extractor.get_num_frames(), 250) + + def test_sampling_frequency(self): + self.assertEqual(self.extractor.get_sampling_frequency(), 10.0) + + def test_channel_names(self): + self.assertEqual(self.extractor.get_channel_names(), ["Chan1"]) + + def test_num_channels(self): + self.assertEqual(self.extractor.get_num_channels(), 1) + + def test_num_rois(self): + self.assertEqual(self.extractor.get_num_rois(), self.num_rois) + + def test_extractor_first_channel_raw_traces(self): + assert_array_equal(self.extractor.get_traces(name="raw"), self.first_channel_raw_traces) + + def test_extractor_second_channel(self): + extractor = Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="chan2") + self.assertEqual(extractor.get_channel_names(), ["Chan2"]) + traces = extractor.get_traces_dict() + self.assertEqual(traces["deconvolved"], None) + assert_array_equal(traces["raw"], self.second_channel_raw_traces) + + def test_extractor_image_masks(self): + """Test that the image masks are correctly extracted.""" + assert_array_equal(self.extractor.get_roi_image_masks(), self.image_masks) + + def test_extractor_image_masks_selected_rois(self): + """Test that the image masks are correctly extracted for a subset of ROIs.""" + roi_indices = list(range(5)) + assert_array_equal(self.extractor.get_roi_image_masks(roi_ids=roi_indices), self.image_masks[..., roi_indices])