Skip to content

Commit

Permalink
Merge pull request #242 from catalystneuro/update_suite2psegmentation…
Browse files Browse the repository at this point in the history
…extractor

Update `Suite2pSegmentationExtractor` to support multi channel and multi plane outputs
  • Loading branch information
CodyCBakerPhD authored Nov 6, 2023
2 parents 80ecc1f + c26c4ca commit 8041fb0
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 78 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Upcoming

### Features

* Updated `Suite2pSegmentationExtractor` to support multi channel and multi plane data. [PR #242](https://github.com/catalystneuro/roiextractors/pull/242)



# v0.5.4

### Features
Expand Down
246 changes: 173 additions & 73 deletions src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import shutil
from pathlib import Path
from typing import Optional

from warnings import warn
import os
import numpy as np

from ...extraction_tools import PathType, IntType
from ...extraction_tools import PathType
from ...extraction_tools import _image_mask_extractor
from ...multisegmentationextractor import MultiSegmentationExtractor
from ...segmentationextractor import SegmentationExtractor
Expand All @@ -23,134 +24,233 @@ class Suite2pSegmentationExtractor(SegmentationExtractor):
extractor_name = "Suite2pSegmentationExtractor"
installed = True # check at class level if installed or not
is_writable = False
mode = "file"
mode = "folder"
installation_mesg = "" # error message when not installed

@classmethod
def get_available_channels(cls, folder_path: PathType):
"""Get the available channel names from the folder paths produced by Suite2p.
Parameters
----------
file_path : PathType
Path to Suite2p output path.
Returns
-------
channel_names: list
List of channel names.
"""
plane_names = cls.get_available_planes(folder_path=folder_path)

channel_names = ["chan1"]
second_channel_paths = list((Path(folder_path) / plane_names[0]).glob("F_chan2.npy"))
if not second_channel_paths:
return channel_names
channel_names.append("chan2")

return channel_names

@classmethod
def get_available_planes(cls, folder_path: PathType):
"""Get the available plane names from the folder produced by Suite2p.
Parameters
----------
file_path : PathType
Path to Suite2p output path.
Returns
-------
plane_names: list
List of plane names.
"""
from natsort import natsorted

folder_path = Path(folder_path)
prefix = "plane"
plane_paths = natsorted(folder_path.glob(pattern=prefix + "*"))
assert len(plane_paths), f"No planes found in '{folder_path}'."
plane_names = [plane_path.stem for plane_path in plane_paths]
return plane_names

def __init__(
self,
folder_path: Optional[PathType] = None,
combined: bool = False,
plane_no: IntType = 0,
file_path: Optional[PathType] = None,
folder_path: PathType,
channel_name: Optional[str] = None,
plane_name: Optional[str] = None,
combined: Optional[bool] = None, # TODO: to be removed
plane_no: Optional[int] = None, # TODO: to be removed
):
"""Create SegmentationExtractor object out of suite 2p data type.
Parameters
----------
folder_path: str or Path
~/suite2p folder location on disk
combined: bool
if the plane is a combined plane as in the Suite2p pipeline
plane_no: int
the plane for which to extract segmentation for.
file_path: str or Path [Deprecated]
~/suite2p folder location on disk
The path to the 'suite2p' folder.
channel_name: str, optional
The name of the channel to load, to determine what channels are available use Suite2pSegmentationExtractor.get_available_channels(folder_path).
plane_name: str, optional
The name of the plane to load, to determine what planes are available use Suite2pSegmentationExtractor.get_available_planes(folder_path).
"""
from warnings import warn

if file_path is not None:
if combined:
warning_string = "Keyword argument 'combined' is deprecated and will be removed on or after Nov, 2023. "
warn(
message=warning_string,
category=DeprecationWarning,
)
if plane_no:
warning_string = (
"The keyword argument 'file_path' is being deprecated on or after August, 2022 in favor of 'folder_path'. "
"'folder_path' takes precence over 'file_path'."
"Keyword argument 'plane_no' is deprecated and will be removed on or after Nov, 2023 in favor of 'plane_name'."
"Specify which stream you wish to load with the 'plane_name' keyword argument."
)
warn(
message=warning_string,
category=DeprecationWarning,
)
folder_path = file_path if folder_path is None else folder_path

SegmentationExtractor.__init__(self)
self.combined = combined
self.plane_no = plane_no
channel_names = self.get_available_channels(folder_path=folder_path)
if channel_name is None:
if len(channel_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. "
"To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`.",
UserWarning,
)
channel_name = channel_names[0]

self.channel_name = channel_name
if self.channel_name not in channel_names:
raise ValueError(
f"The selected channel '{channel_name}' is not a valid channel name. To see what channels are available, "
f"call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`."
)

plane_names = self.get_available_planes(folder_path=folder_path)
if plane_name is None:
if len(plane_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. "
"To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`.",
UserWarning,
)
plane_name = plane_names[0]

if plane_name not in plane_names:
raise ValueError(
f"The selected plane '{plane_name}' is not a valid plane name. To see what planes are available, "
f"call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`."
)
self.plane_name = plane_name

super().__init__()

self.folder_path = Path(folder_path)

self.stat = self._load_npy("stat.npy")
self._roi_response_raw = self._load_npy("F.npy", mmap_mode="r").T
self._roi_response_neuropil = self._load_npy("Fneu.npy", mmap_mode="r").T
self._roi_response_deconvolved = self._load_npy("spks.npy", mmap_mode="r").T
options = self._load_npy(file_name="ops.npy")
self.options = options.item() if options is not None else options
self._sampling_frequency = self.options["fs"]
self._num_frames = self.options["nframes"]
self._image_size = (self.options["Ly"], self.options["Lx"])

self.stat = self._load_npy(file_name="stat.npy")

fluorescence_traces_file_name = "F.npy" if channel_name == "chan1" else "F_chan2.npy"
neuropil_traces_file_name = "Fneu.npy" if channel_name == "chan1" else "Fneu_chan2.npy"
self._roi_response_raw = self._load_npy(file_name=fluorescence_traces_file_name, mmap_mode="r", transpose=True)
self._roi_response_neuropil = self._load_npy(file_name=neuropil_traces_file_name, mmap_mode="r", transpose=True)
self._roi_response_deconvolved = (
self._load_npy(file_name="spks.npy", mmap_mode="r", transpose=True) if channel_name == "chan1" else None
)

self.iscell = self._load_npy("iscell.npy", mmap_mode="r")
self.ops = self._load_npy("ops.npy").item()

self._channel_names = [f"OpticalChannel{i}" for i in range(self.ops["nchannels"])]
self._sampling_frequency = self.ops["fs"] * [2 if self.combined else 1][0]
self._raw_movie_file_location = self.ops.get("filelist", [None])[0]
self._image_correlation = self._summary_image_read("Vcorr")
self._image_mean = self._summary_image_read("meanImg")
channel_name = "OpticalChannel" if len(channel_names) == 1 else channel_name.capitalize()
self._channel_names = [channel_name]

self._image_correlation = self._correlation_image_read()
image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2"
self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None
roi_indices = list(range(self.get_num_rois()))
self._image_masks = _image_mask_extractor(
self.get_roi_pixel_masks(),
roi_indices,
self.get_image_size(),
)

def _load_npy(self, filename, mmap_mode=None):
"""Load a .npy file with specified filename.
def _load_npy(self, file_name: str, mmap_mode=None, transpose: bool = False):
"""Load a .npy file with specified filename. Returns None if file is missing.
Parameters
----------
filename: str
file_name: str
The name of the .npy file to load.
mmap_mode: str
The mode to use for memory mapping. See numpy.load for details.
transpose: bool, optional
Whether to transpose the loaded array.
Returns
-------
The loaded .npy file.
The loaded .npy file.
"""
file_path = self.folder_path / f"plane{self.plane_no}" / filename
return np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None)
file_path = self.folder_path / self.plane_name / file_name
if not file_path.exists():
return

data = np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None)
if transpose:
return data.T

return data

def get_num_frames(self) -> int:
return self._num_frames

def get_accepted_list(self):
return list(np.where(self.iscell[:, 0] == 1)[0])

def get_rejected_list(self):
return list(np.where(self.iscell[:, 0] == 0)[0])

def _summary_image_read(self, bstr="meanImg"):
"""Read summary image from ops (settings) dict.
Parameters
----------
bstr: str
The name of the summary image to read.
def _correlation_image_read(self):
"""Read correlation image from ops (settings) dict.
Returns
-------
img : numpy.ndarray | None
The summary image if bstr is in ops, else None.
The correlation image.
"""
img = None
if bstr in self.ops:
if bstr == "Vcorr" or bstr == "max_proj":
img = np.zeros((self.ops["Ly"], self.ops["Lx"]), np.float32)
img[
(self.ops["Ly"] - self.ops["yrange"][-1]) : (self.ops["Ly"] - self.ops["yrange"][0]),
self.ops["xrange"][0] : self.ops["xrange"][-1],
] = self.ops[bstr]
else:
img = self.ops[bstr]
if "Vcorr" not in self.options:
return None

correlation_image = self.options["Vcorr"]
if (self.options["yrange"][-1], self.options["xrange"][-1]) == self._image_size:
return correlation_image

img = np.zeros(self._image_size, correlation_image.dtype)
img[
(self.options["Ly"] - self.options["yrange"][-1]) : (self.options["Ly"] - self.options["yrange"][0]),
self.options["xrange"][0] : self.options["xrange"][-1],
] = correlation_image

return img

@property
def roi_locations(self):
"""Returns the center locations (x, y) of each ROI."""
return np.array([j["med"] for j in self.stat]).T.astype(int)

def get_roi_image_masks(self, roi_ids=None):
if roi_ids is None:
roi_idx_ = range(self.get_num_rois())
else:
roi_idx = [np.where(np.array(i) == self.get_roi_ids())[0] for i in roi_ids]
ele = [i for i, j in enumerate(roi_idx) if j.size == 0]
roi_idx_ = [j[0] for i, j in enumerate(roi_idx) if i not in ele]
return _image_mask_extractor(
self.get_roi_pixel_masks(roi_ids=roi_idx_),
list(range(len(roi_idx_))),
self.get_image_size(),
)

def get_roi_pixel_masks(self, roi_ids=None):
pixel_mask = []
for i in range(self.get_num_rois()):
pixel_mask.append(
np.vstack(
[
self.ops["Ly"] - 1 - self.stat[i]["ypix"],
self.stat[i]["ypix"],
self.stat[i]["xpix"],
self.stat[i]["lam"],
]
Expand All @@ -165,7 +265,7 @@ def get_roi_pixel_masks(self, roi_ids=None):
return [pixel_mask[i] for i in roi_idx_]

def get_image_size(self):
return [self.ops["Ly"], self.ops["Lx"]]
return self._image_size

@staticmethod
def write_segmentation(segmentation_object: SegmentationExtractor, save_path: PathType, overwrite=True):
Expand Down Expand Up @@ -238,7 +338,7 @@ def write_segmentation(segmentation_object: SegmentationExtractor, save_path: Pa
for no, i in enumerate(stat):
stat[no] = {
"med": roi_locs[no, :].tolist(),
"ypix": segmentation_object.get_image_size()[0] - 1 - pixel_masks[no][:, 0],
"ypix": pixel_masks[no][:, 0],
"xpix": pixel_masks[no][:, 1],
"lam": pixel_masks[no][:, 2],
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def test_imaging_extractors_canonical_shape(self, extractor_class, extractor_kwa
),
param(
extractor_class=Suite2pSegmentationExtractor,
extractor_kwargs=dict(folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")),
),
param(
extractor_class=Suite2pSegmentationExtractor,
extractor_kwargs=dict(file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")),
extractor_kwargs=dict(
folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"),
channel_name="chan1",
plane_name="plane0",
),
),
]

Expand Down
Loading

0 comments on commit 8041fb0

Please sign in to comment.