diff --git a/.github/workflows/check-docstrings.yaml b/.github/workflows/check-docstrings.yaml new file mode 100644 index 00000000..1cc2eb7c --- /dev/null +++ b/.github/workflows/check-docstrings.yaml @@ -0,0 +1,12 @@ +name: Check Docstrings +on: + workflow_dispatch: + pull_request: + +jobs: + check-docstrings: + uses: catalystneuro/.github/.github/workflows/check_docstrings.yaml@main + with: + python-version: '3.10' + repository: 'catalystneuro/roiextractors' + package-name: 'roiextractors' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 456112df..e1f37e64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.1 hooks: - id: black exclude: ^docs/ diff --git a/src/roiextractors/extractorlist.py b/src/roiextractors/extractorlist.py index 091265c2..73d0bf6e 100644 --- a/src/roiextractors/extractorlist.py +++ b/src/roiextractors/extractorlist.py @@ -15,6 +15,8 @@ from .extractors.tiffimagingextractors import ( TiffImagingExtractor, ScanImageTiffImagingExtractor, + ScanImageTiffSinglePlaneImagingExtractor, + ScanImageTiffMultiPlaneImagingExtractor, BrukerTiffMultiPlaneImagingExtractor, BrukerTiffSinglePlaneImagingExtractor, MicroManagerTiffImagingExtractor, @@ -25,12 +27,15 @@ from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor from .multisegmentationextractor import MultiSegmentationExtractor from .multiimagingextractor import MultiImagingExtractor +from .volumetricimagingextractor import VolumetricImagingExtractor imaging_extractor_full_list = [ NumpyImagingExtractor, Hdf5ImagingExtractor, TiffImagingExtractor, ScanImageTiffImagingExtractor, + ScanImageTiffSinglePlaneImagingExtractor, + ScanImageTiffMultiPlaneImagingExtractor, BrukerTiffMultiPlaneImagingExtractor, BrukerTiffSinglePlaneImagingExtractor, MicroManagerTiffImagingExtractor, @@ -39,6 +44,7 @@ SbxImagingExtractor, NumpyMemmapImagingExtractor, MemmapImagingExtractor, + VolumetricImagingExtractor, ] segmentation_extractor_full_list = [ diff --git a/src/roiextractors/extractors/tiffimagingextractors/__init__.py b/src/roiextractors/extractors/tiffimagingextractors/__init__.py index b8f5cf54..3f3a3618 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/__init__.py +++ b/src/roiextractors/extractors/tiffimagingextractors/__init__.py @@ -16,7 +16,11 @@ TiffImagingExtractor A ImagingExtractor for TIFF files. ScanImageTiffImagingExtractor - Specialized extractor for reading TIFF files produced via ScanImage. + Legacy extractor for reading TIFF files produced via ScanImage v3.8. +ScanImageTiffSinglePlaneImagingExtractor + Specialized extractor for reading single-plane TIFF files produced via ScanImage. +ScanImageTiffMultiPlaneImagingExtractor + Specialized extractor for reading multi-plane TIFF files produced via ScanImage. BrukerTiffMultiPlaneImagingExtractor Specialized extractor for reading TIFF files produced via Bruker. BrukerTiffSinglePlaneImagingExtractor @@ -25,6 +29,10 @@ Specialized extractor for reading TIFF files produced via Micro-Manager. """ from .tiffimagingextractor import TiffImagingExtractor -from .scanimagetiffimagingextractor import ScanImageTiffImagingExtractor +from .scanimagetiffimagingextractor import ( + ScanImageTiffImagingExtractor, + ScanImageTiffMultiPlaneImagingExtractor, + ScanImageTiffSinglePlaneImagingExtractor, +) from .brukertiffimagingextractor import BrukerTiffMultiPlaneImagingExtractor, BrukerTiffSinglePlaneImagingExtractor from .micromanagertiffimagingextractor import MicroManagerTiffImagingExtractor diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py new file mode 100644 index 00000000..691722dc --- /dev/null +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py @@ -0,0 +1,193 @@ +"""Utility functions for ScanImage TIFF Extractors.""" +import numpy as np +from ...extraction_tools import PathType, get_package + + +def _get_scanimage_reader() -> type: + """Import the scanimage-tiff-reader package and return the ScanImageTiffReader class.""" + return get_package( + package_name="ScanImageTiffReader", installation_instructions="pip install scanimage-tiff-reader" + ).ScanImageTiffReader + + +def extract_extra_metadata( + file_path: PathType, +) -> dict: # TODO: Refactor neuroconv to reference this implementation to avoid duplication + """Extract metadata from a ScanImage TIFF file. + + Parameters + ---------- + file_path : PathType + Path to the TIFF file. + + Returns + ------- + extra_metadata: dict + Dictionary of metadata extracted from the TIFF file. + + Notes + ----- + Known to work on SI versions v3.8.0, v2019bR0, v2022.0.0, and v2023.0.0 + """ + ScanImageTiffReader = _get_scanimage_reader() + io = ScanImageTiffReader(str(file_path)) + extra_metadata = {} + for metadata_string in (io.description(iframe=0), io.metadata()): + metadata_dict = { + x.split("=")[0].strip(): x.split("=")[1].strip() + for x in metadata_string.replace("\n", "\r").split("\r") + if "=" in x + } + extra_metadata = dict(**extra_metadata, **metadata_dict) + return extra_metadata + + +def parse_matlab_vector(matlab_vector: str) -> list: + """Parse a MATLAB vector string into a list of integer values. + + Parameters + ---------- + matlab_vector : str + MATLAB vector string. + + Returns + ------- + vector: list of int + List of integer values. + + Raises + ------ + ValueError + If the MATLAB vector string cannot be parsed. + + Notes + ----- + MATLAB vector string is of the form "[1 2 3 ... N]" or "[1,2,3,...,N]" or "[1;2;3;...;N]". + There may or may not be whitespace between the values. Ex. "[1, 2, 3]" or "[1,2,3]". + """ + vector = matlab_vector.strip("[]") + if ";" in vector: + vector = vector.split(";") + elif "," in vector: + vector = vector.split(",") + elif " " in vector: + vector = vector.split(" ") + elif len(vector) == 1: + pass + else: + raise ValueError(f"Could not parse vector from {matlab_vector}.") + vector = [int(x.strip()) for x in vector if x != ""] + return vector + + +def parse_metadata(metadata: dict) -> dict: + """Parse metadata dictionary to extract relevant information and store it standard keys for ImagingExtractors. + + Currently supports + - sampling_frequency + - num_planes + - frames_per_slice + - channel_names + - num_channels + + Parameters + ---------- + metadata : dict + Dictionary of metadata extracted from the TIFF file. + + Returns + ------- + metadata_parsed: dict + Dictionary of parsed metadata. + + Notes + ----- + Known to work on SI versions v2019bR0, v2022.0.0, and v2023.0.0. Fails on v3.8.0. + SI.hChannels.channelsActive = string of MATLAB-style vector with channel integers (see parse_matlab_vector). + SI.hChannels.channelName = "{'channel_name_1' 'channel_name_2' ... 'channel_name_M'}" + where M is the number of channels (active or not). + """ + sampling_frequency = float(metadata["SI.hRoiManager.scanFrameRate"]) + num_planes = int(metadata["SI.hStackManager.numSlices"]) + frames_per_slice = int(metadata["SI.hStackManager.framesPerSlice"]) + active_channels = parse_matlab_vector(metadata["SI.hChannels.channelsActive"]) + channel_indices = np.array(active_channels) - 1 # Account for MATLAB indexing + channel_names = np.array(metadata["SI.hChannels.channelName"].split("'")[1::2]) + channel_names = channel_names[channel_indices].tolist() + num_channels = len(channel_names) + metadata_parsed = dict( + sampling_frequency=sampling_frequency, + num_channels=num_channels, + num_planes=num_planes, + frames_per_slice=frames_per_slice, + channel_names=channel_names, + ) + return metadata_parsed + + +def parse_metadata_v3_8(metadata: dict) -> dict: + """Parse metadata dictionary to extract relevant information and store it standard keys for ImagingExtractors. + + Requires old version of metadata (v3.8). + Currently supports + - sampling frequency + - num_channels + - num_planes + + Parameters + ---------- + metadata : dict + Dictionary of metadata extracted from the TIFF file. + + Returns + ------- + metadata_parsed: dict + Dictionary of parsed metadata. + """ + sampling_frequency = float(metadata["state.acq.frameRate"]) + num_channels = int(metadata["state.acq.numberOfChannelsSave"]) + num_planes = int(metadata["state.acq.numberOfZSlices"]) + metadata_parsed = dict( + sampling_frequency=sampling_frequency, + num_channels=num_channels, + num_planes=num_planes, + ) + return metadata_parsed + + +def extract_timestamps_from_file(file_path: PathType) -> np.ndarray: + """Extract the frame timestamps from a ScanImage TIFF file. + + Parameters + ---------- + file_path : PathType + Path to the TIFF file. + + Returns + ------- + timestamps : numpy.ndarray + Array of frame timestamps in seconds. + + Raises + ------ + AssertionError + If the frame timestamps are not found in the TIFF file. + + Notes + ----- + Known to work on SI versions v2019bR0, v2022.0.0, and v2023.0.0. Fails on v3.8.0. + """ + ScanImageTiffReader = _get_scanimage_reader() + io = ScanImageTiffReader(str(file_path)) + assert "frameTimestamps_sec" in io.description(iframe=0), "frameTimestamps_sec not found in TIFF file" + num_frames = io.shape()[0] + timestamps = np.zeros(num_frames) + for iframe in range(num_frames): + description = io.description(iframe=iframe) + description_lines = description.split("\n") + for line in description_lines: + if "frameTimestamps_sec" in line: + timestamps[iframe] = float(line.split("=")[1].strip()) + break + + return timestamps diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index 83a51916..fea05b62 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -6,29 +6,355 @@ Specialized extractor for reading TIFF files produced via ScanImage. """ from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Iterable from warnings import warn - import numpy as np -from ...extraction_tools import PathType, FloatType, ArrayType, get_package +from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package from ...imagingextractor import ImagingExtractor +from ...volumetricimagingextractor import VolumetricImagingExtractor +from .scanimagetiff_utils import ( + extract_extra_metadata, + parse_metadata, + extract_timestamps_from_file, + _get_scanimage_reader, +) + + +class ScanImageTiffMultiPlaneImagingExtractor(VolumetricImagingExtractor): + """Specialized extractor for reading multi-plane (volumetric) TIFF files produced via ScanImage.""" + extractor_name = "ScanImageTiffMultiPlaneImaging" + is_writable = True + mode = "file" -def _get_scanimage_reader() -> type: - """Import the scanimage-tiff-reader package and return the ScanImageTiffReader class.""" - return get_package( - package_name="ScanImageTiffReader", installation_instructions="pip install scanimage-tiff-reader" - ).ScanImageTiffReader + def __init__( + self, + file_path: PathType, + channel_name: Optional[str] = None, + ) -> None: + self.file_path = Path(file_path) + self.metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata(self.metadata) + num_planes = parsed_metadata["num_planes"] + channel_names = parsed_metadata["channel_names"] + if channel_name is None: + channel_name = channel_names[0] + imaging_extractors = [] + for plane in range(num_planes): + imaging_extractor = ScanImageTiffSinglePlaneImagingExtractor( + file_path=file_path, channel_name=channel_name, plane_name=str(plane) + ) + imaging_extractors.append(imaging_extractor) + super().__init__(imaging_extractors=imaging_extractors) + assert all( + imaging_extractor.get_num_planes() == self._num_planes for imaging_extractor in imaging_extractors + ), "All imaging extractors must have the same number of planes." -class ScanImageTiffImagingExtractor(ImagingExtractor): +class ScanImageTiffSinglePlaneImagingExtractor(ImagingExtractor): """Specialized extractor for reading TIFF files produced via ScanImage.""" extractor_name = "ScanImageTiffImaging" is_writable = True mode = "file" + @classmethod + def get_available_channels(cls, file_path): + """Get the available channel names from a TIFF file produced by ScanImage. + + Parameters + ---------- + file_path : PathType + Path to the TIFF file. + + Returns + ------- + channel_names: list + List of channel names. + """ + metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata(metadata) + channel_names = parsed_metadata["channel_names"] + return channel_names + + @classmethod + def get_available_planes(cls, file_path): + """Get the available plane names from a TIFF file produced by ScanImage. + + Parameters + ---------- + file_path : PathType + Path to the TIFF file. + + Returns + ------- + plane_names: list + List of plane names. + """ + metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata(metadata) + num_planes = parsed_metadata["num_planes"] + plane_names = [f"{i}" for i in range(num_planes)] + return plane_names + + def __init__( + self, + file_path: PathType, + channel_name: str, + plane_name: str, + ) -> None: + """Create a ScanImageTiffImagingExtractor instance from a TIFF file produced by ScanImage. + + The underlying data is stored in a round-robin format collapsed into 3 dimensions (frames, rows, columns). + I.e. the first frame of each channel and each plane is stored, and then the second frame of each channel and + each plane, etc. + If framesPerSlice > 1, then multiple frames are acquired per slice before moving to the next slice. + Ex. for 2 channels, 2 planes, and 2 framesPerSlice: + ``` + [channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, + channel_1_plane_2_frame_1, channel_2_plane_2_frame_1, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, + channel_1_plane_1_frame_3, channel_2_plane_1_frame_3, channel_1_plane_1_frame_4, channel_2_plane_1_frame_4, + channel_1_plane_2_frame_3, channel_2_plane_2_frame_3, channel_1_plane_2_frame_4, channel_2_plane_2_frame_4, ... + channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N] + ``` + This file structured is accessed by ScanImageTiffImagingExtractor for a single channel and plane. + + Parameters + ---------- + file_path : PathType + Path to the TIFF file. + channel_name : str + Name of the channel for this extractor (default=None). + plane_name : str + Name of the plane for this extractor (default=None). + """ + self.file_path = Path(file_path) + self.metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata(self.metadata) + self._sampling_frequency = parsed_metadata["sampling_frequency"] + self._num_channels = parsed_metadata["num_channels"] + self._num_planes = parsed_metadata["num_planes"] + self._frames_per_slice = parsed_metadata["frames_per_slice"] + self._channel_names = parsed_metadata["channel_names"] + self._plane_names = [f"{i}" for i in range(self._num_planes)] + self.channel_name = channel_name + self.plane_name = plane_name + if channel_name not in self._channel_names: + raise ValueError(f"Channel name ({channel_name}) not found in channel names ({self._channel_names}).") + self.channel = self._channel_names.index(channel_name) + if plane_name not in self._plane_names: + raise ValueError(f"Plane name ({plane_name}) not found in plane names ({self._plane_names}).") + self.plane = self._plane_names.index(plane_name) + + ScanImageTiffReader = _get_scanimage_reader() + with ScanImageTiffReader(str(self.file_path)) as io: + shape = io.shape() # [frames, rows, columns] + if len(shape) == 3: + self._total_num_frames, self._num_rows, self._num_columns = shape + self._num_raw_per_plane = self._frames_per_slice * self._num_channels + self._num_raw_per_cycle = self._num_raw_per_plane * self._num_planes + self._num_frames = self._total_num_frames // (self._num_planes * self._num_channels) + self._num_cycles = self._total_num_frames // self._num_raw_per_cycle + else: + raise NotImplementedError( + "Extractor cannot handle 4D ScanImageTiff data. Please raise an issue to request this feature: " + "https://github.com/catalystneuro/roiextractors/issues " + ) + timestamps = extract_timestamps_from_file(file_path) + index = [self.frame_to_raw_index(iframe) for iframe in range(self._num_frames)] + self._times = timestamps[index] + + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: + """Get specific video frames from indices (not necessarily continuous). + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + frames: numpy.ndarray + The video frames. + """ + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + self.check_frame_inputs(frame_idxs[-1]) + + if not all(np.diff(frame_idxs) == 1): + return np.concatenate([self._get_single_frame(frame=idx) for idx in frame_idxs]) + else: + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) + + # Data accessed through an open ScanImageTiffReader io gets scrambled if there are multiple calls. + # Thus, open fresh io in context each time something is needed. + def _get_single_frame(self, frame: int) -> np.ndarray: + """Get a single frame of data from the TIFF file. + + Parameters + ---------- + frame : int + The index of the frame to retrieve. + + Returns + ------- + frame: numpy.ndarray + The frame of data. + """ + self.check_frame_inputs(frame) + ScanImageTiffReader = _get_scanimage_reader() + raw_index = self.frame_to_raw_index(frame) + with ScanImageTiffReader(str(self.file_path)) as io: + return io.data(beg=raw_index, end=raw_index + 1) + + def get_video(self, start_frame=None, end_frame=None) -> np.ndarray: + """Get the video frames. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + video: numpy.ndarray + The video frames. + """ + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self._num_frames + end_frame_inclusive = end_frame - 1 + self.check_frame_inputs(end_frame_inclusive) + self.check_frame_inputs(start_frame) + raw_start = self.frame_to_raw_index(start_frame) + raw_end_inclusive = self.frame_to_raw_index(end_frame_inclusive) # frame_to_raw_index requires inclusive frame + raw_end = raw_end_inclusive + 1 + + ScanImageTiffReader = _get_scanimage_reader() + with ScanImageTiffReader(filename=str(self.file_path)) as io: + raw_video = io.data(beg=raw_start, end=raw_end) + + start_cycle = np.ceil(start_frame / self._frames_per_slice).astype("int") + end_cycle = end_frame // self._frames_per_slice + num_cycles = end_cycle - start_cycle + start_frame_in_cycle = start_frame % self._frames_per_slice + end_frame_in_cycle = end_frame % self._frames_per_slice + start_left_in_cycle = (self._frames_per_slice - start_frame_in_cycle) % self._frames_per_slice + end_left_in_cycle = (self._frames_per_slice - end_frame_in_cycle) % self._frames_per_slice + index = [] + for j in range(start_left_in_cycle): # Add remaining frames from first (incomplete) cycle + index.append(j * self._num_channels) + for i in range(num_cycles): + for j in range(self._frames_per_slice): + index.append( + (j - start_frame_in_cycle) * self._num_channels + + (i + bool(start_left_in_cycle)) * self._num_raw_per_cycle + ) + for j in range(end_left_in_cycle): # Add remaining frames from last (incomplete) cycle) + index.append((j - start_frame_in_cycle) * self._num_channels + num_cycles * self._num_raw_per_cycle) + video = raw_video[index] + return video + + def get_image_size(self) -> Tuple[int, int]: + return (self._num_rows, self._num_columns) + + def get_num_frames(self) -> int: + return self._num_frames + + def get_sampling_frequency(self) -> float: + return self._sampling_frequency + + def get_channel_names(self) -> list: + return self._channel_names + + def get_num_channels(self) -> int: + return self._num_channels + + def get_num_planes(self) -> int: + """Get the number of depth planes. + + Returns + ------- + _num_planes: int + The number of depth planes. + """ + return self._num_planes + + def get_dtype(self) -> DtypeType: + return self.get_frames(0).dtype + + def check_frame_inputs(self, frame) -> None: + """Check that the frame index is valid. Raise ValueError if not. + + Parameters + ---------- + frame : int + The index of the frame to retrieve. + + Raises + ------ + ValueError + If the frame index is invalid. + """ + if frame >= self._num_frames: + raise ValueError(f"Frame index ({frame}) exceeds number of frames ({self._num_frames}).") + if frame < 0: + raise ValueError(f"Frame index ({frame}) must be greater than or equal to 0.") + + def frame_to_raw_index(self, frame): + """Convert a frame index to the raw index in the TIFF file. + + Parameters + ---------- + frame : int + The index of the frame to retrieve. + + Returns + ------- + raw_index: int + The raw index of the frame in the TIFF file. + + Notes + ----- + The underlying data is stored in a round-robin format collapsed into 3 dimensions (frames, rows, columns). + I.e. the first frame of each channel and each plane is stored, and then the second frame of each channel and + each plane, etc. + If framesPerSlice > 1, then multiple frames are acquired per slice before moving to the next slice. + Ex. for 2 channels, 2 planes, and 2 framesPerSlice: + ``` + [channel_1_plane_1_frame_1, channel_2_plane_1_frame_1, channel_1_plane_1_frame_2, channel_2_plane_1_frame_2, + channel_1_plane_2_frame_1, channel_2_plane_2_frame_1, channel_1_plane_2_frame_2, channel_2_plane_2_frame_2, + channel_1_plane_1_frame_3, channel_2_plane_1_frame_3, channel_1_plane_1_frame_4, channel_2_plane_1_frame_4, + channel_1_plane_2_frame_3, channel_2_plane_2_frame_3, channel_1_plane_2_frame_4, channel_2_plane_2_frame_4, ... + channel_1_plane_1_frame_N, channel_2_plane_1_frame_N, channel_1_plane_2_frame_N, channel_2_plane_2_frame_N] + ``` + """ + cycle = frame // self._frames_per_slice + frame_in_cycle = frame % self._frames_per_slice + raw_index = ( + cycle * self._num_raw_per_cycle + + self.plane * self._num_raw_per_plane + + frame_in_cycle * self._num_channels + + self.channel + ) + return raw_index + + +class ScanImageTiffImagingExtractor(ImagingExtractor): # TODO: Remove this extractor on/after December 2023 + """Specialized extractor for reading TIFF files produced via ScanImage. + + This implementation is for legacy purposes and is not recommended for use. + Please use ScanImageTiffSinglePlaneImagingExtractor or ScanImageTiffMultiPlaneImagingExtractor instead. + """ + + extractor_name = "ScanImageTiffImaging" + is_writable = True + mode = "file" + def __init__( self, file_path: PathType, @@ -47,6 +373,12 @@ def __init__( sampling_frequency : float The frequency at which the frames were sampled, in Hz. """ + deprecation_message = """ + This extractor is being deprecated on or after December 2023 in favor of + ScanImageTiffMultiPlaneImagingExtractor or ScanImageTiffSinglePlaneImagingExtractor. Please use one of these + extractors instead. + """ + warn(deprecation_message, category=FutureWarning) ScanImageTiffReader = _get_scanimage_reader() super().__init__() diff --git a/src/roiextractors/imagingextractor.py b/src/roiextractors/imagingextractor.py index 66ede81a..a74503f7 100644 --- a/src/roiextractors/imagingextractor.py +++ b/src/roiextractors/imagingextractor.py @@ -109,6 +109,20 @@ def get_video( ------- video: numpy.ndarray The video frames. + + Notes + ----- + Importantly, we follow the convention that the dimensions of the array are returned in their matrix order, + More specifically: + (time, height, width) + + Which is equivalent to: + (samples, rows, columns) + + Note that this does not match the cartesian convention: + (t, x, y) + + Where x is the columns width or and y is the rows or height. """ pass diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index d1720898..47d694d9 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -53,6 +53,7 @@ def generate_dummy_imaging_extractor( num_channels: int = 1, sampling_frequency: float = 30, dtype: DtypeType = "uint16", + channel_names: Optional[list] = None, ): """Generate a dummy imaging extractor for testing. @@ -78,7 +79,8 @@ def generate_dummy_imaging_extractor( ImagingExtractor An imaging extractor with random data fed into `NumpyImagingExtractor`. """ - channel_names = [f"channel_num_{num}" for num in range(num_channels)] + if channel_names is None: + channel_names = [f"channel_num_{num}" for num in range(num_channels)] size = (num_frames, num_rows, num_columns, num_channels) video = generate_dummy_video(size=size, dtype=dtype) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py new file mode 100644 index 00000000..2abf0c1a --- /dev/null +++ b/src/roiextractors/volumetricimagingextractor.py @@ -0,0 +1,169 @@ +"""Base class definition for volumetric imaging extractors.""" + +from typing import Tuple, List, Iterable, Optional +import numpy as np + +from .extraction_tools import ArrayType, DtypeType +from .imagingextractor import ImagingExtractor + + +class VolumetricImagingExtractor(ImagingExtractor): + """Class to combine multiple ImagingExtractor objects by depth plane.""" + + extractor_name = "VolumetricImaging" + installed = True + installatiuon_mesage = "" + + def __init__(self, imaging_extractors: List[ImagingExtractor]): + """Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + """ + super().__init__() + assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" + assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) + self._check_consistency_between_imaging_extractors(imaging_extractors) + self._imaging_extractors = imaging_extractors + self._num_planes = len(imaging_extractors) + + def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): + """Check that essential properties are consistent between extractors so that they can be combined appropriately. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + + Raises + ------ + AssertionError + If any of the properties are not consistent between extractors. + + Notes + ----- + This method checks the following properties: + - sampling frequency + - image size + - number of channels + - channel names + - data type + - num_frames + """ + properties_to_check = dict( + get_sampling_frequency="The sampling frequency", + get_image_size="The size of a frame", + get_num_channels="The number of channels", + get_channel_names="The name of the channels", + get_dtype="The data type", + get_num_frames="The number of frames", + ) + for method, property_message in properties_to_check.items(): + values = [getattr(extractor, method)() for extractor in imaging_extractors] + unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) + assert ( + len(unique_values) == 1 + ), f"{property_message} is not consistent over the files (found {unique_values})." + + def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: + """Get the video frames. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + video: numpy.ndarray + The 3D video frames (num_frames, num_rows, num_columns, num_planes). + """ + if start_frame is None: + start_frame = 0 + elif start_frame < 0: + start_frame = self.get_num_frames() + start_frame + elif start_frame >= self.get_num_frames(): + raise ValueError( + f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}" + ) + if end_frame is None: + end_frame = self.get_num_frames() + elif end_frame < 0: + end_frame = self.get_num_frames() + end_frame + elif end_frame > self.get_num_frames(): + raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}") + if end_frame <= start_frame: + raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}") + + video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + video[..., i] = imaging_extractor.get_video(start_frame, end_frame) + return video + + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: + """Get specific video frames from indices (not necessarily continuous). + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + frames: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + for frame_idx in frame_idxs: + if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames(): + raise ValueError(f"frame_idx {frame_idx} is out of bounds") + + # Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int + if not all(np.diff(frame_idxs) == 1): + frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + frames[..., i] = imaging_extractor.get_frames(frame_idxs) + return frames + else: + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) + + def get_image_size(self) -> Tuple: + """Get the size of a single frame. + + Returns + ------- + image_size: tuple + The size of a single frame (num_rows, num_columns, num_planes). + """ + image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) + return image_size + + def get_num_planes(self) -> int: + """Get the number of depth planes. + + Returns + ------- + _num_planes: int + The number of depth planes. + """ + return self._num_planes + + def get_num_frames(self) -> int: + return self._imaging_extractors[0].get_num_frames() + + def get_sampling_frequency(self) -> float: + return self._imaging_extractors[0].get_sampling_frequency() + + def get_channel_names(self) -> list: + return self._imaging_extractors[0].get_channel_names() + + def get_num_channels(self) -> int: + return self._imaging_extractors[0].get_num_channels() + + def get_dtype(self) -> DtypeType: + return self._imaging_extractors[0].get_dtype() diff --git a/tests/test_scanimage_utils.py b/tests/test_scanimage_utils.py new file mode 100644 index 00000000..fd55c671 --- /dev/null +++ b/tests/test_scanimage_utils.py @@ -0,0 +1,133 @@ +import pytest +from numpy.testing import assert_array_equal +from ScanImageTiffReader import ScanImageTiffReader +from roiextractors.extractors.tiffimagingextractors.scanimagetiff_utils import ( + _get_scanimage_reader, + extract_extra_metadata, + parse_matlab_vector, + parse_metadata, + parse_metadata_v3_8, + extract_timestamps_from_file, +) + +from .setup_paths import OPHYS_DATA_PATH + + +def test_get_scanimage_reader(): + ScanImageTiffReader = _get_scanimage_reader() + assert ScanImageTiffReader is not None + + +@pytest.mark.parametrize( + "filename, expected_key, expected_value", + [ + ("sample_scanimage_version_3_8.tiff", "state.software.version", "3.8"), + ("scanimage_20220801_single.tif", "SI.VERSION_MAJOR", "2022"), + ("scanimage_20220923_roi.tif", "SI.VERSION_MAJOR", "2023"), + ], +) +def test_extract_extra_metadata(filename, expected_key, expected_value): + file_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / filename + metadata = extract_extra_metadata(file_path) + assert metadata[expected_key] == expected_value + + +@pytest.mark.parametrize( + "matlab_vector, expected_vector", + [ + ("[1 2 3]", [1, 2, 3]), + ("[1,2,3]", [1, 2, 3]), + ("[1, 2, 3]", [1, 2, 3]), + ("[1;2;3]", [1, 2, 3]), + ("[1; 2; 3]", [1, 2, 3]), + ], +) +def test_parse_matlab_vector(matlab_vector, expected_vector): + vector = parse_matlab_vector(matlab_vector) + assert vector == expected_vector + + +def test_parse_matlab_vector_invalid(): + with pytest.raises(ValueError): + parse_matlab_vector("Invalid") + + +@pytest.mark.parametrize( + "filename, expected_metadata", + [ + ( + "scanimage_20220801_single.tif", + { + "sampling_frequency": 15.2379, + "num_channels": 1, + "num_planes": 20, + "frames_per_slice": 24, + "channel_names": ["Channel 1"], + }, + ), + ( + "scanimage_20220923_roi.tif", + { + "sampling_frequency": 29.1248, + "num_channels": 2, + "num_planes": 2, + "frames_per_slice": 2, + "channel_names": ["Channel 1", "Channel 4"], + }, + ), + ], +) +def test_parse_metadata(filename, expected_metadata): + file_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / filename + metadata = extract_extra_metadata(file_path) + metadata = parse_metadata(metadata) + assert metadata == expected_metadata + + +def test_parse_metadata_v3_8(): + file_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / "sample_scanimage_version_3_8.tiff" + metadata = extract_extra_metadata(file_path) + metadata = parse_metadata_v3_8(metadata) + expected_metadata = {"sampling_frequency": 3.90625, "num_channels": 1, "num_planes": 1} + assert metadata == expected_metadata + + +@pytest.mark.parametrize( + "filename, expected_timestamps", + [ + ("scanimage_20220801_single.tif", [0.45951611, 0.98468446, 1.50985974]), + ( + "scanimage_20220923_roi.tif", + [ + 0.0, + 0.0, + 0.03433645, + 0.03433645, + 1.04890375, + 1.04890375, + 1.08324025, + 1.08324025, + 2.12027815, + 2.12027815, + 2.15461465, + 2.15461465, + 2.7413649, + 2.7413649, + 2.7757014, + 2.7757014, + 3.23987545, + 3.23987545, + 3.27421195, + 3.27421195, + 3.844804, + 3.844804, + 3.87914055, + 3.87914055, + ], + ), + ], +) +def test_extract_timestamps_from_file(filename, expected_timestamps): + file_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / filename + timestamps = extract_timestamps_from_file(file_path) + assert_array_equal(timestamps, expected_timestamps) diff --git a/tests/test_scanimagetiffimagingextractor.py b/tests/test_scanimagetiffimagingextractor.py new file mode 100644 index 00000000..4c0889a9 --- /dev/null +++ b/tests/test_scanimagetiffimagingextractor.py @@ -0,0 +1,232 @@ +import pytest +from numpy.testing import assert_array_equal +from ScanImageTiffReader import ScanImageTiffReader +from roiextractors import ScanImageTiffSinglePlaneImagingExtractor, ScanImageTiffMultiPlaneImagingExtractor + +from .setup_paths import OPHYS_DATA_PATH + + +@pytest.fixture(scope="module") +def file_path(): + return OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / "scanimage_20220923_roi.tif" + + +@pytest.fixture(scope="module") +def expected_properties(): + return dict( + sampling_frequency=29.1248, + num_channels=2, + num_planes=2, + frames_per_slice=2, + channel_names=["Channel 1", "Channel 4"], + plane_names=["0", "1"], + image_size=(528, 256), + num_frames=6, + dtype="int16", + ) + + +@pytest.fixture( + scope="module", + params=[ + dict(channel_name="Channel 1", plane_name="0"), + dict(channel_name="Channel 1", plane_name="1"), + dict(channel_name="Channel 4", plane_name="0"), + dict(channel_name="Channel 4", plane_name="1"), + ], +) +def scan_image_tiff_single_plane_imaging_extractor(request, file_path): + return ScanImageTiffSinglePlaneImagingExtractor(file_path=file_path, **request.param) + + +@pytest.mark.parametrize("channel_name, plane_name", [("Invalid Channel", "0"), ("Channel 1", "Invalid Plane")]) +def test_ScanImageTiffSinglePlaneImagingExtractor__init__invalid(file_path, channel_name, plane_name): + with pytest.raises(ValueError): + ScanImageTiffSinglePlaneImagingExtractor(file_path=file_path, channel_name=channel_name, plane_name=plane_name) + + +@pytest.mark.parametrize("frame_idxs", (0, [0, 1, 2], [0, 2, 5])) +def test_get_frames(scan_image_tiff_single_plane_imaging_extractor, frame_idxs, expected_properties): + frames = scan_image_tiff_single_plane_imaging_extractor.get_frames(frame_idxs=frame_idxs) + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor.plane + channel = scan_image_tiff_single_plane_imaging_extractor.channel + num_planes = expected_properties["num_planes"] + num_channels = expected_properties["num_channels"] + frames_per_slice = expected_properties["frames_per_slice"] + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + raw_idxs = [] + for idx in frame_idxs: + cycle = idx // frames_per_slice + frame_in_cycle = idx % frames_per_slice + raw_idx = ( + cycle * num_planes * num_channels * frames_per_slice + + plane * num_channels * frames_per_slice + + num_channels * frame_in_cycle + + channel + ) + raw_idxs.append(raw_idx) + with ScanImageTiffReader(file_path) as io: + assert_array_equal(frames, io.data()[raw_idxs]) + + +@pytest.mark.parametrize("frame_idxs", ([-1], [50])) +def test_get_frames_invalid(scan_image_tiff_single_plane_imaging_extractor, frame_idxs): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor.get_frames(frame_idxs=frame_idxs) + + +@pytest.mark.parametrize("frame_idx", (1, 3, 5)) +def test_get_single_frame(scan_image_tiff_single_plane_imaging_extractor, expected_properties, frame_idx): + frame = scan_image_tiff_single_plane_imaging_extractor._get_single_frame(frame=frame_idx) + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor.plane + channel = scan_image_tiff_single_plane_imaging_extractor.channel + num_planes = expected_properties["num_planes"] + num_channels = expected_properties["num_channels"] + frames_per_slice = expected_properties["frames_per_slice"] + cycle = frame_idx // frames_per_slice + frame_in_cycle = frame_idx % frames_per_slice + raw_idx = ( + cycle * num_planes * num_channels * frames_per_slice + + plane * num_channels * frames_per_slice + + num_channels * frame_in_cycle + + channel + ) + with ScanImageTiffReader(file_path) as io: + assert_array_equal(frame, io.data()[raw_idx : raw_idx + 1]) + + +@pytest.mark.parametrize("frame", (-1, 50)) +def test_get_single_frame_invalid(scan_image_tiff_single_plane_imaging_extractor, frame): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor._get_single_frame(frame=frame) + + +@pytest.mark.parametrize("start_frame, end_frame", [(0, None), (None, 6), (1, 4), (0, 6)]) +def test_get_video( + scan_image_tiff_single_plane_imaging_extractor, + expected_properties, + start_frame, + end_frame, +): + video = scan_image_tiff_single_plane_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = expected_properties["num_frames"] + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor.plane + channel = scan_image_tiff_single_plane_imaging_extractor.channel + num_planes = expected_properties["num_planes"] + num_channels = expected_properties["num_channels"] + frames_per_slice = expected_properties["frames_per_slice"] + raw_idxs = [] + for idx in range(start_frame, end_frame): + cycle = idx // frames_per_slice + frame_in_cycle = idx % frames_per_slice + raw_idx = ( + cycle * num_planes * num_channels * frames_per_slice + + plane * num_channels * frames_per_slice + + num_channels * frame_in_cycle + + channel + ) + raw_idxs.append(raw_idx) + with ScanImageTiffReader(file_path) as io: + assert_array_equal(video, io.data()[raw_idxs]) + + +@pytest.mark.parametrize("start_frame, end_frame", [(-1, 2), (0, 50)]) +def test_get_video_invalid( + scan_image_tiff_single_plane_imaging_extractor, + start_frame, + end_frame, +): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + + +def test_get_image_size(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + image_size = scan_image_tiff_single_plane_imaging_extractor.get_image_size() + assert image_size == expected_properties["image_size"] + + +def test_get_num_frames(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + num_frames = scan_image_tiff_single_plane_imaging_extractor.get_num_frames() + assert num_frames == expected_properties["num_frames"] + + +def test_get_sampling_frequency(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + sampling_frequency = scan_image_tiff_single_plane_imaging_extractor.get_sampling_frequency() + assert sampling_frequency == expected_properties["sampling_frequency"] + + +def test_get_num_channels(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() + assert num_channels == expected_properties["num_channels"] + + +def test_get_available_planes(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + plane_names = ScanImageTiffSinglePlaneImagingExtractor.get_available_planes(file_path) + assert plane_names == expected_properties["plane_names"] + + +def test_get_available_channels(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + channel_names = ScanImageTiffSinglePlaneImagingExtractor.get_available_channels(file_path) + assert channel_names == expected_properties["channel_names"] + + +def test_get_num_planes(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() + assert num_planes == expected_properties["num_planes"] + + +def test_get_dtype(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + dtype = scan_image_tiff_single_plane_imaging_extractor.get_dtype() + assert dtype == expected_properties["dtype"] + + +def test_check_frame_inputs_valid(scan_image_tiff_single_plane_imaging_extractor): + scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=0) + + +def test_check_frame_inputs_invalid(scan_image_tiff_single_plane_imaging_extractor, expected_properties): + num_frames = expected_properties["num_frames"] + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=num_frames + 1) + + +@pytest.mark.parametrize("frame", (0, 3, 5)) +def test_frame_to_raw_index( + scan_image_tiff_single_plane_imaging_extractor, + frame, + expected_properties, +): + raw_index = scan_image_tiff_single_plane_imaging_extractor.frame_to_raw_index(frame=frame) + plane = scan_image_tiff_single_plane_imaging_extractor.plane + channel = scan_image_tiff_single_plane_imaging_extractor.channel + num_planes = expected_properties["num_planes"] + num_channels = expected_properties["num_channels"] + frames_per_slice = expected_properties["frames_per_slice"] + cycle = frame // frames_per_slice + frame_in_cycle = frame % frames_per_slice + expected_index = ( + cycle * num_planes * num_channels * frames_per_slice + + plane * num_channels * frames_per_slice + + num_channels * frame_in_cycle + + channel + ) + assert raw_index == expected_index + + +def test_ScanImageTiffMultiPlaneImagingExtractor__init__(file_path): + extractor = ScanImageTiffMultiPlaneImagingExtractor(file_path=file_path) + assert extractor.file_path == file_path + + +def test_ScanImageTiffMultiPlaneImagingExtractor__init__invalid(file_path): + with pytest.raises(ValueError): + ScanImageTiffMultiPlaneImagingExtractor(file_path=file_path, channel_name="Invalid Channel") diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py new file mode 100644 index 00000000..62424dae --- /dev/null +++ b/tests/test_volumetricimagingextractor.py @@ -0,0 +1,121 @@ +import pytest +import numpy as np +from roiextractors.testing import generate_dummy_imaging_extractor +from roiextractors import VolumetricImagingExtractor + +num_frames = 10 + + +@pytest.fixture(scope="module", params=[1, 2]) +def imaging_extractors(request): + num_channels = request.param + return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)] + + +@pytest.fixture(scope="module") +def volumetric_imaging_extractor(imaging_extractors): + return VolumetricImagingExtractor(imaging_extractors) + + +@pytest.mark.parametrize( + "params", + [ + [dict(sampling_frequency=1), dict(sampling_frequency=2)], + [dict(num_rows=1), dict(num_rows=2)], + [dict(num_channels=1), dict(num_channels=2)], + [dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)], + [dict(dtype=np.int16), dict(dtype=np.float32)], + [dict(num_frames=1), dict(num_frames=2)], + ], +) +def test_check_consistency_between_imaging_extractors(params): + imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params] + with pytest.raises(AssertionError): + VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + + +@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)]) +def test_get_video(volumetric_imaging_extractor, start_frame, end_frame): + video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + expected_video = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame)) + expected_video = np.array(expected_video) + expected_video = np.moveaxis(expected_video, 0, -1) + assert np.all(video == expected_video) + + +@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)]) +def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + + +@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]]) +def test_get_frames(volumetric_imaging_extractor, frame_idxs): + frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + expected_frames = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs)) + expected_frames = np.array(expected_frames) + expected_frames = np.moveaxis(expected_frames, 0, -1) + assert np.all(frames == expected_frames) + + +@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]]) +def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + + +@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)]) +def test_get_image_size(num_rows, num_columns, num_planes): + imaging_extractors = [ + generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes) + + +@pytest.mark.parametrize("num_planes", [1, 2, 3]) +def test_get_num_planes(num_planes): + imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_planes() == num_planes + + +@pytest.mark.parametrize("num_frames", [1, 2, 3]) +def test_get_num_frames(num_frames): + imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_frames() == num_frames + + +@pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) +def test_get_sampling_frequency(sampling_frequency): + imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency + + +@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]]) +def test_get_channel_names(channel_names): + imaging_extractors = [ + generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names)) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_channel_names() == channel_names + + +@pytest.mark.parametrize("num_channels", [1, 2, 3]) +def test_get_num_channels(num_channels): + imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_channels() == num_channels + + +@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8]) +def test_get_dtype(dtype): + imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_dtype() == dtype