diff --git a/clinicadl/dataset/caps_dataset.py b/clinicadl/dataset/caps_dataset.py deleted file mode 100644 index d45dc5aa6..000000000 --- a/clinicadl/dataset/caps_dataset.py +++ /dev/null @@ -1,817 +0,0 @@ -# coding: utf8 -# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? -import abc -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.extraction import ( - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.dataset.prepare_data.prepare_data_utils import ( - compute_discarded_slices, - extract_patch_path, - extract_patch_tensor, - extract_roi_path, - extract_roi_tensor, - extract_slice_path, - extract_slice_tensor, - find_mask_path, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ( - Pattern, - Preprocessing, - SliceDirection, - SliceMode, - Template, -) -from clinicadl.utils.exceptions import ( - ClinicaDLCAPSError, - ClinicaDLTSVError, -) - -logger = getLogger("clinicadl") - - -################################# -# Datasets loaders -################################# -class CapsDataset(Dataset): - """Abstract class for all derived CapsDatasets.""" - - def __init__( - self, - config: CapsDatasetConfig, - label_presence: bool, - preprocessing_dict: Dict[str, Any], - ): - self.label_presence = label_presence - self.eval_mode = False - self.config = config - self.preprocessing_dict = preprocessing_dict - - if not hasattr(self, "elem_index"): - raise AttributeError( - "Child class of CapsDataset must set elem_index attribute." - ) - if not hasattr(self, "mode"): - raise AttributeError("Child class of CapsDataset, must set mode attribute.") - - self.df = self.config.data.data_df - mandatory_col = { - "participant_id", - "session_id", - "cohort", - } - if label_presence and self.config.data.label is not None: - mandatory_col.add(self.config.data.label) - - if not mandatory_col.issubset(set(self.df.columns.values)): - raise ClinicaDLTSVError( - f"the data file is not in the correct format." - f"Columns should include {mandatory_col}" - ) - self.elem_per_image = self.num_elem_per_image() - self.size = self[0]["image"].size() - - @property - @abc.abstractmethod - def elem_index(self): - pass - - def label_fn(self, target: Union[str, float, int]) -> Union[float, int, None]: - """ - Returns the label value usable in criterion. - - Args: - target: value of the target. - Returns: - label: value of the label usable in criterion. - """ - # Reconstruction case (no label) - if self.config.data.label is None: - return None - # Regression case (no label code) - elif self.config.data.label_code is None: - return np.float32([target]) - # Classification case (label + label_code dict) - else: - return self.config.data.label_code[str(target)] - - def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: - """ - Returns the label value usable in criterion. - - """ - domain_code = {"t1": 0, "flair": 1} - return domain_code[str(target)] - - def __len__(self) -> int: - return len(self.df) * self.elem_per_image - - def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: - """ - Gets the path to the tensor image (*.pt) - - Args: - participant: ID of the participant. - session: ID of the session. - cohort: Name of the cohort. - Returns: - image_path: path to the tensor containing the whole image. - """ - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - - # Try to find .nii.gz file - try: - folder, file_type = self.config.compute_folder_and_file_type() - - results = clinicadl_file_reader( - [participant], - [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), - ) - logger.debug(f"clinicadl_file_reader output: {results}") - filepath = Path(results[0][0]) - image_filename = filepath.name.replace(".nii.gz", ".pt") - - image_dir = ( - self.config.data.caps_dict[cohort] - / "subjects" - / participant - / session - / "deeplearning_prepare_data" - / "image_based" - / folder - ) - image_path = image_dir / image_filename - # Try to find .pt file - except ClinicaDLCAPSError: - folder, file_type = self.config.compute_folder_and_file_type() - file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt") - results = clinicadl_file_reader( - [participant], - [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), - ) - filepath = results[0] - image_path = Path(filepath[0]) - - return image_path - - def _get_meta_data( - self, idx: int - ) -> Tuple[str, str, str, Union[float, int, None], int]: - """ - Gets all meta data necessary to compute the path with _get_image_path - - Args: - idx (int): row number of the meta-data contained in self.df - Returns: - participant (str): ID of the participant. - session (str): ID of the session. - cohort (str): Name of the cohort. - elem_index (int): Index of the part of the image. - label (str or float or int): value of the label to be used in criterion. - """ - image_idx = idx // self.elem_per_image - participant = self.df.at[image_idx, "participant_id"] - session = self.df.at[image_idx, "session_id"] - cohort = self.df.at[image_idx, "cohort"] - - if self.elem_index is None: - elem_idx = idx % self.elem_per_image - else: - elem_idx = self.elem_index - if self.label_presence and self.config.data.label is not None: - target = self.df.at[image_idx, self.config.data.label] - label = self.label_fn(target) - else: - label = -1 - - if "domain" in self.df.columns: - domain = self.df.at[image_idx, "domain"] - domain = self.domain_fn(domain) - else: - domain = "" # TO MODIFY - return participant, session, cohort, elem_idx, label, domain - - def _get_full_image(self) -> torch.Tensor: - """ - Allows to get the an example of the image mode corresponding to the dataset. - Useful to compute the number of elements if mode != image. - - Returns: - image tensor of the full image first image. - """ - import nibabel as nib - - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - - participant_id = self.df.loc[0, "participant_id"] - session_id = self.df.loc[0, "session_id"] - cohort = self.df.loc[0, "cohort"] - - try: - image_path = self._get_image_path(participant_id, session_id, cohort) - image = torch.load(image_path, weights_only=True) - except IndexError: - file_type = self.config.extraction.file_type - results = clinicadl_file_reader( - [participant_id], - [session_id], - self.config.data.caps_dict[cohort], - file_type.model_dump(), - ) - image_nii = nib.loadsave.load(results[0]) - image_np = image_nii.get_fdata() - image = ToTensor()(image_np) - - return image - - @abc.abstractmethod - def __getitem__(self, idx: int) -> Dict[str, Any]: - """ - Gets the sample containing all the information needed for training and testing tasks. - - Args: - idx: row number of the meta-data contained in self.df - Returns: - dictionary with following items: - - "image" (torch.Tensor): the input given to the model, - - "label" (int or float): the label used in criterion, - - "participant_id" (str): ID of the participant, - - "session_id" (str): ID of the session, - - f"{self.mode}_id" (int): number of the element, - - "image_path": path to the image loaded in CAPS. - - """ - pass - - @abc.abstractmethod - def num_elem_per_image(self) -> int: - """Computes the number of elements per image based on the full image.""" - pass - - def eval(self): - """Put the dataset on evaluation mode (data augmentation is not performed).""" - self.eval_mode = True - return self - - def train(self): - """Put the dataset on training mode (data augmentation is performed).""" - self.eval_mode = False - return self - - -class CapsDatasetImage(CapsDataset): - """Dataset of MRI organized in a CAPS folder.""" - - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - - self.mode = "image" - self.config = config - self.label_presence = label_presence - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return None - - def __getitem__(self, idx): - participant, session, cohort, _, label, domain = self._get_meta_data(idx) - - image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path, weights_only=True) - - train_trf, trf = self.config.transforms.get_transforms() - - image = trf(image) - if self.config.transforms.train_transformations and not self.eval_mode: - image = train_trf(image) - - sample = { - "image": image, - "label": label, - "participant_id": participant, - "session_id": session, - "image_id": 0, - "image_path": image_path.as_posix(), - "domain": domain, - } - - return sample - - def num_elem_per_image(self): - return 1 - - -class CapsDatasetPatch(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - patch_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - """ - self.patch_index = patch_index - self.mode = "patch" - self.config = config - self.label_presence = label_presence - - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.patch_index - - def __getitem__(self, idx): - participant, session, cohort, patch_idx, label, domain = self._get_meta_data( - idx - ) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - patch_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - patch_filename = extract_patch_path( - image_path, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, - ) - patch_tensor = torch.load( - Path(patch_dir).resolve() / patch_filename, weights_only=True - ) - - else: - image = torch.load(image_path, weights_only=True) - patch_tensor = extract_patch_tensor( - image, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, - ) - - train_trf, trf = self.config.transforms.get_transforms() - patch_tensor = trf(patch_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - patch_tensor = train_trf(patch_tensor) - - sample = { - "image": patch_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "patch_id": patch_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - image = self._get_full_image() - - patches_tensor = ( - image.unfold( - 1, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 2, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 3, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .contiguous() - ) - patches_tensor = patches_tensor.view( - -1, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - ) - num_patches = patches_tensor.shape[0] - return num_patches - - -class CapsDatasetRoi(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - roi_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - roi_index: If a value is given the same region will be extracted for each image. - else the dataset will load all the regions possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - self.roi_index = roi_index - self.mode = "roi" - self.config = config - self.label_presence = label_presence - self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors( - self.config.data.caps_directory, preprocessing_dict - ) - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.roi_index - - def __getitem__(self, idx): - participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.roi_list is None: - raise NotImplementedError( - "Default regions are not available anymore in ClinicaDL. " - "Please define appropriate masks and give a roi_list." - ) - - if self.config.extraction.save_features: - mask_path = self.mask_paths[roi_idx] - roi_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - roi_filename = extract_roi_path( - image_path, mask_path, self.config.extraction.roi_uncrop_output - ) - roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) - - else: - image = torch.load(image_path, weights_only=True) - mask_array = self.mask_arrays[roi_idx] - roi_tensor = extract_roi_tensor( - image, mask_array, self.config.extraction.uncropped_roi - ) - - train_trf, trf = self.config.transforms.get_transforms() - - roi_tensor = trf(roi_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - roi_tensor = train_trf(roi_tensor) - - sample = { - "image": roi_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "roi_id": roi_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - if self.config.extraction.roi_list is None: - return 2 - else: - return len(self.config.extraction.roi_list) - - def _get_mask_paths_and_tensors( - self, - caps_directory: Path, - preprocessing_dict: Dict[str, Any], - ) -> Tuple[List[str], List]: - """Loads the masks necessary to regions extraction""" - import nibabel as nib - - caps_dict = self.config.data.caps_dict - if len(caps_dict) > 1: - caps_directory = caps_dict[next(iter(caps_dict))] - logger.warning( - f"The equality of masks is not assessed for multi-cohort training. " - f"The masks stored in {caps_directory} will be used." - ) - - try: - preprocessing_ = Preprocessing(preprocessing_dict["preprocessing"]) - except NotImplementedError: - print( - f"Template of preprocessing {preprocessing_dict['preprocessing']} " - f"is not defined." - ) - # Find template name and pattern - if preprocessing_.value == "custom": - template_name = preprocessing_dict["roi_custom_template"] - if template_name is None: - raise ValueError( - "Please provide a name for the template when preprocessing is `custom`." - ) - - pattern = preprocessing_dict["roi_custom_mask_pattern"] - if pattern is None: - raise ValueError( - "Please provide a pattern for the masks when preprocessing is `custom`." - ) - - else: - for template_ in Template: - if preprocessing_.name == template_.name: - template_name = template_ - - for pattern_ in Pattern: - if preprocessing_.name == pattern_.name: - pattern = pattern_ - - mask_location = caps_directory / "masks" / f"tpl-{template_name}" - - mask_paths, mask_arrays = list(), list() - for roi in self.config.extraction.roi_list: - logger.info(f"Find mask for roi {roi}.") - mask_path, desc = find_mask_path(mask_location, roi, pattern, True) - if mask_path is None: - raise FileNotFoundError(desc) - mask_nii = nib.loadsave.load(mask_path) - mask_paths.append(Path(mask_path)) - mask_arrays.append(mask_nii.get_fdata()) - - return mask_paths, mask_arrays - - -class CapsDatasetSlice(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - slice_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - slice_index: If a value is given the same slice will be extracted for each image. - else the dataset will load all the slices possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - """ - self.slice_index = slice_index - self.mode = "slice" - self.config = config - self.label_presence = label_presence - self.preprocessing_dict = preprocessing_dict - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.slice_index - - def __getitem__(self, idx): - participant, session, cohort, slice_idx, label, domain = self._get_meta_data( - idx - ) - slice_idx = slice_idx + self.config.extraction.discarded_slices[0] - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - slice_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - slice_filename = extract_slice_path( - image_path, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - slice_tensor = torch.load( - Path(slice_dir) / slice_filename, weights_only=True - ) - - else: - image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path, weights_only=True) - slice_tensor = extract_slice_tensor( - image, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - - train_trf, trf = self.config.transforms.get_transforms() - - slice_tensor = trf(slice_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - slice_tensor = train_trf(slice_tensor) - - sample = { - "image": slice_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "slice_id": slice_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - if self.config.extraction.num_slices is not None: - return self.config.extraction.num_slices - - image = self._get_full_image() - return ( - image.size(int(self.config.extraction.slice_direction) + 1) - - self.config.extraction.discarded_slices[0] - - self.config.extraction.discarded_slices[1] - ) - - -def return_dataset( - input_dir: Path, - data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, - cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, -) -> CapsDataset: - """ - Return appropriate Dataset according to given options. - Args: - input_dir: path to a directory containing a CAPS structure. - data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied during training only. - all_transformations: Optional transform to be applied during training and evaluation. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - cnn_index: Index of the CNN in a multi-CNN paradigm (optional). - label_presence: If True the diagnosis will be extracted from the given DataFrame. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - Returns: - the corresponding dataset. - """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": - raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." - ) - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - preprocessing_type=preprocessing_dict["preprocessing"], - preprocessing=preprocessing_dict["preprocessing"], - extraction=preprocessing_dict["mode"], - caps_directory=input_dir, - data_df=data_df, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - config.transforms = transforms_config - - if preprocessing_dict["mode"] == "image": - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetImage( - config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "patch": - assert isinstance(config.extraction, ExtractionPatchConfig) - config.extraction.patch_size = preprocessing_dict["patch_size"] - config.extraction.stride_size = preprocessing_dict["stride_size"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetPatch( - config, - patch_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "roi": - assert isinstance(config.extraction, ExtractionROIConfig) - config.extraction.roi_list = preprocessing_dict["roi_list"] - config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetRoi( - config, - roi_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "slice": - assert isinstance(config.extraction, ExtractionSliceConfig) - config.extraction.slice_direction = SliceDirection( - str(preprocessing_dict["slice_direction"]) - ) - config.extraction.slice_mode = SliceMode(preprocessing_dict["slice_mode"]) - config.extraction.discarded_slices = compute_discarded_slices( - preprocessing_dict["discarded_slices"] - ) - config.extraction.num_slices = ( - None - if "num_slices" not in preprocessing_dict - else preprocessing_dict["num_slices"] - ) - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetSlice( - config, - slice_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - else: - raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." - ) diff --git a/clinicadl/dataset/caps_dataset_config.py b/clinicadl/dataset/caps_dataset_config.py deleted file mode 100644 index 0eac3ffd3..000000000 --- a/clinicadl/dataset/caps_dataset_config.py +++ /dev/null @@ -1,127 +0,0 @@ -from pathlib import Path -from typing import Optional, Tuple, Union - -from pydantic import BaseModel, ConfigDict - -from clinicadl.dataset.config import extraction -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.data_config import DataConfig -from clinicadl.dataset.dataloader_config import DataLoaderConfig -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ExtractionMethod, Preprocessing -from clinicadl.utils.iotools.clinica_utils import FileType - - -def get_extraction(extract_method: ExtractionMethod): - if extract_method == ExtractionMethod.ROI: - return extraction.ExtractionROIConfig - elif extract_method == ExtractionMethod.SLICE: - return extraction.ExtractionSliceConfig - elif extract_method == ExtractionMethod.IMAGE: - return extraction.ExtractionImageConfig - elif extract_method == ExtractionMethod.PATCH: - return extraction.ExtractionPatchConfig - else: - raise ValueError(f"Preprocessing {extract_method.value} is not implemented.") - - -def get_preprocessing(preprocessing_type: Preprocessing): - if preprocessing_type == Preprocessing.T1_LINEAR: - return T1PreprocessingConfig - elif preprocessing_type == Preprocessing.PET_LINEAR: - return PETPreprocessingConfig - elif preprocessing_type == Preprocessing.FLAIR_LINEAR: - return FlairPreprocessingConfig - elif preprocessing_type == Preprocessing.CUSTOM: - return CustomPreprocessingConfig - elif preprocessing_type == Preprocessing.DWI_DTI: - return DTIPreprocessingConfig - else: - raise ValueError( - f"Preprocessing {preprocessing_type.value} is not implemented." - ) - - -class CapsDatasetConfig(BaseModel): - """Config class for CapsDataset object. - - caps_directory, preprocessing_json, extract_method, preprocessing - are arguments that must be passed by the user. - - transforms isn't optional because there is always at least one transform (NanRemoval) - """ - - data: DataConfig - dataloader: DataLoaderConfig - extraction: extraction.ExtractionConfig - preprocessing: PreprocessingConfig - transforms: TransformsConfig - - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @classmethod - def from_preprocessing_and_extraction_method( - cls, - preprocessing_type: Union[str, Preprocessing], - extraction: Union[str, ExtractionMethod], - **kwargs, - ): - return cls( - data=DataConfig(**kwargs), - dataloader=DataLoaderConfig(**kwargs), - preprocessing=get_preprocessing(Preprocessing(preprocessing_type))( - **kwargs - ), - extraction=get_extraction(ExtractionMethod(extraction))(**kwargs), - transforms=TransformsConfig(**kwargs), - ) - - def compute_folder_and_file_type( - self, from_bids: Optional[Path] = None - ) -> Tuple[str, FileType]: - preprocessing = self.preprocessing.preprocessing - if from_bids is not None: - if isinstance(self.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(self.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(self.preprocessing, T1PreprocessingConfig) or isinstance( - self.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(self.preprocessing) - elif isinstance(self.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type diff --git a/clinicadl/dataset/caps_dataset_utils.py b/clinicadl/dataset/caps_dataset_utils.py deleted file mode 100644 index b54ba373d..000000000 --- a/clinicadl/dataset/caps_dataset_utils.py +++ /dev/null @@ -1,193 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.utils.enum import Preprocessing -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import FileType - - -def compute_folder_and_file_type( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -) -> Tuple[str, FileType]: - preprocessing = config.preprocessing.preprocessing - if from_bids is not None: - if isinstance(config.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(config.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(config.preprocessing, T1PreprocessingConfig) or isinstance( - config.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(config.preprocessing) - elif isinstance(config.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type - - -def find_file_type(config: CapsDatasetConfig) -> FileType: - if isinstance(config.preprocessing, T1PreprocessingConfig): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - if ( - config.preprocessing.tracer is None - or config.preprocessing.suvr_reference_region is None - ): - raise ClinicaDLArgumentError( - "`tracer` and `suvr_reference_region` must be defined " - "when using `pet-linear` preprocessing." - ) - file_type = pet_linear_nii(config.preprocessing) - else: - raise NotImplementedError( - f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" - ) - - return file_type - - -def read_json(json_path: Path) -> Dict[str, Any]: - """ - Ensures retro-compatibility between the different versions of ClinicaDL. - - Parameters - ---------- - json_path: Path - path to the JSON file summing the parameters of a MAPS. - - Returns - ------- - A dictionary of training parameters. - """ - from clinicadl.utils.iotools.utils import path_decoder - - with json_path.open(mode="r") as f: - parameters = json.load(f, object_hook=path_decoder) - # Types of retro-compatibility - # Change arg name: ex network --> model - # Change arg value: ex for preprocessing: mni --> t1-extensive - # New arg with default hard-coded value --> discarded_slice --> 20 - retro_change_name = { - "model": "architecture", - "multi": "multi_network", - "minmaxnormalization": "normalize", - "num_workers": "n_proc", - "mode": "extract_method", - } - - retro_add = { - "optimizer": "Adam", - "loss": None, - } - - for old_name, new_name in retro_change_name.items(): - if old_name in parameters: - parameters[new_name] = parameters[old_name] - del parameters[old_name] - - for name, value in retro_add.items(): - if name not in parameters: - parameters[name] = value - - if "extract_method" in parameters: - parameters["mode"] = parameters["extract_method"] - # Value changes - if "use_cpu" in parameters: - parameters["gpu"] = not parameters["use_cpu"] - del parameters["use_cpu"] - if "nondeterministic" in parameters: - parameters["deterministic"] = not parameters["nondeterministic"] - del parameters["nondeterministic"] - - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - - from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=parameters["mode"], - preprocessing_type=parameters["preprocessing"], - **parameters, - ) - if "file_type" not in parameters["preprocessing_dict"]: - _, file_type = compute_folder_and_file_type(config) - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() - - return parameters diff --git a/clinicadl/dataset/caps_reader.py b/clinicadl/dataset/caps_reader.py deleted file mode 100644 index 14199616e..000000000 --- a/clinicadl/dataset/caps_reader.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path -from typing import Optional - -from clinicadl.dataset.caps_dataset import CapsDataset -from clinicadl.dataset.config.extraction import ( - ExtractionConfig, - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.dataset.config.preprocessing import PreprocessingConfig -from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.transforms.config import TransformsConfig - - -class CapsReader: - def __init__(self, caps_directory: Path, manager: ExperimentManager): - """TO COMPLETE""" - pass - - def get_dataset( - self, - extraction: ExtractionConfig, - preprocessing: PreprocessingConfig, - sub_ses_tsv: Path, - transforms: TransformsConfig, - ) -> CapsDataset: - return CapsDataset(extraction, preprocessing, sub_ses_tsv, transforms) - - def get_preprocessing(self, preprocessing: str) -> PreprocessingConfig: - """TO COMPLETE""" - - return PreprocessingConfig() - - def extract_slice( - self, preprocessing: PreprocessingConfig, arg_slice: Optional[int] = None - ) -> ExtractionSliceConfig: - """TO COMPLETE""" - - return ExtractionSliceConfig() - - def extract_patch( - self, preprocessing: PreprocessingConfig, arg_patch: Optional[int] = None - ) -> ExtractionPatchConfig: - """TO COMPLETE""" - - return ExtractionPatchConfig() - - def extract_roi( - self, preprocessing: PreprocessingConfig, arg_roi: Optional[int] = None - ) -> ExtractionROIConfig: - """TO COMPLETE""" - - return ExtractionROIConfig() - - def extract_image( - self, preprocessing: PreprocessingConfig, arg_image: Optional[int] = None - ) -> ExtractionImageConfig: - """TO COMPLETE""" - - return ExtractionImageConfig() diff --git a/clinicadl/dataset/concat.py b/clinicadl/dataset/concat.py deleted file mode 100644 index f0b420dfe..000000000 --- a/clinicadl/dataset/concat.py +++ /dev/null @@ -1,6 +0,0 @@ -from clinicadl.dataset.caps_dataset import CapsDataset - - -class ConcatDataset(CapsDataset): - def __init__(self, list_: list[CapsDataset]): - """TO COMPLETE""" diff --git a/clinicadl/dataset/config/__init__.py b/clinicadl/dataset/config/__init__.py index e69de29bb..820581577 100644 --- a/clinicadl/dataset/config/__init__.py +++ b/clinicadl/dataset/config/__init__.py @@ -0,0 +1,22 @@ +from .data import DataConfig +from .extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from .preprocessing import ( + CustomPreprocessingConfig, + FlairPreprocessingConfig, + PETPreprocessingConfig, + PreprocessingConfig, + T1PreprocessingConfig, + T2PreprocessingConfig, +) +from .utils import ( + get_extraction, + get_preprocessing, + get_preprocessing_and_mode_from_json, + get_preprocessing_and_mode_from_parameters, +) diff --git a/clinicadl/dataset/config/data.py b/clinicadl/dataset/config/data.py new file mode 100644 index 000000000..b99dfba24 --- /dev/null +++ b/clinicadl/dataset/config/data.py @@ -0,0 +1,74 @@ +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import pandas as pd +from pydantic import BaseModel, ConfigDict, field_validator + +from clinicadl.dataset.data_utils import load_data_test +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLTSVError, +) + +logger = getLogger("clinicadl.data_config") + + +class DataConfig(BaseModel): # TODO : put in data module + """Config class to specify the data. + + caps_directory and preprocessing_json are arguments + that must be passed by the user. + """ + + caps_directory: Optional[Path] = None + baseline: bool = False + mask_path: Optional[Path] = None + data_tsv: Optional[Path] = None + n_subjects: int = 300 + # pydantic config + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + @field_validator("diagnoses", mode="before") + def validator_diagnoses(cls, v): + """Transforms a list to a tuple.""" + if isinstance(v, list): + return tuple(v) + return v # TODO : check if columns are in tsv + + def create_groupe_df(self): + group_df = None + if self.data_tsv is not None and self.data_tsv.is_file(): + group_df = load_data_test( + self.data_tsv, + multi_cohort=False, + ) + return group_df + + def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): + return ( + self.label is not None + and self.label != "" + and self.label != _label + and _label_code == "default" + ) + + def check_label(self, _label: str): + if not self.label: + self.label = _label + + @field_validator("data_tsv", mode="before") + @classmethod + def check_data_tsv(cls, v) -> Path: + if v is not None: + if not isinstance(v, Path): + v = Path(v) + if not v.is_file(): + raise ClinicaDLTSVError( + "The participants_list you gave is not a file. Please give an existing file." + ) + if v.stat().st_size == 0: + raise ClinicaDLTSVError( + "The participants_list you gave is empty. Please give a non-empty file." + ) + return v diff --git a/clinicadl/dataset/config/extraction.py b/clinicadl/dataset/config/extraction.py index f3619590f..9eecf5fed 100644 --- a/clinicadl/dataset/config/extraction.py +++ b/clinicadl/dataset/config/extraction.py @@ -1,19 +1,31 @@ +from abc import ABC, abstractmethod from logging import getLogger +from pathlib import Path from time import time -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +import nibabel as nib +import numpy as np +import torch from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt from clinicadl.utils.enum import ( ExtractionMethod, + Pattern, + Preprocessing, SliceDirection, SliceMode, + Template, ) +from clinicadl.utils.exceptions import ClinicaDLArgumentError from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.preprocessing_config") +NII_GZ = ".nii.gz" +PT = ".pt" + class ExtractionConfig(BaseModel): """ @@ -21,47 +33,438 @@ class ExtractionConfig(BaseModel): """ extract_method: ExtractionMethod - file_type: Optional[FileType] = None + extract_json: str = f"extract_{int(time())}.json" + use_uncropped_image: bool = True save_features: bool = False - extract_json: Optional[str] = None # pydantic config model_config = ConfigDict(validate_assignment=True) @field_validator("extract_json", mode="before") def compute_extract_json(cls, v: str): - if v is None: - return f"extract_{int(time())}.json" + """Ensures the extract_json filename has a .json extension.""" + if isinstance(v, Path): + v = str(v) elif not v.endswith(".json"): - return f"{v}.json" - else: - return v + v = f"{v}.json" + return v + + def extract_image(self, input_img: Path) -> torch.Tensor: + """Loads a NIfTI image and returns it as a float32 tensor.""" + image_array = nib.loadsave.load(input_img).get_fdata(dtype="float32") # type: ignore + return torch.from_numpy(image_array).unsqueeze(0).float() + + @abstractmethod + def extract_tensor( + self, + image_tensor: torch.Tensor, + index: int, + object_tensors: Optional[torch.Tensor] = None, + ): + """Extracts specific data from an image tensor.""" + pass + + @abstractmethod + def extract_path(self, image_path, index): + """Defines path for saving extracted elements.""" + pass + + @abstractmethod + def extract(self, nii_path: Path): + """Performs extraction based on the implemented method.""" + pass + + @abstractmethod + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + """Returns the number of extracted elements per image.""" + pass class ExtractionImageConfig(ExtractionConfig): + """ + Configuration class for full image extraction as a single tensor. + """ + extract_method: ExtractionMethod = ExtractionMethod.IMAGE + def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]: + """Extracts the full image as a single tensor file and saves it.""" + image_tensor = self.extract_image(nii_path) + output_file = Path(Path(nii_path.stem).stem + PT), image_tensor.clone() + return [output_file] + + def extract_tensor( + self, + image_tensor: torch.Tensor, + index: int, + ): + return image_tensor + + def extract_path(self, image_path, index): + return image_path + + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + return 1 + class ExtractionPatchConfig(ExtractionConfig): + """ + Configuration class for patch extraction from an image with defined patch size and stride. + """ + patch_size: int = 50 stride_size: int = 50 extract_method: ExtractionMethod = ExtractionMethod.PATCH + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + """Returns the total number of patches generated from the image.""" + if elem_index is not None: + return 1 + + return self.create_patches(image).shape[0] + + def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: + """Extracts patches from a NIfTI image tensor.""" + + image_tensor = self.extract_image(nii_path) + patches_tensor = self.create_patches(image_tensor) + patch_list = [ + (self.extract_path(nii_path, i), patches_tensor[i].unsqueeze(0)) + for i in range(patches_tensor.size(0)) + ] + return patch_list + + def extract_tensor( + self, image_tensor: torch.Tensor, patch_index: int + ) -> torch.Tensor: + """Extracts a single patch from image_tensor""" + patches_tensor = self.create_patches(image_tensor) + return patches_tensor[patch_index, ...].unsqueeze_(0).clone() + + def extract_path(self, img_path: Path, patch_index: int) -> Path: + """Constructs the save path for a given patch.""" + prefix_suffix = img_path.name.rsplit("_", 1) + return Path( + f"{prefix_suffix[0]}_patchsize-{self.patch_size}_stride-{self.stride_size}_patch-{patch_index}{prefix_suffix[1].replace(NII_GZ, PT)}" + ) + + def create_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: + """Creates a tensor of patches from the image using `unfold`.""" + patches_tensor = ( + image_tensor.unfold(1, self.patch_size, self.stride_size) + .unfold(2, self.patch_size, self.stride_size) + .unfold(3, self.patch_size, self.stride_size) + .contiguous() + ) + return patches_tensor.view( + -1, self.patch_size, self.patch_size, self.patch_size + ) + class ExtractionSliceConfig(ExtractionConfig): + """ + Configuration class for slice extraction from an image in specified directions. + """ + slice_direction: SliceDirection = SliceDirection.SAGITTAL slice_mode: SliceMode = SliceMode.RGB - num_slices: Optional[NonNegativeInt] = None - discarded_slices: Tuple[NonNegativeInt, NonNegativeInt] = (0, 0) + # num_slices: Optional[NonNegativeInt] = None # not sure it is needed + discarded_slices: Tuple[int, int] = (0, 0) extract_method: ExtractionMethod = ExtractionMethod.SLICE + @field_validator("discarded_slices", mode="before") + def validate_discarded_slice(cls, v: Union[int, Tuple]) -> Tuple[int, int]: + if isinstance(v, int): + return (v, v) + elif len(v) == 1: + return (v[0], v[0]) + elif len(v) == 2: + return v + else: + raise IndexError( + f"Maximum two number of discarded slices can be defined. " + f"You gave discarded slices = {v}." + ) + + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + if elem_index is not None: + return 1 + # if self.num_slices is not None: + # return self.num_slices + direction = int(self.slice_direction) + return image.size(direction + 1) - sum(self.discarded_slices) + + def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: + """Extracts slices from the image in the specified direction.""" + image_tensor = self.extract_image(nii_path) + start, end = self.discarded_slices + slices = [] + for i in range( + start, image_tensor.size(int(self.slice_direction.value) + 1) - end + ): + slice_tensor = self.extract_tensor(image_tensor, i) + slices.append((self.extract_path(nii_path, i), slice_tensor)) + return slices + + def extract_tensor( + self, image_tensor: torch.Tensor, slice_index: int + ) -> torch.Tensor: + idx_tuple = tuple( + [slice(None)] * (int(self.slice_direction) + 1) + + [slice_index + self.discarded_slices[0]] + + [slice(None)] * (2 - int(self.slice_direction)) + ) + slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L + if self.slice_mode == SliceMode.RGB: + slice_tensor = torch.cat([slice_tensor] * 3) # shape is 3 * W * L + return slice_tensor.clone() + + def extract_path(self, img_path: Path, slice_index: int) -> Path: + """Constructs the save path for a given slice.""" + prefix_suffix = img_path.name.rsplit("_", 1) + slice_dict = {0: "sag", 1: "cor", 2: "axi"} + + return Path( + f"{prefix_suffix[0]}_axis-{slice_dict[int(self.slice_direction.value)]}" + f"_channel-{self.slice_mode.value}_slice-{slice_index}{prefix_suffix[1].replace(NII_GZ, PT)}" + ) + class ExtractionROIConfig(ExtractionConfig): + """ + Configuration class for extracting regions of interest (ROIs) from images using masks. + """ + roi_list: List[str] = [] - roi_uncrop_output: bool = False - roi_custom_template: str = "" - roi_custom_pattern: str = "" - roi_custom_suffix: str = "" - roi_custom_mask_pattern: str = "" - roi_background_value: int = 0 + roi_crop_input: bool = True + roi_crop_output: bool = True + roi_template: str = "" + roi_mask_pattern: str = "" + roi_mask_location: Path + + # roi_custom_template: str = "" + # roi_custom_mask_pattern: str = "" extract_method: ExtractionMethod = ExtractionMethod.ROI + + @field_validator("roi_mask_pattern", "before") + def validate_roi_mask_pattern(cls, v: str) -> str: + """Check that pattern begins and ends with _ to avoid mixing keys""" + if not v: + raise ClinicaDLArgumentError("A mask pattern must be defined.") + if not v.startswith("_"): + v = "_" + v + if not v.endswith("_"): + v = v + "_" + return v + + @field_validator("roi_list", mode="before") + def validate_roi_list(cls, v: List[str]) -> List[str]: + if not v: + raise NotImplementedError( + "Default regions are not available anymore in ClinicaDL. " + "Please define appropriate masks and give a roi_list." + ) + if len(v) == 0: + raise ClinicaDLArgumentError("A list of regions of interest must be given.") + + return v + + def num_elem_per_image( + self, image: torch.Tensor, elem_index: Optional[int] = None + ) -> int: + return 1 if elem_index is not None else len(self.roi_list) + + def check_with_preprocessing(self, preprocessing: Preprocessing): + if preprocessing == Preprocessing.CUSTOM: + if not self.roi_template: + raise ClinicaDLArgumentError( + "A custom template must be defined when the modality is set to custom." + ) + # self.roi_template = self.roi_custom_template + # self.roi_mask_pattern = self.roi_custom_mask_pattern + else: + if preprocessing == Preprocessing.T1_LINEAR: + self.roi_template = Template.T1_LINEAR + self.roi_mask_pattern = Pattern.T1_LINEAR + elif preprocessing == Preprocessing.PET_LINEAR: + self.roi_template = Template.PET_LINEAR + self.roi_mask_pattern = Pattern.PET_LINEAR + elif preprocessing == Preprocessing.FLAIR_LINEAR: + self.roi_template = Template.FLAIR_LINEAR + self.roi_mask_pattern = Pattern.FLAIR_LINEAR + + def check_mask_list(self, masks_location: Path) -> None: + for roi in self.roi_list: + roi_path, desc = self.find_mask_path(masks_location, roi) + if roi_path is None: + raise FileNotFoundError( + f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" + ) + roi_mask = nib.loadsave.load(roi_path).get_fdata() # type: ignore + mask_values = set(np.unique(roi_mask)) + if mask_values != {0, 1}: + raise ValueError( + "The ROI masks used should be binary (composed of 0 and 1 only)." + ) + + def find_mask_path(self, masks_location: Path, roi: str) -> Tuple[Path, str]: + """ + Finds masks corresponding to the pattern asked and containing the adequate self.roi_crop_input description + + Parameters + ---------- + masks_location: Path + Directory containing the masks. + roi: str + Name of the region. + + Returns + ------- + path of the mask or None if nothing was found. + a human-friendly description of the pattern looked for. + """ + + candidates_pattern = f"*{self.roi_mask_pattern}*_roi-{roi}_mask.nii*" + + desc = f"The mask should follow the pattern {candidates_pattern}. " + candidates = [e for e in masks_location.glob(candidates_pattern)] + if self.roi_crop_input is None: + # pass + candidates2 = candidates + elif self.roi_crop_input: + candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] + desc += "and contain '_desc-Crop_' string." + else: + candidates2 = [ + mask for mask in candidates if "_desc-Crop_" not in mask.name + ] + desc += "and not contain '_desc-Crop_' string." + + if len(candidates2) == 0: + raise FileNotFoundError( + f"Could not find any masks corresponding to the pattern asked and containing the adequate {self.roi_crop_input} description " + ) + # return None, desc + else: + return min(candidates2), desc + + def compute_output_pattern(self, mask_path: Path): + """ + Computes the output pattern of the region cropped (without the source file prefix) + Parameters + ---------- + mask_path: Path + Path to the masks + self.roi_crop_output: bool + If True the output is cropped, and the descriptor CropRoi must exist + + Returns + ------- + the output pattern + """ + + mask_filename = mask_path.name + template_id = mask_filename.split("_")[0].split("-")[1] + mask_descriptors = mask_filename.split("_")[1:-2:] + roi_id = mask_filename.split("_")[-2].split("-")[1] + if "desc-Crop" not in mask_descriptors and not self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + elif "desc-Crop" in mask_descriptors: + mask_descriptors = [ + descriptor + for descriptor in mask_descriptors + if descriptor != "desc-Crop" + ] + if self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + else: + mask_descriptors = ["desc-CropImage"] + mask_descriptors + + mask_pattern = "_".join(mask_descriptors) + + if mask_pattern == "": + output_pattern = f"space-{template_id}_roi-{roi_id}" + else: + output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" + + return output_pattern + + def extract(self, nii_path: Path) -> List[Tuple[str, torch.Tensor]]: + """Extracts roi from a NIfTI image tensor.""" + image_tensor = self.extract_image(nii_path) + roi_list = [] + for roi_name in self.roi_list: + mask_path, _ = self.find_mask_path(self.roi_mask_location, roi_name) + mask_np = nib.loadsave.load(mask_path).get_fdata() # type: ignore + roi_list.append( + ( + self.extract_tensor(image_tensor, mask_np), + self.extract_path(nii_path, mask_path), + ) + ) + return roi_list + + def extract_tensor(self, image_tensor: torch.Tensor, roi_idx: int) -> torch.Tensor: + _, mask_arrays = self._get_mask_paths_and_tensors() + mask_np = mask_arrays[roi_idx] + + if len(mask_np.shape) == 3: + mask_np = np.expand_dims(mask_np, axis=0) + elif len(mask_np.shape) == 4: + assert mask_np.shape[0] == 1 + else: + raise ValueError( + "ROI masks must be 3D or 4D tensors. " + f"The dimension of your ROI mask is {len(mask_np.shape)}." + ) + + roi_tensor = image_tensor * mask_np + if self.roi_crop_output: + roi_tensor = roi_tensor[ + np.ix_( + mask_np.any((1, 2, 3)), + mask_np.any((0, 2, 3)), + mask_np.any((0, 1, 3)), + mask_np.any((0, 1, 2)), + ) + ] + return roi_tensor.float().clone() + + def extract_path(self, img_path: Path, mask_path: Path) -> str: + input_img_filename = img_path.name + + sub_ses_prefix = "_".join(input_img_filename.split("_")[0:3:]) + if not sub_ses_prefix.endswith("_T1w"): + sub_ses_prefix = "_".join(input_img_filename.split("_")[0:2:]) + input_suffix = input_img_filename.split("_")[-1].split(".")[0] + + output_pattern = self.compute_output_pattern(mask_path) + + return f"{sub_ses_prefix}_{output_pattern}_{input_suffix}{PT}" + + def _get_mask_paths_and_tensors(self) -> Tuple[List[str], List]: + """Loads the masks necessary to regions extraction""" + mask_location = ( + self.roi_mask_location + / f"tpl-{self.roi_template}" # caps_directory / "masks" = mask_location + ) + + mask_paths, mask_arrays = list(), list() + for roi in self.roi_list: + logger.info(f"Find mask for roi {roi}.") + mask_path, desc = self.find_mask_path(mask_location, roi) + if mask_path is None: + raise FileNotFoundError(desc) + mask_nii = nib.loadsave.load(mask_path) + mask_paths.append(Path(mask_path)) + mask_arrays.append(mask_nii.get_fdata()) # type: ignore + + return mask_paths, mask_arrays + + +ALL_EXTRACTION_TYPES = Union[ + ExtractionImageConfig, + ExtractionROIConfig, + ExtractionSliceConfig, + ExtractionPatchConfig, +] diff --git a/clinicadl/dataset/config/preprocessing.py b/clinicadl/dataset/config/preprocessing.py index ad8db765e..8b9314a6d 100644 --- a/clinicadl/dataset/config/preprocessing.py +++ b/clinicadl/dataset/config/preprocessing.py @@ -1,16 +1,20 @@ +import abc from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, computed_field from clinicadl.utils.enum import ( DTIMeasure, DTISpace, + ImageModality, + LinearModality, Preprocessing, SUVRReferenceRegions, Tracer, ) +from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.modality_config") @@ -20,38 +24,203 @@ class PreprocessingConfig(BaseModel): Abstract config class for the preprocessing procedure. """ - tsv_file: Optional[Path] = None + from_bids: bool = False preprocessing: Preprocessing use_uncropped_image: bool = False # pydantic config - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + @abc.abstractmethod + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + """Abstract method to get the BIDS filetype.""" + pass + + @abc.abstractmethod + def caps_nii(self) -> tuple: + """Abstract method to retrieve CAPS file information.""" + pass + + @abc.abstractmethod + def get_filetype(self) -> FileType: + """Abstract method to obtain FileType details.""" + pass + + def compute_folder(self, from_bids: bool = False) -> str: + return ( + self.preprocessing.value + if from_bids + else self.preprocessing.value.replace("-", "_") + ) + + @computed_field + @property + def file_type(self) -> FileType: + if self.from_bids: + return self.bids_nii() + elif self.preprocessing not in Preprocessing: + raise NotImplementedError( + f"Extraction of preprocessing {self.preprocessing.value} is not implemented from CAPS directory." + ) + else: + return self.get_filetype() + + def linear_nii(self) -> FileType: + """ + Constructs the file type for linear caps image data + """ + needed_pipeline, modality = self.caps_nii() + desc_crop = "" if self.use_uncropped_image else "_desc-Crop" + + file_type = FileType( + pattern=f"*space-MNI152NLin2009cSym{desc_crop}_res-1x1x1_{modality.value}.nii.gz", + description=f"{modality.value} Image registered in MNI152NLin2009cSym space using {needed_pipeline.value} pipeline " + + ( + "" + if self.use_uncropped_image + else "and cropped (matrix size 169×208×179, 1 mm isotropic voxels)" + ), + needed_pipeline=needed_pipeline, + ) + return file_type class PETPreprocessingConfig(PreprocessingConfig): + """ + Configuration for PET image preprocessing + """ + tracer: Tracer = Tracer.FFDG suvr_reference_region: SUVRReferenceRegions = SUVRReferenceRegions.CEREBELLUMPONS2 preprocessing: Preprocessing = Preprocessing.PET_LINEAR + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + trc, rec, description = "", "", "PET data" + if self.tracer: + description += f" with {self.tracer.value} tracer" + trc = f"_trc-{self.tracer.value}" + if reconstruction: + description += f" and reconstruction method {reconstruction}" + rec = f"_rec-{reconstruction}" + + return FileType(pattern=f"pet/*{trc}{rec}_pet.nii*", description=description) + + def caps_nii(self) -> Tuple[Preprocessing, ImageModality]: + return (self.preprocessing, ImageModality.PET) + + def get_filetype(self) -> FileType: + des_crop = "" if self.use_uncropped_image else "_desc-Crop" + + return FileType( + pattern=f"pet_linear/*_trc-{self.tracer.value}_space-MNI152NLin2009cSym{des_crop}_res-1x1x1_suvr-{self.suvr_reference_region.value}_pet.nii.gz", + description="", + needed_pipeline="pet-linear", + ) + class CustomPreprocessingConfig(PreprocessingConfig): + """ + Configuration for custom preprocessing with a user-defined suffix. + """ + custom_suffix: str = "" preprocessing: Preprocessing = Preprocessing.CUSTOM + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + return FileType( + pattern=f"*{self.custom_suffix}", + description="Custom suffix", + ) + + def caps_nii(self) -> tuple: + return (self.preprocessing, ImageModality.CUSTOM) + + def get_filetype(self) -> FileType: + return self.bids_nii() + class DTIPreprocessingConfig(PreprocessingConfig): + """ + Configuration for DTI-based preprocessing + """ + dti_measure: DTIMeasure = DTIMeasure.FRACTIONAL_ANISOTROPY dti_space: DTISpace = DTISpace.ALL preprocessing: Preprocessing = Preprocessing.DWI_DTI + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="dwi/sub-*_ses-*_dwi.nii*", description="DWI NIfTI") + + def caps_nii(self) -> tuple: + return (self.preprocessing, ImageModality.DWI) + + def get_filetype(self) -> FileType: + """Return the query dict required to capture DWI DTI images. + + Parameters + ---------- + config: DTIPreprocessingConfig + + Returns + ------- + FileType : + """ + measure = self.dti_measure + space = self.dti_space + + return FileType( + pattern=f"dwi/dti_based_processing/*/*_space-{space}_{measure.value}.nii.gz", + description=f"DTI-based {measure.value} in space {space}.", + needed_pipeline="dwi_dti", + ) + class T1PreprocessingConfig(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.T1_LINEAR + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="anat/sub-*_ses-*_T1w.nii*", description="T1w MRI") + + def caps_nii(self) -> tuple: + return (self.preprocessing, LinearModality.T1W) + + def get_filetype(self) -> FileType: + return self.linear_nii() + class FlairPreprocessingConfig(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.FLAIR_LINEAR + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + return FileType(pattern="sub-*_ses-*_flair.nii*", description="FLAIR T2w MRI") + + def caps_nii(self) -> tuple: + return (self.preprocessing, LinearModality.T2W) + + def get_filetype(self) -> FileType: + return self.linear_nii() + class T2PreprocessingConfig(PreprocessingConfig): preprocessing: Preprocessing = Preprocessing.T2_LINEAR + + def bids_nii(self, reconstruction: Optional[str] = None) -> FileType: + raise NotImplementedError( + f"Extraction of preprocessing {self.preprocessing.value} is not implemented from BIDS directory." + ) + + def caps_nii(self) -> tuple: + return (self.preprocessing, LinearModality.FLAIR) + + def get_filetype(self) -> FileType: + return self.linear_nii() + + +ALL_PREPROCESSING_TYPES = Union[ + T1PreprocessingConfig, + T2PreprocessingConfig, + FlairPreprocessingConfig, + PETPreprocessingConfig, + CustomPreprocessingConfig, + DTIPreprocessingConfig, +] diff --git a/clinicadl/dataset/config/utils.py b/clinicadl/dataset/config/utils.py new file mode 100644 index 000000000..8d34d2836 --- /dev/null +++ b/clinicadl/dataset/config/utils.py @@ -0,0 +1,98 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +from logging import getLogger +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +from pydantic import BaseModel, ConfigDict + +from clinicadl.dataset.config import extraction, preprocessing +from clinicadl.dataset.transforms.transforms import Transforms +from clinicadl.utils.enum import ExtractionMethod, Preprocessing +from clinicadl.utils.iotools.utils import read_preprocessing + +logger = getLogger("clinicadl") + + +def get_extraction( + extract_method: Union[str, ExtractionMethod], +) -> type[extraction.ALL_EXTRACTION_TYPES]: + extract_method = ExtractionMethod(extract_method) + if extract_method == ExtractionMethod.ROI: + return extraction.ExtractionROIConfig + elif extract_method == ExtractionMethod.SLICE: + return extraction.ExtractionSliceConfig + elif extract_method == ExtractionMethod.IMAGE: + return extraction.ExtractionImageConfig + elif extract_method == ExtractionMethod.PATCH: + return extraction.ExtractionPatchConfig + else: + raise ValueError(f"Preprocessing {extract_method.value} is not implemented.") + + +def get_preprocessing( + preprocessing_type: Union[str, Preprocessing], +) -> type[preprocessing.ALL_PREPROCESSING_TYPES]: + preprocessing_type = Preprocessing(preprocessing_type) + if preprocessing_type == Preprocessing.T1_LINEAR: + return preprocessing.T1PreprocessingConfig + elif preprocessing_type == Preprocessing.PET_LINEAR: + return preprocessing.PETPreprocessingConfig + elif preprocessing_type == Preprocessing.FLAIR_LINEAR: + return preprocessing.FlairPreprocessingConfig + elif preprocessing_type == Preprocessing.CUSTOM: + return preprocessing.CustomPreprocessingConfig + elif preprocessing_type == Preprocessing.DWI_DTI: + return preprocessing.DTIPreprocessingConfig + else: + raise ValueError( + f"Preprocessing {preprocessing_type.value} is not implemented." + ) + + +def get_infos_from_json( + json_path: Path, +) -> Tuple[ + preprocessing.ALL_PREPROCESSING_TYPES, extraction.ALL_EXTRACTION_TYPES, Transforms +]: + """ + Extracts the preprocessing and mode from a json file. + + Parameters + ---------- + json_path : Path + Path to the json file containing the preprocessing and mode. + + Returns + ------- + Tuple[Preprocessing, SliceMode] + The preprocessing and mode extracted from the json file. + """ + + dict_ = read_preprocessing(json_path) + return get_infos_from_parameters(**dict_) + + +def get_infos_from_parameters( + **kwargs, +) -> Tuple[ + preprocessing.ALL_PREPROCESSING_TYPES, extraction.ALL_EXTRACTION_TYPES, Transforms +]: + """ + Extracts the preprocessing and mode from a json file. + + Returns + ------- + Tuple[Preprocessing, SliceMode] + The preprocessing and mode extracted from the json file. + """ + + if "preprocessing_dict" in kwargs: + kwargs = kwargs["preprocessing_dict"] + + preprocessing = Preprocessing(kwargs["preprocessing"]) + mode = ExtractionMethod(kwargs["extract_method"]) + extraction = get_extraction(mode)(**kwargs) + transforms = Transforms(extraction=extraction, **kwargs) + return get_preprocessing(preprocessing)(**kwargs), extraction, transforms diff --git a/clinicadl/dataset/data_config.py b/clinicadl/dataset/data_config.py deleted file mode 100644 index 39e6a6254..000000000 --- a/clinicadl/dataset/data_config.py +++ /dev/null @@ -1,164 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union - -import pandas as pd -from pydantic import BaseModel, ConfigDict, computed_field, field_validator - -from clinicadl.utils.enum import Mode -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLTSVError, -) -from clinicadl.utils.iotools.clinica_utils import check_caps_folder -from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test -from clinicadl.utils.iotools.utils import read_preprocessing - -logger = getLogger("clinicadl.data_config") - - -class DataConfig(BaseModel): # TODO : put in data module - """Config class to specify the data. - - caps_directory and preprocessing_json are arguments - that must be passed by the user. - """ - - caps_directory: Optional[Path] = None - baseline: bool = False - diagnoses: Tuple[str, ...] = ("AD", "CN") - data_df: Optional[pd.DataFrame] = None - label: Optional[str] = None - label_code: Union[str, Dict[str, int], None] = {} - multi_cohort: bool = False - mask_path: Optional[Path] = None - preprocessing_json: Optional[Path] = None - data_tsv: Optional[Path] = None - n_subjects: int = 300 - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @field_validator("diagnoses", mode="before") - def validator_diagnoses(cls, v): - """Transforms a list to a tuple.""" - if isinstance(v, list): - return tuple(v) - return v # TODO : check if columns are in tsv - - def create_groupe_df(self): - group_df = None - if self.data_tsv is not None and self.data_tsv.is_file(): - group_df = load_data_test( - self.data_tsv, - self.diagnoses, - multi_cohort=self.multi_cohort, - ) - return group_df - - def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): - return ( - self.label is not None - and self.label != "" - and self.label != _label - and _label_code == "default" - ) - - def check_label(self, _label: str): - if not self.label: - self.label = _label - - @field_validator("data_tsv", mode="before") - @classmethod - def check_data_tsv(cls, v) -> Path: - if v is not None: - if not isinstance(v, Path): - v = Path(v) - if not v.is_file(): - raise ClinicaDLTSVError( - "The participants_list you gave is not a file. Please give an existing file." - ) - if v.stat().st_size == 0: - raise ClinicaDLTSVError( - "The participants_list you gave is empty. Please give a non-empty file." - ) - return v - - @computed_field - @property - def caps_dict(self) -> Dict[str, Path]: - if self.multi_cohort: - if self.caps_directory.suffix != ".tsv": - raise ClinicaDLArgumentError( - "If multi_cohort is True, the CAPS_DIRECTORY argument should be a path to a TSV file." - ) - else: - caps_df = pd.read_csv(self.caps_directory, sep="\t") - check_multi_cohort_tsv(caps_df, "CAPS") - caps_dict = dict() - for idx in range(len(caps_df)): - cohort = caps_df.loc[idx, "cohort"] - caps_path = Path(caps_df.at[idx, "path"]) - check_caps_folder(caps_path) - caps_dict[cohort] = caps_path - else: - check_caps_folder(self.caps_directory) - caps_dict = {"single": self.caps_directory} - - return caps_dict - - @computed_field - @property - def preprocessing_dict(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - - Raises - ------ - ValueError - In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. - """ - - if self.preprocessing_json is not None: - if not self.multi_cohort: - preprocessing_json = ( - self.caps_directory / "tensor_extraction" / self.preprocessing_json - ) - else: - caps_dict = self.caps_dict - json_found = False - for caps_name, caps_path in caps_dict.items(): - preprocessing_json = ( - caps_path / "tensor_extraction" / self.preprocessing_json - ) - if preprocessing_json.is_file(): - logger.info( - f"Preprocessing JSON {preprocessing_json} found in CAPS {caps_name}." - ) - json_found = True - if not json_found: - raise ValueError( - f"Preprocessing JSON {self.preprocessing_json} was not found for any CAPS " - f"in {caps_dict}." - ) - - preprocessing_dict = read_preprocessing(preprocessing_json) - - if ( - preprocessing_dict["mode"] == "roi" - and "roi_background_value" not in preprocessing_dict - ): - preprocessing_dict["roi_background_value"] = 0 - - return preprocessing_dict - else: - return None - - @computed_field - @property - def mode(self) -> Mode: - return Mode(self.preprocessing_dict["mode"]) diff --git a/clinicadl/dataset/dataloader_config.py b/clinicadl/dataset/dataloader_config.py deleted file mode 100644 index cc01ba9a9..000000000 --- a/clinicadl/dataset/dataloader_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from logging import getLogger - -from pydantic import BaseModel, ConfigDict -from pydantic.types import PositiveInt - -from clinicadl.utils.enum import Sampler - -logger = getLogger("clinicadl.dataloader_config") - - -class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module - """Config class to configure the DataLoader.""" - - batch_size: PositiveInt = 8 - n_proc: PositiveInt = 2 - sampler: Sampler = Sampler.RANDOM - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/dataset/datasets/caps_dataset.py new file mode 100644 index 000000000..ee0b01e97 --- /dev/null +++ b/clinicadl/dataset/datasets/caps_dataset.py @@ -0,0 +1,260 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import nibabel as nib +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.transforms.transforms import Transforms +from clinicadl.dataset.utils import CapsDatasetOutput +from clinicadl.utils.enum import ( + ExtractionMethod, + Pattern, + Preprocessing, + SliceDirection, + SliceMode, + Template, +) +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader + +logger = getLogger("clinicadl") + + +class CapsDataset(Dataset): + """Abstract class for all derived CapsDatasets.""" + + def __init__( + self, + caps_directory: Path, + data_df: pd.DataFrame, + preprocessing: PreprocessingConfig, + transforms: Transforms, + index: Optional[int] = None, + ): + self.caps_directory = caps_directory + self.subjects_directory = caps_directory / "subjects" + self.preprocessing = preprocessing + self.transforms = transforms + self.extraction = transforms.extraction + self.image_0 = self._get_full_image() + self.df = data_df + mandatory_col = { + "participant_id", + "session_id", + "cohort", + } + + if not mandatory_col.issubset(set(self.df.columns.values)): + raise ClinicaDLTSVError( + f"the data file is not in the correct format." + f"Columns should include {mandatory_col}" + ) + self.elem_index = index + self.elem_per_image = self.extraction.num_elem_per_image( + elem_index=self.elem_index, image=self.image_0 + ) + self.size = self[0].image.size() + + def __len__(self) -> int: + return len(self.df) * self.elem_per_image + + def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: + """ + Gets the path to the tensor image (*.pt) + + Args: + participant: ID of the participant. + session: ID of the session. + cohort: Name of the cohort. + Returns: + image_path: path to the tensor containing the whole image. + """ + + # Try to find .nii.gz file + try: + results = clinicadl_file_reader( + [participant], + [session], + self.caps_directory, + self.preprocessing.file_type.model_dump(), + ) + logger.debug(f"clinicadl_file_reader output: {results}") + filepath = Path(results[0][0]) + image_filename = filepath.name.replace(".pt", ".nii.gz") + + image_dir = ( + self.caps_directory + / "subjects" + / participant + / session + / "deep_learning_prepare_data" + / "image" + / self.preprocessing.compute_folder() + ) + image_path = image_dir / image_filename + # Try to find .pt file + except ClinicaDLCAPSError: + self.preprocessing.file_type.pattern = ( + self.preprocessing.file_type.pattern.replace(".nii.gz", ".pt") + ) + results = clinicadl_file_reader( + [participant], + [session], + self.caps_directory, + self.preprocessing.file_type.model_dump(), + ) + filepath = results[0] + image_path = Path(filepath[0]) + + return image_path + + def _get_meta_data(self, idx: int) -> Tuple[str, str, str, int]: + """ + Gets all meta data necessary to compute the path with _get_image_path + + Args: + idx (int): row number of the meta-data contained in self.df + Returns: + participant (str): ID of the participant. + session (str): ID of the session. + cohort (str): Name of the cohort. + elem_index (int): Index of the part of the image. + label (str or float or int): value of the label to be used in criterion. + """ + image_idx = idx // self.elem_per_image + participant = self.df.at[image_idx, "participant_id"] + session = self.df.at[image_idx, "session_id"] + cohort = self.df.at[image_idx, "cohort"] + + if self.elem_index is None: + elem_idx = idx % self.elem_per_image + else: + elem_idx = self.elem_index + + return participant, session, cohort, elem_idx + + def _get_full_image(self) -> torch.Tensor: + """ + Allows to get the an example of the image mode corresponding to the dataset. + Useful to compute the number of elements if mode != image. + + Returns: + image tensor of the full image first image. + """ + + participant_id = self.df.at[0, "participant_id"] + session_id = self.df.at[0, "session_id"] + cohort = self.df.at[0, "cohort"] + + try: + image_path = self._get_image_path(participant_id, session_id, cohort) + image = torch.load(image_path, weights_only=True) + except IndexError: + results = clinicadl_file_reader( + [participant_id], + [session_id], + self.caps_directory, + self.preprocessing.file_type.model_dump(), + ) + image_nii = nib.loadsave.load((results[0])) # type: ignore + image_np = image_nii.get_fdata() # type: ignore + image = ToTensor()(image_np) + + return image + + def __getitem__(self, idx: int) -> CapsDatasetOutput: + """ + Gets the sample containing all the information needed for training and testing tasks. + + Args: + idx: row number of the meta-data contained in self.df + Returns: + dictionary with following items: + - "image" (torch.Tensor): the input given to the model, + - "label" (int or float): the label used in criterion, + - PARTICIPANT_ID (str): ID of the participant, + - SESSION_ID (str): ID of the session, + - f"{self.mode}_id" (int): number of the element, + - "image_path": path to the image loaded in CAPS. + + """ + participant, session, cohort, index = self._get_meta_data(idx) + + image_path = self._get_image_path(participant, session, cohort) + image = torch.load(image_path, weights_only=True) + + ( + image_trf, + object_trf, + image_augmentation, + object_augmentation, + ) = self.transforms.get_transforms() + + image = image_trf(image) + + if image_augmentation and not self.eval_mode: + image = image_augmentation(image) + + if not isinstance(self.extraction, ExtractionImageConfig): + tensor = self.transforms.extraction.extract_tensor( + image, + index, + ) + if object_trf: + tensor = object_trf(tensor) + + if object_augmentation and not self.eval_mode: + tensor = object_augmentation(tensor) + + out = tensor + index = 0 + + else: + out = image + + sample = CapsDatasetOutput( + image=out, + # label=label, + participant_id=participant, + session_id=session, + image_id=index, + image_path=image_path, + mode=self.extraction.extract_method, + ) + + return sample + + def num_elem_per_image(self) -> int: + """Computes the number of elements per image based on the full image.""" + return self.extraction.num_elem_per_image( + elem_index=self.elem_index, image=self.image_0 + ) + + def eval(self): + """Put the dataset on evaluation mode (data augmentation is not performed).""" + self.eval_mode = True + return self + + def train(self): + """Put the dataset on training mode (data augmentation is performed).""" + self.eval_mode = False + return self diff --git a/clinicadl/dataset/datasets/concat.py b/clinicadl/dataset/datasets/concat.py new file mode 100644 index 000000000..4c1bbf7fe --- /dev/null +++ b/clinicadl/dataset/datasets/concat.py @@ -0,0 +1,132 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel +from torch.utils.data import Dataset + +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.config.utils import ( + get_infos_from_json, +) +from clinicadl.dataset.datasets.caps_dataset import CapsDataset +from clinicadl.dataset.transforms.transforms import Transforms +from clinicadl.dataset.utils import CapsDatasetOutput +from clinicadl.utils.enum import ( + ExtractionMethod, + Pattern, + Preprocessing, + SliceDirection, + SliceMode, + Template, +) +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLConcatError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.clinica_utils import check_caps_folder +from clinicadl.utils.iotools.utils import path_decoder, read_preprocessing + +logger = getLogger("clinicadl") + + +class ConcatDataset(CapsDataset): + def __init__(self, datasets: List[CapsDataset]): + self._datasets = datasets + self._len = sum(len(dataset) for dataset in datasets) + self._indexes = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + logger.debug(f"Datasets summary length: {self._len}") + logger.debug(f"Datasets indexes: {self._indexes}") + + self.caps_dict = self.compute_caps_dict() + self.check_configs() + + self.eval_mode = False + + def __getitem__(self, index: int) -> Optional[CapsDatasetOutput]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len + + def check_configs(self): + extraction = self._datasets[len(self._datasets) - 1].extraction + preprocessing = self._datasets[len(self._datasets) - 1].preprocessing + transforms = self._datasets[len(self._datasets) - 1].transforms + size = self._datasets[len(self._datasets) - 1].size + elem_per_image = self._datasets[len(self._datasets) - 1].elem_per_image + + for idx in range(len(self._datasets) - 1): + if self._datasets[idx].extraction != extraction: + raise ClinicaDLConcatError( + f"Different extraction modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].extraction}, " + f"Dataset {len(self._datasets)}: {extraction}" + ) + + if self._datasets[idx].preprocessing != preprocessing: + raise ClinicaDLConcatError( + f"Different preprocessing modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].preprocessing}, " + f"Dataset {len(self._datasets)}: {preprocessing}" + ) + + if self._datasets[idx].transforms != transforms: + raise ClinicaDLConcatError( + f"Different transforms modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].transforms}, " + f"Dataset {len(self._datasets)}: {transforms}" + ) + if self._datasets[idx].size != size: + raise ClinicaDLConcatError( + f"Different size modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].size}, " + f"Dataset {len(self._datasets)}: {size}" + ) + if self._datasets[idx].elem_per_image != elem_per_image: + raise ClinicaDLConcatError( + f"Different elem_per_image modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].elem_per_image}, " + f"Dataset {len(self._datasets)}: {elem_per_image}" + ) + + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + self.size = size + self.elem_per_image = elem_per_image + + def compute_caps_dict(self) -> Dict[str, Path]: + caps_dict = dict() + for idx in range(len(self._datasets)): + cohort = idx + caps_path = self._datasets[idx].caps_directory + check_caps_folder(caps_path) + caps_dict[cohort] = caps_path + + return caps_dict diff --git a/clinicadl/dataset/prepare_data/prepare_data.py b/clinicadl/dataset/prepare_data/prepare_data.py deleted file mode 100644 index d9ef1c412..000000000 --- a/clinicadl/dataset/prepare_data/prepare_data.py +++ /dev/null @@ -1,230 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Optional - -from joblib import Parallel, delayed -from torch import save as save_tensor - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.dataset.config.extraction import ( - ExtractionConfig, - ExtractionImageConfig, - ExtractionPatchConfig, - ExtractionROIConfig, - ExtractionSliceConfig, -) -from clinicadl.utils.enum import ExtractionMethod, Pattern, Preprocessing, Template -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import ( - check_caps_folder, - clinicadl_file_reader, - container_from_filename, - determine_caps_or_bids, - get_subject_session_list, -) -from clinicadl.utils.iotools.utils import write_preprocessing - -from .prepare_data_utils import check_mask_list - - -def DeepLearningPrepareData( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -): - logger = getLogger("clinicadl.prepare_data") - # Get subject and session list - if from_bids is not None: - try: - input_directory = Path(from_bids) - except ClinicaDLArgumentError: - logger.warning("Your BIDS directory doesn't exist.") - logger.debug(f"BIDS directory: {input_directory}.") - is_bids_dir = True - else: - input_directory = config.data.caps_directory - check_caps_folder(input_directory) - logger.debug(f"CAPS directory: {input_directory}.") - is_bids_dir = False - - subjects, sessions = get_subject_session_list( - input_directory, config.data.data_tsv, is_bids_dir, False, None - ) - - if config.extraction.save_features: - logger.info( - f"{config.extraction.extract_method.value}s will be extracted in Pytorch tensor from {len(sessions)} images." - ) - else: - logger.info( - f"Images will be extracted in Pytorch tensor from {len(sessions)} images." - ) - logger.info( - f"Information for {config.extraction.extract_method.value} will be saved in output JSON file and will be used " - f"during training for on-the-fly extraction." - ) - logger.debug(f"List of subjects: \n{subjects}.") - logger.debug(f"List of sessions: \n{sessions}.") - - # Select the correct filetype corresponding to modality - # and select the right folder output name corresponding to modality - logger.debug( - f"Selected images are preprocessed with {config.preprocessing} pipeline`." - ) - - mod_subfolder, file_type = compute_folder_and_file_type(config, from_bids) - - # Input file: - input_files = clinicadl_file_reader( - subjects, sessions, input_directory, file_type.model_dump() - )[0] - logger.debug(f"Selected image file name list: {input_files}.") - - def write_output_imgs(output_mode, container, subfolder): - # Write the extracted tensor on a .pt file - for filename, tensor in output_mode: - output_file_dir = ( - config.data.caps_directory - / container - / "deeplearning_prepare_data" - / subfolder - / mod_subfolder - ) - output_file_dir.mkdir(parents=True, exist_ok=True) - output_file = output_file_dir / filename - save_tensor(tensor, output_file) - logger.debug(f"Output tensor saved at {output_file}") - - if ( - config.extraction.extract_method == ExtractionMethod.IMAGE - or not config.extraction.save_features - ): - - def prepare_image(file): - from .prepare_data_utils import extract_images - - logger.debug(f"Processing of {file}.") - container = container_from_filename(file) - subfolder = "image_based" - output_mode = extract_images(Path(file)) - logger.debug("Image extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_image)(file) for file in input_files - ) - - elif config.extraction.save_features: - if config.extraction.extract_method == ExtractionMethod.SLICE: - assert isinstance(config.extraction, ExtractionSliceConfig) - - def prepare_slice(file): - from .prepare_data_utils import extract_slices - - assert isinstance(config.extraction, ExtractionSliceConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "slice_based" - output_mode = extract_slices( - Path(file), - slice_direction=config.extraction.slice_direction, - slice_mode=config.extraction.slice_mode, - discarded_slices=config.extraction.discarded_slices, - ) - logger.debug(f" {len(output_mode)} slices extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_slice)(file) for file in input_files - ) - - elif config.extraction.extract_method == ExtractionMethod.PATCH: - assert isinstance(config.extraction, ExtractionPatchConfig) - - def prepare_patch(file): - from .prepare_data_utils import extract_patches - - assert isinstance(config.extraction, ExtractionPatchConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "patch_based" - output_mode = extract_patches( - Path(file), - patch_size=config.extraction.patch_size, - stride_size=config.extraction.stride_size, - ) - logger.debug(f" {len(output_mode)} patches extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_patch)(file) for file in input_files - ) - - elif config.extraction.extract_method == ExtractionMethod.ROI: - assert isinstance(config.extraction, ExtractionROIConfig) - - def prepare_roi(file): - from .prepare_data_utils import extract_roi - - assert isinstance(config.extraction, ExtractionROIConfig) - logger.debug(f" Processing of {file}.") - container = container_from_filename(file) - subfolder = "roi_based" - if config.preprocessing == Preprocessing.CUSTOM: - if not config.extraction.roi_custom_template: - raise ClinicaDLArgumentError( - "A custom template must be defined when the modality is set to custom." - ) - roi_template = config.extraction.roi_custom_template - roi_mask_pattern = config.extraction.roi_custom_mask_pattern - else: - if config.preprocessing.preprocessing == Preprocessing.T1_LINEAR: - roi_template = Template.T1_LINEAR - roi_mask_pattern = Pattern.T1_LINEAR - elif config.preprocessing.preprocessing == Preprocessing.PET_LINEAR: - roi_template = Template.PET_LINEAR - roi_mask_pattern = Pattern.PET_LINEAR - elif ( - config.preprocessing.preprocessing == Preprocessing.FLAIR_LINEAR - ): - roi_template = Template.FLAIR_LINEAR - roi_mask_pattern = Pattern.FLAIR_LINEAR - - masks_location = input_directory / "masks" / f"tpl-{roi_template}" - - if len(config.extraction.roi_list) == 0: - raise ClinicaDLArgumentError( - "A list of regions of interest must be given." - ) - else: - check_mask_list( - masks_location, - config.extraction.roi_list, - roi_mask_pattern, - config.preprocessing.use_uncropped_image, - ) - - output_mode = extract_roi( - Path(file), - masks_location=masks_location, - mask_pattern=roi_mask_pattern, - cropped_input=not config.preprocessing.use_uncropped_image, - roi_names=config.extraction.roi_list, - uncrop_output=config.extraction.roi_uncrop_output, - ) - logger.debug("ROI extracted.") - write_output_imgs(output_mode, container, subfolder) - - Parallel(n_jobs=config.dataloader.n_proc)( - delayed(prepare_roi)(file) for file in input_files - ) - - else: - raise NotImplementedError( - f"Extraction is not implemented for mode {config.extraction.extract_method.value}." - ) - - # Save parameters dictionary - preprocessing_json_path = write_preprocessing( - config.extraction.model_dump(), config.data.caps_directory - ) - logger.info(f"Preprocessing JSON saved at {preprocessing_json_path}.") diff --git a/clinicadl/dataset/prepare_data/prepare_data_utils.py b/clinicadl/dataset/prepare_data/prepare_data_utils.py deleted file mode 100644 index 0acd2ec25..000000000 --- a/clinicadl/dataset/prepare_data/prepare_data_utils.py +++ /dev/null @@ -1,442 +0,0 @@ -# coding: utf8 -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from clinicadl.utils.enum import SliceDirection, SliceMode - - -############ -# SLICE # -############ -def compute_discarded_slices(discarded_slices: Union[int, tuple]) -> Tuple[int, int]: - if isinstance(discarded_slices, int): - begin_discard, end_discard = discarded_slices, discarded_slices - elif len(discarded_slices) == 1: - begin_discard, end_discard = discarded_slices[0], discarded_slices[0] - elif len(discarded_slices) == 2: - begin_discard, end_discard = discarded_slices[0], discarded_slices[1] - else: - raise IndexError( - f"Maximum two number of discarded slices can be defined. " - f"You gave discarded slices = {discarded_slices}." - ) - return begin_discard, end_discard - - -def extract_slices( - nii_path: Path, - slice_direction: SliceDirection = SliceDirection.SAGITTAL, - slice_mode: SliceMode = SliceMode.SINGLE, - discarded_slices: Union[int, tuple] = 0, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts the slices from three directions - This function extracts slices form the preprocessed nifti image. - - The direction of extraction can be defined either on sagittal direction (0), - coronal direction (1) or axial direction (other). - - The output slices can be stored following two modes: - single (1 channel) or rgb (3 channels, all the same). - - Args: - nii_path: path to the NifTi input image. - slice_direction: along which axis slices are extracted. - slice_mode: 'single' or 'rgb'. - discarded_slices: Number of slices to discard at the beginning and the end of the image. - Will be a tuple of two integers if the number of slices to discard at the beginning - and at the end differ. - Returns: - list of tuples containing the path to the extracted slice - and the tensor of the corresponding slice. - """ - import nibabel as nib - - image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - - begin_discard, end_discard = compute_discarded_slices(discarded_slices) - index_list = range( - begin_discard, image_tensor.shape[int(slice_direction.value) + 1] - end_discard - ) - - slice_list = [] - for slice_index in index_list: - slice_tensor = extract_slice_tensor( - image_tensor, slice_direction, slice_mode, slice_index - ) - slice_path = extract_slice_path( - nii_path, slice_direction, slice_mode, slice_index - ) - - slice_list.append((slice_path, slice_tensor)) - - return slice_list - - -def extract_slice_tensor( - image_tensor: torch.Tensor, - slice_direction: SliceDirection, - slice_mode: SliceMode, - slice_index: int, -) -> torch.Tensor: - # Allow to select the slice `slice_index` in dimension `slice_direction` - idx_tuple = tuple( - [slice(None)] * (int(slice_direction.value) + 1) - + [slice_index] - + [slice(None)] * (2 - int(slice_direction.value)) - ) - slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L - - if slice_mode == "rgb": - slice_tensor = torch.cat( - (slice_tensor, slice_tensor, slice_tensor) - ) # shape is 3 * W * L - - return slice_tensor.clone() - - -def extract_slice_path( - img_path: Path, - slice_direction: SliceDirection, - slice_mode: SliceMode, - slice_index: int, -) -> str: - slice_dict = {0: "sag", 1: "cor", 2: "axi"} - input_img_filename = img_path.name - txt_idx = input_img_filename.rfind("_") - it_filename_prefix = input_img_filename[0:txt_idx] - it_filename_suffix = input_img_filename[txt_idx:] - it_filename_suffix = it_filename_suffix.replace(".nii.gz", ".pt") - return ( - f"{it_filename_prefix}_axis-{slice_dict[int(slice_direction.value)]}" - f"_channel-{slice_mode.value}_slice-{slice_index}{it_filename_suffix}" - ) - - -############ -# PATCH # -############ -def extract_patches( - nii_path: Path, - patch_size: int, - stride_size: int, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts the patches - This function extracts patches form the preprocessed nifti image. Patch size - if provided as input and also the stride size. If stride size is smaller - than the patch size an overlap exist between consecutive patches. If stride - size is equal to path size there is no overlap. Otherwise, unprocessed - zones can exits. - Args: - nii_path: path to the NifTi input image. - patch_size: size of a single patch. - stride_size: size of the stride leading to next patch. - Returns: - list of tuples containing the path to the extracted patch - and the tensor of the corresponding patch. - """ - import nibabel as nib - - image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - - patches_tensor = ( - image_tensor.unfold(1, patch_size, stride_size) - .unfold(2, patch_size, stride_size) - .unfold(3, patch_size, stride_size) - .contiguous() - ) - patches_tensor = patches_tensor.view(-1, patch_size, patch_size, patch_size) - - patch_list = [] - for patch_index in range(patches_tensor.shape[0]): - patch_tensor = extract_patch_tensor( - image_tensor, patch_size, stride_size, patch_index, patches_tensor - ) - patch_path = extract_patch_path(nii_path, patch_size, stride_size, patch_index) - - patch_list.append((patch_path, patch_tensor)) - - return patch_list - - -def extract_patch_tensor( - image_tensor: torch.Tensor, - patch_size: int, - stride_size: int, - patch_index: int, - patches_tensor: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Extracts a single patch from image_tensor""" - - if patches_tensor is None: - patches_tensor = ( - image_tensor.unfold(1, patch_size, stride_size) - .unfold(2, patch_size, stride_size) - .unfold(3, patch_size, stride_size) - .contiguous() - ) - - # the dimension of patches_tensor is [1, patch_num1, patch_num2, patch_num3, patch_size1, patch_size2, patch_size3] - patches_tensor = patches_tensor.view(-1, patch_size, patch_size, patch_size) - - return patches_tensor[patch_index, ...].unsqueeze_(0).clone() - - -def extract_patch_path( - img_path: Path, patch_size: int, stride_size: int, patch_index: int -) -> str: - input_img_filename = img_path.name - txt_idx = input_img_filename.rfind("_") - it_filename_prefix = input_img_filename[0:txt_idx] - it_filename_suffix = input_img_filename[txt_idx:] - it_filename_suffix = it_filename_suffix.replace(".nii.gz", ".pt") - - return f"{it_filename_prefix}_patchsize-{patch_size}_stride-{stride_size}_patch-{patch_index}{it_filename_suffix}" - - -############ -# IMAGE # -############ -def extract_images(input_img: Path) -> List[Tuple[str, torch.Tensor]]: - """Extract the images - This function convert nifti image to tensor (.pt) version of the image. - Tensor version is saved at the same location than input_img. - Args: - input_img: path to the NifTi input image. - Returns: - filename (str): single tensor file saved on the disk. Same location than input file. - """ - import nibabel as nib - import torch - - image_array = nib.loadsave.load(input_img).get_fdata(dtype="float32") - image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() - # make sure the tensor type is torch.float32 - output_file = ( - Path(input_img.name.replace(".nii.gz", ".pt")), - image_tensor.clone(), - ) - - return [output_file] - - -############ -# ROI # -############ -def check_mask_list( - masks_location: Path, roi_list: List[str], mask_pattern: str, cropping: bool -) -> None: - import nibabel as nib - import numpy as np - - for roi in roi_list: - roi_path, desc = find_mask_path(masks_location, roi, mask_pattern, cropping) - if roi_path is None: - raise FileNotFoundError( - f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" - ) - roi_mask = nib.loadsave.load(roi_path).get_fdata() - mask_values = set(np.unique(roi_mask)) - if mask_values != {0, 1}: - raise ValueError( - "The ROI masks used should be binary (composed of 0 and 1 only)." - ) - - -def find_mask_path( - masks_location: Path, roi: str, mask_pattern: str, cropping: bool -) -> Tuple[Union[None, str], str]: - """ - Finds masks corresponding to the pattern asked and containing the adequate cropping description - - Parameters - ---------- - masks_location: Path - Directory containing the masks. - roi: str - Name of the region. - mask_pattern: str - Pattern which should be found in the filename of the mask. - cropping: bool - If True the original image should contain the substring 'desc-Crop'. - - Returns - ------- - path of the mask or None if nothing was found. - a human-friendly description of the pattern looked for. - """ - - # Check that pattern begins and ends with _ to avoid mixing keys - if mask_pattern is None: - mask_pattern = "" - - candidates_pattern = f"*{mask_pattern}*_roi-{roi}_mask.nii*" - - desc = f"The mask should follow the pattern {candidates_pattern}. " - candidates = [e for e in masks_location.glob(candidates_pattern)] - if cropping is None: - # pass - candidates2 = candidates - elif cropping: - candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] - desc += "and contain '_desc-Crop_' string." - else: - candidates2 = [mask for mask in candidates if "_desc-Crop_" not in mask.name] - desc += "and not contain '_desc-Crop_' string." - - if len(candidates2) == 0: - return None, desc - else: - return min(candidates2), desc - - -def compute_output_pattern(mask_path: Path, crop_output: bool): - """ - Computes the output pattern of the region cropped (without the source file prefix) - Parameters - ---------- - mask_path: Path - Path to the masks - crop_output: bool - If True the output is cropped, and the descriptor CropRoi must exist - - Returns - ------- - the output pattern - """ - - mask_filename = mask_path.name - template_id = mask_filename.split("_")[0].split("-")[1] - mask_descriptors = mask_filename.split("_")[1:-2:] - roi_id = mask_filename.split("_")[-2].split("-")[1] - if "desc-Crop" not in mask_descriptors and crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - elif "desc-Crop" in mask_descriptors: - mask_descriptors = [ - descriptor for descriptor in mask_descriptors if descriptor != "desc-Crop" - ] - if crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - else: - mask_descriptors = ["desc-CropImage"] + mask_descriptors - - mask_pattern = "_".join(mask_descriptors) - - if mask_pattern == "": - output_pattern = f"space-{template_id}_roi-{roi_id}" - else: - output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" - - return output_pattern - - -def extract_roi( - nii_path: Path, - masks_location: Path, - mask_pattern: str, - cropped_input: bool, - roi_names: List[str], - uncrop_output: bool, -) -> List[Tuple[str, torch.Tensor]]: - """Extracts regions of interest defined by masks - This function extracts regions of interest from preprocessed nifti images. - The regions are defined using binary masks that must be located in the CAPS - at `masks/tpl-