Skip to content

Commit

Permalink
Merge pull request #779 from scap3yvt/778-add-option-to-not-apply-zer…
Browse files Browse the repository at this point in the history
…o-plane-cropping-for-gandlf_preprocess

Added option to not apply zero plane cropping for gandlf preprocess
  • Loading branch information
sarthakpati authored Jan 22, 2024
2 parents 2d5513a + abad6db commit f57483c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 57 deletions.
44 changes: 20 additions & 24 deletions GANDLF/cli/preprocess_and_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
parseTrainingCSV,
populate_header_in_parameters,
get_dataframe,
get_correct_padding_size,
)
from GANDLF.parseConfig import parseConfig
from GANDLF.data.ImagesFromDataFrame import ImagesFromDataFrame
Expand All @@ -18,8 +19,13 @@


def preprocess_and_save(
data_csv, config_file, output_dir, label_pad_mode="constant", applyaugs=False
):
data_csv: str,
config_file: str,
output_dir: str,
label_pad_mode: str = "constant",
applyaugs: bool = False,
apply_zero_crop: bool = False,
) -> None:
"""
This function performs preprocessing based on parameters provided and saves the output.
Expand All @@ -29,6 +35,7 @@ def preprocess_and_save(
output_dir (str): The output directory.
label_pad_mode (str): The padding strategy for the label. Defaults to "constant".
applyaugs (bool): If data augmentation is to be applied before saving the image. Defaults to False.
apply_zero_crop (bool): If zero cropping is to be applied before saving the image. Defaults to False.
Raises:
ValueError: Parameter check from previous
Expand All @@ -55,7 +62,11 @@ def preprocess_and_save(
parameters = populate_header_in_parameters(parameters, headers)

data_for_processing = ImagesFromDataFrame(
dataframe, parameters, train=applyaugs, apply_zero_crop=True, loader_type="full"
dataframe,
parameters,
train=applyaugs,
apply_zero_crop=apply_zero_crop,
loader_type="full",
)

dataloader_for_processing = DataLoader(
Expand Down Expand Up @@ -117,34 +128,19 @@ def preprocess_and_save(
subject_dict_to_write = torchio.Subject(subject_process)

# apply a different padding mode to image and label (so that label information is not duplicated)
if (parameters["patch_sampler"] == "label") or (
isinstance(parameters["patch_sampler"], dict)
):
if parameters["patch_sampler"]["type"] == "label":
# get the padding size from the patch_size
psize_pad = list(
np.asarray(np.ceil(np.divide(parameters["patch_size"], 2)), dtype=int)
psize_pad = get_correct_padding_size(
parameters["patch_size"], parameters["model"]["dimension"]
)
# initialize the padder for images
padder = torchio.transforms.Pad(
psize_pad, padding_mode="symmetric", include=keys_with_images
psize_pad,
padding_mode=label_pad_mode,
include=keys_with_images + ["label"],
)
subject_dict_to_write = padder(subject_dict_to_write)

if parameters["headers"]["labelHeader"] is not None:
# initialize the padder for label
padder_label = torchio.transforms.Pad(
psize_pad, padding_mode=label_pad_mode, include="label"
)
subject_dict_to_write = padder_label(subject_dict_to_write)

sampler = torchio.data.LabelSampler(parameters["patch_size"])
generator = sampler(subject_dict_to_write, num_patches=1)
for patch in generator:
for channel in parameters["headers"]["channelHeaders"]:
subject_dict_to_write[str(channel)] = patch[str(channel)]

subject_dict_to_write["label"] = patch["label"]

# write new images
common_ext = get_filename_extension_sanitized(subject["path_to_metadata"][0])
# in cases where the original image has a file format that does not support
Expand Down
54 changes: 25 additions & 29 deletions GANDLF/data/ImagesFromDataFrame.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Union
import os
from pathlib import Path
import numpy as np

import pandas
import torch
import torchio
from torchio.transforms import Pad
Expand All @@ -12,6 +14,7 @@
perform_sanity_check_on_subject,
resize_image,
get_filename_extension_sanitized,
get_correct_padding_size,
)
from .preprocessing import get_transforms_for_preprocessing
from .augmentation import global_augs_dict
Expand All @@ -31,30 +34,27 @@

# This function takes in a dataframe, with some other parameters and returns the dataloader
def ImagesFromDataFrame(
dataframe, parameters, train, apply_zero_crop=False, loader_type=""
):
dataframe: pandas.DataFrame,
parameters: dict,
train: bool,
apply_zero_crop: bool = False,
loader_type: str = "",
) -> Union[torchio.SubjectsDataset, torchio.Queue]:
"""
Reads the pandas dataframe and gives the dataloader to use for training/validation/testing
Parameters
----------
dataframe : pandas.DataFrame
The main input dataframe which is calculated after splitting the data CSV
parameters : dict
The parameters dictionary
train : bool
If the dataloader is for training or not. For training, the patching infrastructure and data augmentation is applied.
apply_zero_crop : bool
If enabled, the crop_external_zero_plane is applied.
loader_type : str
Type of loader for printing.
Returns
-------
subjects_dataset: torchio.SubjectsDataset
This is the output for validation/testing, where patching and data augmentation is disregarded
patches_queue: torchio.Queue
This is the output for training, which is the subjects_dataset queue after patching and data augmentation is taken into account
Reads the pandas dataframe and gives the dataloader to use for training/validation/testing.
Args:
dataframe (pandas.DataFrame): The main input dataframe which is calculated after splitting the data CSV.
parameters (dict): The parameters dictionary.
train (bool): If the dataloader is for training or not.
apply_zero_crop (bool, optional): If zero crop is to be applied. Defaults to False.
loader_type (str, optional): The type of loader to use for printing. Defaults to "".
Raises:
ValueError: If the subject cannot be loaded.
Returns:
Union[torchio.SubjectsDataset, torchio.Queue]: The dataloader queue for validation/testing (where patching and data augmentation is not required) or the subjects dataset for training.
"""
# store in previous variable names
patch_size = parameters["patch_size"]
Expand Down Expand Up @@ -243,13 +243,9 @@ def _save_resized_images(

# # padding image, but only for label sampler, because we don't want to pad for uniform
if sampler["enable_padding"]:
psize_pad = list(
np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int)
psize_pad = get_correct_padding_size(
patch_size, parameters["model"]["dimension"]
)
# ensure that the patch size for z-axis is not 1 for 2d images
if parameters["model"]["dimension"] == 2:
psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1]
# for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
padder = Pad(psize_pad, padding_mode=sampler["padding_mode"])
subject = padder(subject)

Expand Down
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
resample_image,
perform_sanity_check_on_subject,
write_training_patches,
get_correct_padding_size,
)

from .tensor import (
Expand Down
24 changes: 21 additions & 3 deletions GANDLF/utils/imaging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, pathlib, sys, math
from typing import Union
import os, pathlib, math, copy
import numpy as np
import SimpleITK as sitk
import torchio
Expand Down Expand Up @@ -126,8 +127,6 @@ def perform_sanity_check_on_subject(subject, parameters):
# read the first image and save that for comparison
file_reader_base = None

import copy

list_for_comparison = copy.deepcopy(parameters["headers"]["channelHeaders"])
if parameters["headers"]["labelHeader"] is not None:
list_for_comparison.append("label")
Expand Down Expand Up @@ -240,3 +239,22 @@ def write_training_patches(subject, params):
img_to_write,
os.path.join(training_output_dir_current_subject, "label_" + key + ext),
)


def get_correct_padding_size(patch_size: Union[list, tuple], model_dimension: int):
"""
This function returns the correct padding size based on the patch size and overlap.
Args:
patch_size (Union[list, tuple]): The patch size.
model_dimension (int): The model dimension.
Returns:
Union[list, tuple]: The correct padding size.
"""
psize_pad = list(np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int))
# ensure that the patch size for z-axis is not 1 for 2d images
if model_dimension == 2:
psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1]

return psize_pad
16 changes: 15 additions & 1 deletion gandlf_preprocess
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,25 @@ if __name__ == "__main__":
help="This specifies the whether to apply data augmentation during output creation. Defaults to False",
required=False,
)
parser.add_argument(
"-a",
"--cropzero",
metavar="",
type=bool,
default=False,
help="This specifies the whether to apply zero cropping during output creation. Defaults to False",
required=False,
)

args = parser.parse_args()

preprocess_and_save(
args.inputdata, args.config, args.output, args.labelPad, args.applyaugs
args.inputdata,
args.config,
args.output,
args.labelPad,
args.applyaugs,
args.cropzero,
)

print("Finished.")

0 comments on commit f57483c

Please sign in to comment.