Module fmcib.callbacks
--Expand source code -
-from .prediction_saver import SavePredictions
-Sub-modules
--
-
fmcib.callbacks.prediction_saver
-- - - -
fmcib.callbacks.utils
-- - - -
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 44ee607..a98c6de 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,5 +25,5 @@ jobs: path: .cache restore-keys: | mkdocs-material- - - run: pip install mkdocs-material + - run: pip install mkdocs-material mkdocstrings mkdocstrings-python mkdocs-section-index mkdocs-literate-nav mkdocs-gen-files mkdocs-autorefs - run: mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index f4f09ed..05cbbc2 100644 --- a/.gitignore +++ b/.gitignore @@ -634,3 +634,4 @@ checkpoints/ data/csvs/ ./models ./outputs +docs/reference/* \ No newline at end of file diff --git a/docs/api_docs/fmcib/callbacks/index.html b/docs/api_docs/fmcib/callbacks/index.html deleted file mode 100644 index fa6bad3..0000000 --- a/docs/api_docs/fmcib/callbacks/index.html +++ /dev/null @@ -1,76 +0,0 @@ - - -
- - - -fmcib.callbacks
from .prediction_saver import SavePredictions
-fmcib.callbacks.prediction_saver
fmcib.callbacks.utils
fmcib.callbacks.prediction_saver
from typing import Any, List
-
-from pathlib import Path
-
-import pandas as pd
-import torchvision
-from loguru import logger
-from pytorch_lightning.callbacks import BasePredictionWriter
-
-from .utils import decollate, handle_image
-
-
-class SavePredictions(BasePredictionWriter):
- """
- A class that saves model predictions.
-
- Attributes:
- path (str): The path to save the output CSV file.
- save_preview_samples (bool): If True, save preview images.
- keys (List[str]): A list of keys.
- """
-
- def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
- """
- Initialize an instance of the class.
-
- Args:
- path (str): The path to save the output CSV file.
- save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
- keys (List[str], optional): A list of keys. Defaults to None.
-
- Raises:
- None
-
- Returns:
- None
- """
- super().__init__("epoch")
- self.output_csv = Path(path)
- self.keys = keys
- self.save_preview_samples = save_preview_samples
- self.output_csv.parent.mkdir(parents=True, exist_ok=True)
-
- def save_preview_image(self, data, tag):
- """
- Save a preview image to a specified directory.
-
- Args:
- self (object): The object calling the function. (self in Python)
- data (tuple): A tuple containing the image data and its corresponding tag.
- tag (str): The tag for the image.
-
- Returns:
- None
-
- Raises:
- None
- """
- self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
- self.output_dir.mkdir(parents=True, exist_ok=True)
- image, _ = data
- image = handle_image(image)
- fp = self.output_dir / f"{tag}.png"
- torchvision.utils.save_image(image, fp)
-
- def write_on_epoch_end(
- self,
- trainer,
- pl_module: "LightningModule",
- predictions: List[Any],
- batch_indices: List[Any],
- ):
- """
- Write predictions on epoch end.
-
- Args:
- self: The instance of the class.
- trainer: The trainer object.
- pl_module (LightningModule): The Lightning module.
- predictions (List[Any]): A list of prediction values.
- batch_indices (List[Any]): A list of batch indices.
-
- Raises:
- AssertionError: If 'predict' is not present in pl_module.datasets.
- AssertionError: If 'data' is not defined in pl_module.datasets.
-
- Returns:
- None
- """
- rows = []
- assert "predict" in pl_module.datasets, "`data` not defined"
- dataset = pl_module.datasets["predict"]
- predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
- for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
- for i, v in enumerate(pred):
- row[f"pred_{i}"] = v.item()
-
- rows.append(row)
-
- # Save image previews
- if idx <= self.save_preview_samples:
- input = dataset[idx]
- self.save_preview_image(input, idx)
-
- df = pd.DataFrame(rows)
- df.to_csv(self.output_csv)
-
-class SavePredictions
-(path: str, save_preview_samples: bool = False, keys: List[str] = None)
-
A class that saves model predictions.
-path
: str
save_preview_samples
: bool
keys
: List[str]
Initialize an instance of the class.
-path
: str
save_preview_samples
: bool
, optionalkeys
: List[str]
, optionalNone
-None
class SavePredictions(BasePredictionWriter):
- """
- A class that saves model predictions.
-
- Attributes:
- path (str): The path to save the output CSV file.
- save_preview_samples (bool): If True, save preview images.
- keys (List[str]): A list of keys.
- """
-
- def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
- """
- Initialize an instance of the class.
-
- Args:
- path (str): The path to save the output CSV file.
- save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
- keys (List[str], optional): A list of keys. Defaults to None.
-
- Raises:
- None
-
- Returns:
- None
- """
- super().__init__("epoch")
- self.output_csv = Path(path)
- self.keys = keys
- self.save_preview_samples = save_preview_samples
- self.output_csv.parent.mkdir(parents=True, exist_ok=True)
-
- def save_preview_image(self, data, tag):
- """
- Save a preview image to a specified directory.
-
- Args:
- self (object): The object calling the function. (self in Python)
- data (tuple): A tuple containing the image data and its corresponding tag.
- tag (str): The tag for the image.
-
- Returns:
- None
-
- Raises:
- None
- """
- self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
- self.output_dir.mkdir(parents=True, exist_ok=True)
- image, _ = data
- image = handle_image(image)
- fp = self.output_dir / f"{tag}.png"
- torchvision.utils.save_image(image, fp)
-
- def write_on_epoch_end(
- self,
- trainer,
- pl_module: "LightningModule",
- predictions: List[Any],
- batch_indices: List[Any],
- ):
- """
- Write predictions on epoch end.
-
- Args:
- self: The instance of the class.
- trainer: The trainer object.
- pl_module (LightningModule): The Lightning module.
- predictions (List[Any]): A list of prediction values.
- batch_indices (List[Any]): A list of batch indices.
-
- Raises:
- AssertionError: If 'predict' is not present in pl_module.datasets.
- AssertionError: If 'data' is not defined in pl_module.datasets.
-
- Returns:
- None
- """
- rows = []
- assert "predict" in pl_module.datasets, "`data` not defined"
- dataset = pl_module.datasets["predict"]
- predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
- for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
- for i, v in enumerate(pred):
- row[f"pred_{i}"] = v.item()
-
- rows.append(row)
-
- # Save image previews
- if idx <= self.save_preview_samples:
- input = dataset[idx]
- self.save_preview_image(input, idx)
-
- df = pd.DataFrame(rows)
- df.to_csv(self.output_csv)
-
-def save_preview_image(self, data, tag)
-
Save a preview image to a specified directory.
-self
: object
data
: tuple
tag
: str
None
-None
def save_preview_image(self, data, tag):
- """
- Save a preview image to a specified directory.
-
- Args:
- self (object): The object calling the function. (self in Python)
- data (tuple): A tuple containing the image data and its corresponding tag.
- tag (str): The tag for the image.
-
- Returns:
- None
-
- Raises:
- None
- """
- self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
- self.output_dir.mkdir(parents=True, exist_ok=True)
- image, _ = data
- image = handle_image(image)
- fp = self.output_dir / f"{tag}.png"
- torchvision.utils.save_image(image, fp)
-
-def write_on_epoch_end(self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any])
-
Write predictions on epoch end.
-self
trainer
pl_module
: LightningModule
predictions
: List[Any]
batch_indices
: List[Any]
AssertionError
AssertionError
None
def write_on_epoch_end(
- self,
- trainer,
- pl_module: "LightningModule",
- predictions: List[Any],
- batch_indices: List[Any],
-):
- """
- Write predictions on epoch end.
-
- Args:
- self: The instance of the class.
- trainer: The trainer object.
- pl_module (LightningModule): The Lightning module.
- predictions (List[Any]): A list of prediction values.
- batch_indices (List[Any]): A list of batch indices.
-
- Raises:
- AssertionError: If 'predict' is not present in pl_module.datasets.
- AssertionError: If 'data' is not defined in pl_module.datasets.
-
- Returns:
- None
- """
- rows = []
- assert "predict" in pl_module.datasets, "`data` not defined"
- dataset = pl_module.datasets["predict"]
- predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
- for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
- for i, v in enumerate(pred):
- row[f"pred_{i}"] = v.item()
-
- rows.append(row)
-
- # Save image previews
- if idx <= self.save_preview_samples:
- input = dataset[idx]
- self.save_preview_image(input, idx)
-
- df = pd.DataFrame(rows)
- df.to_csv(self.output_csv)
-fmcib.callbacks.utils
from typing import List
-
-import torch
-
-
-def decollate(data: List[torch.Tensor]):
- """
- Decollate a list of tensors into a list of values.
-
- Args:
- data (list): A list of batch tensors.
-
- Returns:
- list: A list of values from the input tensors.
-
- Raises:
- AssertionError: If the input is not a list of tensors.
- """
- assert isinstance(data, list), "Decollate only implemented for list of `batch` tensors"
-
- out = []
- for d in data:
- # Handles both cases: multiple elements and single element
- # https://pytorch.org/docs/stable/generated/torch.Tensor.tolist.html
- d = d.tolist()
-
- out += d
- return out
-
-
-def handle_image(image):
- """
- Handle image according to specific requirements.
-
- Args:
- image (tensor): An image tensor.
-
- Returns:
- tensor: The processed image tensor, based on the input conditions.
-
- Raises:
- None.
- """
- image = image.squeeze()
- if image.dim() == 3:
- return image[image.shape[0] // 2]
- else:
- return image
-
-def decollate(data: List[torch.Tensor])
-
Decollate a list of tensors into a list of values.
-data
: list
list
AssertionError
def decollate(data: List[torch.Tensor]):
- """
- Decollate a list of tensors into a list of values.
-
- Args:
- data (list): A list of batch tensors.
-
- Returns:
- list: A list of values from the input tensors.
-
- Raises:
- AssertionError: If the input is not a list of tensors.
- """
- assert isinstance(data, list), "Decollate only implemented for list of `batch` tensors"
-
- out = []
- for d in data:
- # Handles both cases: multiple elements and single element
- # https://pytorch.org/docs/stable/generated/torch.Tensor.tolist.html
- d = d.tolist()
-
- out += d
- return out
-
-def handle_image(image)
-
Handle image according to specific requirements.
-image
: tensor
tensor
None.
def handle_image(image):
- """
- Handle image according to specific requirements.
-
- Args:
- image (tensor): An image tensor.
-
- Returns:
- tensor: The processed image tensor, based on the input conditions.
-
- Raises:
- None.
- """
- image = image.squeeze()
- if image.dim() == 3:
- return image[image.shape[0] // 2]
- else:
- return image
-fmcib.datasets
import os
-import random
-from pathlib import Path
-
-import monai
-import numpy as np
-import pandas as pd
-import SimpleITK as sitk
-import wget
-from loguru import logger
-
-from .ssl_radiomics_dataset import SSLRadiomicsDataset
-
-
-def get_lung1_clinical_data():
- wget.download(
- "https://www.dropbox.com/s/ulp8t21eunep21y/NSCLC%20Radiomics%20Lung1.clinical-version3-Oct%202019.csv?dl=1",
- out="/tmp/lung1_clinical.csv",
- )
- return pd.read_csv("/tmp/lung1_clinical.csv")
-
-
-def get_radio_clinical_data():
- wget.download(
- "https://www.dropbox.com/s/mtpynjof550ulfo/NSCLCR01Radiogenomic_DATA_LABELS_2018-05-22_1500-shifted.csv?dl=1",
- out=f"/tmp/radio_clinical.csv",
- )
- return pd.read_csv("/tmp/radio_clinical.csv")
-
-
-def get_lung1_foundation_features():
- wget.download(
- "https://www.dropbox.com/s/ypbb2iogq3bsq5v/lung1.csv?dl=1",
- out=f"/tmp/lung1_foundation_features.csv",
- )
- df = pd.read_csv("/tmp/lung1_foundation_features.csv")
- filtered_df = df.filter(like="pred")
- filtered_df = filtered_df.reset_index() # reset the index
- filtered_df["PatientID"] = df["PatientID"]
- return filtered_df
-
-
-def get_radio_foundation_features():
- wget.download(
- "https://www.dropbox.com/s/pwl4rdlvp9jirar/radio.csv?dl=1",
- out=f"/tmp/radio_foundation_features.csv",
- )
-
- df = pd.read_csv("/tmp/radio_foundation_features.csv")
- filtered_df = df.filter(like="pred")
- filtered_df = filtered_df.reset_index() # reset the index
- filtered_df["PatientID"] = df["Case ID"]
- return filtered_df
-
-
-def generate_dummy_data(dir_path, size=10):
- path = Path(dir_path).resolve()
- path.mkdir(exist_ok=True, parents=True)
-
- row_list = []
- for i in range(size):
- row = create_dummy_row((32, 128, 128), str(path / f"dummy_{i}.nii.gz"))
- row_list.append(row)
-
- df = pd.DataFrame(row_list)
- df.to_csv(path / "dummy.csv", index=False)
-
- logger.info(f"Generated dummy data at {path}/dummy.csv")
-
-
-def create_dummy_row(size, output_filename):
- """
- Function to create a dummy row with path to an image and seed point corresponding to the image
- """
-
- # Create a np array initialized with random values between -1024 and 2048
- np_image = np.random.randint(-1024, 2048, size, dtype=np.int16)
-
- # Create an itk image from the numpy array
- itk_image = sitk.GetImageFromArray(np_image)
-
- # Save itk image to file with the given output filename
- sitk.WriteImage(itk_image, output_filename)
-
- x, y, z = generate_random_seed_point(itk_image.GetSize())
-
- # Convert to global coordinates
- x, y, z = itk_image.TransformContinuousIndexToPhysicalPoint((x, y, z))
-
- return {
- "image_path": output_filename,
- "PatientID": random.randint(0, 100000),
- "coordX": x,
- "coordY": y,
- "coordZ": z,
- "label": random.randint(0, 1),
- }
-
-
-def generate_random_seed_point(image_size):
- """
- Function to generate a random x, y, z coordinate within the image
- """
- x = random.randint(0, image_size[0] - 1)
- y = random.randint(0, image_size[1] - 1)
- z = random.randint(0, image_size[2] - 1)
-
- return (x, y, z)
-fmcib.datasets.ssl_radiomics_dataset
fmcib.datasets.utils
-def create_dummy_row(size, output_filename)
-
Function to create a dummy row with path to an image and seed point corresponding to the image
def create_dummy_row(size, output_filename):
- """
- Function to create a dummy row with path to an image and seed point corresponding to the image
- """
-
- # Create a np array initialized with random values between -1024 and 2048
- np_image = np.random.randint(-1024, 2048, size, dtype=np.int16)
-
- # Create an itk image from the numpy array
- itk_image = sitk.GetImageFromArray(np_image)
-
- # Save itk image to file with the given output filename
- sitk.WriteImage(itk_image, output_filename)
-
- x, y, z = generate_random_seed_point(itk_image.GetSize())
-
- # Convert to global coordinates
- x, y, z = itk_image.TransformContinuousIndexToPhysicalPoint((x, y, z))
-
- return {
- "image_path": output_filename,
- "PatientID": random.randint(0, 100000),
- "coordX": x,
- "coordY": y,
- "coordZ": z,
- "label": random.randint(0, 1),
- }
-
-def generate_dummy_data(dir_path, size=10)
-
def generate_dummy_data(dir_path, size=10):
- path = Path(dir_path).resolve()
- path.mkdir(exist_ok=True, parents=True)
-
- row_list = []
- for i in range(size):
- row = create_dummy_row((32, 128, 128), str(path / f"dummy_{i}.nii.gz"))
- row_list.append(row)
-
- df = pd.DataFrame(row_list)
- df.to_csv(path / "dummy.csv", index=False)
-
- logger.info(f"Generated dummy data at {path}/dummy.csv")
-
-def generate_random_seed_point(image_size)
-
Function to generate a random x, y, z coordinate within the image
def generate_random_seed_point(image_size):
- """
- Function to generate a random x, y, z coordinate within the image
- """
- x = random.randint(0, image_size[0] - 1)
- y = random.randint(0, image_size[1] - 1)
- z = random.randint(0, image_size[2] - 1)
-
- return (x, y, z)
-
-def get_lung1_clinical_data()
-
def get_lung1_clinical_data():
- wget.download(
- "https://www.dropbox.com/s/ulp8t21eunep21y/NSCLC%20Radiomics%20Lung1.clinical-version3-Oct%202019.csv?dl=1",
- out="/tmp/lung1_clinical.csv",
- )
- return pd.read_csv("/tmp/lung1_clinical.csv")
-
-def get_lung1_foundation_features()
-
def get_lung1_foundation_features():
- wget.download(
- "https://www.dropbox.com/s/ypbb2iogq3bsq5v/lung1.csv?dl=1",
- out=f"/tmp/lung1_foundation_features.csv",
- )
- df = pd.read_csv("/tmp/lung1_foundation_features.csv")
- filtered_df = df.filter(like="pred")
- filtered_df = filtered_df.reset_index() # reset the index
- filtered_df["PatientID"] = df["PatientID"]
- return filtered_df
-
-def get_radio_clinical_data()
-
def get_radio_clinical_data():
- wget.download(
- "https://www.dropbox.com/s/mtpynjof550ulfo/NSCLCR01Radiogenomic_DATA_LABELS_2018-05-22_1500-shifted.csv?dl=1",
- out=f"/tmp/radio_clinical.csv",
- )
- return pd.read_csv("/tmp/radio_clinical.csv")
-
-def get_radio_foundation_features()
-
def get_radio_foundation_features():
- wget.download(
- "https://www.dropbox.com/s/pwl4rdlvp9jirar/radio.csv?dl=1",
- out=f"/tmp/radio_foundation_features.csv",
- )
-
- df = pd.read_csv("/tmp/radio_foundation_features.csv")
- filtered_df = df.filter(like="pred")
- filtered_df = filtered_df.reset_index() # reset the index
- filtered_df["PatientID"] = df["Case ID"]
- return filtered_df
-fmcib.datasets.ssl_radiomics_dataset
from pathlib import Path
-
-import monai
-import numpy as np
-import pandas as pd
-import SimpleITK as sitk
-from loguru import logger
-from torch.utils.data import Dataset
-
-from .utils import resample_image_to_spacing, slice_image
-
-
-class SSLRadiomicsDataset(Dataset):
- """
- Dataset class for SSL Radiomics dataset.
-
- Args:
- path (str): The path to the dataset.
- label (str, optional): The label column name in the dataset annotations. Default is None.
- radius (int, optional): The radius around the centroid for positive patch extraction. Default is 25.
- orient (bool, optional): Whether to orient the images to LPI orientation. Default is False.
- resample_spacing (float or tuple, optional): The desired spacing for resampling the images. Default is None.
- enable_negatives (bool, optional): Whether to include negative samples. Default is True.
- transform (callable, optional): A function/transform to apply on the images. Default is None.
- """
-
- def __init__(
- self,
- path,
- label=None,
- radius=25,
- orient=False,
- resample_spacing=None,
- enable_negatives=True,
- transform=None,
- orient_patch=True,
- input_is_target=False,
- ):
- """
- Creates an instance of the SSLRadiomicsDataset class with the given parameters.
-
- Args:
- path (str): The path to the dataset.
- label (Optional[str]): The label to use for the dataset. Defaults to None.
- radius (int): The radius parameter. Defaults to 25.
- orient (bool): True if the dataset should be oriented, False otherwise. Defaults to False.
- resample_spacing (Optional[...]): The resample spacing parameter. Defaults to None.
- enable_negatives (bool): True if negatives are enabled, False otherwise. Defaults to True.
- transform: The transformation to apply to the dataset. Defaults to None.
- orient_patch (bool): True if the patch should be oriented, False otherwise. Defaults to True.
- input_is_target (bool): True if the input is the target, False otherwise. Defaults to False.
-
- Raises:
- None.
-
- Returns:
- None.
- """
- monai.data.set_track_meta(False)
- sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
- super(SSLRadiomicsDataset, self).__init__()
- self._path = Path(path)
-
- self.radius = radius
- self.orient = orient
- self.resample_spacing = resample_spacing
- self.label = label
- self.enable_negatives = enable_negatives
- self.transform = transform
- self.orient_patch = orient_patch
- self.input_is_target = input_is_target
- self.annotations = pd.read_csv(self._path)
- self._num_samples = len(self.annotations) # set the length of the dataset
-
- def get_rows(self):
- """
- Get the rows of the annotations as a list of dictionaries.
-
- Returns:
- list of dict: The rows of the annotations as dictionaries.
- """
- return self.annotations.to_dict(orient="records")
-
- def get_labels(self):
- """
- Function to get labels for when they are available in the dataset.
-
- Args:
- None
-
- Returns:
- None
- """
-
- labels = self.annotations[self.label].values
- assert not np.any(labels == -1), "All labels must be specified"
- return labels
-
- def __len__(self):
- """
- Size of the dataset.
- """
- return self._num_samples
-
- def get_negative_sample(self, image):
- """
- Extract a negative sample from the image background with no overlap to the positive sample.
-
- Parameters:
- image: Image to extract sample
- positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
- """
- positive_patch_size = [self.radius * 2] * 3
- valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
- def get_random_patch():
- """
- Get a random patch from an image.
-
- Returns:
- list: A list containing the start and end indices of the random patch.
- """
- random_patch_idx = [
- [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
- ]
- return random_patch_idx
-
- random_patch_idx = get_random_patch()
-
- # escape_count = 0
- # while is_overlapping(positive_patch_idx, random_patch_idx):
- # if escape_count >= 3:
- # logger.warning("Random patch has overlap with positive patch")
- # return None
-
- # random_patch_idx = get_random_patch()
- # escape_count += 1
-
- random_patch = slice_image(image, random_patch_idx)
- random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
- negative_array = sitk.GetArrayFromImage(random_patch)
-
- negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
- return negative_tensor
-
- def __getitem__(self, idx: int):
- """
- Implement how to load the data corresponding to the idx element in the dataset from your data source.
- """
-
- # Get a row from the CSV file
- row = self.annotations.iloc[idx]
- image_path = row["image_path"]
- image = sitk.ReadImage(str(image_path))
- image = resample_image_to_spacing(image, self.resample_spacing, -1024) if self.resample_spacing is not None else image
-
- centroid = (row["coordX"], row["coordY"], row["coordZ"])
- centroid = image.TransformPhysicalPointToContinuousIndex(centroid)
- centroid = [int(d) for d in centroid]
-
- # Orient all images to LPI orientation
- image = sitk.DICOMOrient(image, "LPI") if self.orient else image
-
- # Extract positive with a specified radius around centroid
- patch_idx = [(c - self.radius, c + self.radius) for c in centroid]
- patch_image = slice_image(image, patch_idx)
-
- patch_image = sitk.DICOMOrient(patch_image, "LPS") if self.orient_patch else patch_image
-
- array = sitk.GetArrayFromImage(patch_image)
- tensor = array if self.transform is None else self.transform(array)
-
- if self.label is not None:
- target = int(row[self.label])
- elif self.input_is_target:
- target = tensor.clone()
- else:
- target = None
-
- if self.enable_negatives:
- return {"positive": tensor, "negative": self.get_negative_sample(image)}, target
-
- return tensor, target
-
-
-if __name__ == "__main__":
- from pathlib import Path
-
- # Test pytorch dataset
- print("Test pytorch dataset")
- dataset = SSLRadiomicsDataset(
- "/home/suraj/Repositories/cancer-imaging-ssl/src/pretraining/data_csv/deeplesion/train.csv",
- orient=True,
- resample_spacing=[1, 1, 1],
- )
-
- # Visualize item from dataset
- item = dataset[0]
-
- positive = sitk.GetImageFromArray(item[0][0])
- negative = sitk.GetImageFromArray(item[0][1])
- current_dir = Path(__file__).parent.resolve()
-
- sitk.WriteImage(positive, f"{str(current_dir)}/tests/positive.nrrd")
- sitk.WriteImage(negative, f"{str(current_dir)}/tests/negative.nrrd")
-
-class SSLRadiomicsDataset
-(path, label=None, radius=25, orient=False, resample_spacing=None, enable_negatives=True, transform=None, orient_patch=True, input_is_target=False)
-
Dataset class for SSL Radiomics dataset.
-path
: str
label
: str
, optionalradius
: int
, optionalorient
: bool
, optionalresample_spacing
: float
or tuple
, optionalenable_negatives
: bool
, optionaltransform
: callable
, optionalCreates an instance of the SSLRadiomicsDataset class with the given parameters.
-path
: str
label
: Optional[str]
radius
: int
orient
: bool
resample_spacing
: Optional[…]
enable_negatives
: bool
transform
orient_patch
: bool
input_is_target
: bool
None.
-None.
class SSLRadiomicsDataset(Dataset):
- """
- Dataset class for SSL Radiomics dataset.
-
- Args:
- path (str): The path to the dataset.
- label (str, optional): The label column name in the dataset annotations. Default is None.
- radius (int, optional): The radius around the centroid for positive patch extraction. Default is 25.
- orient (bool, optional): Whether to orient the images to LPI orientation. Default is False.
- resample_spacing (float or tuple, optional): The desired spacing for resampling the images. Default is None.
- enable_negatives (bool, optional): Whether to include negative samples. Default is True.
- transform (callable, optional): A function/transform to apply on the images. Default is None.
- """
-
- def __init__(
- self,
- path,
- label=None,
- radius=25,
- orient=False,
- resample_spacing=None,
- enable_negatives=True,
- transform=None,
- orient_patch=True,
- input_is_target=False,
- ):
- """
- Creates an instance of the SSLRadiomicsDataset class with the given parameters.
-
- Args:
- path (str): The path to the dataset.
- label (Optional[str]): The label to use for the dataset. Defaults to None.
- radius (int): The radius parameter. Defaults to 25.
- orient (bool): True if the dataset should be oriented, False otherwise. Defaults to False.
- resample_spacing (Optional[...]): The resample spacing parameter. Defaults to None.
- enable_negatives (bool): True if negatives are enabled, False otherwise. Defaults to True.
- transform: The transformation to apply to the dataset. Defaults to None.
- orient_patch (bool): True if the patch should be oriented, False otherwise. Defaults to True.
- input_is_target (bool): True if the input is the target, False otherwise. Defaults to False.
-
- Raises:
- None.
-
- Returns:
- None.
- """
- monai.data.set_track_meta(False)
- sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
- super(SSLRadiomicsDataset, self).__init__()
- self._path = Path(path)
-
- self.radius = radius
- self.orient = orient
- self.resample_spacing = resample_spacing
- self.label = label
- self.enable_negatives = enable_negatives
- self.transform = transform
- self.orient_patch = orient_patch
- self.input_is_target = input_is_target
- self.annotations = pd.read_csv(self._path)
- self._num_samples = len(self.annotations) # set the length of the dataset
-
- def get_rows(self):
- """
- Get the rows of the annotations as a list of dictionaries.
-
- Returns:
- list of dict: The rows of the annotations as dictionaries.
- """
- return self.annotations.to_dict(orient="records")
-
- def get_labels(self):
- """
- Function to get labels for when they are available in the dataset.
-
- Args:
- None
-
- Returns:
- None
- """
-
- labels = self.annotations[self.label].values
- assert not np.any(labels == -1), "All labels must be specified"
- return labels
-
- def __len__(self):
- """
- Size of the dataset.
- """
- return self._num_samples
-
- def get_negative_sample(self, image):
- """
- Extract a negative sample from the image background with no overlap to the positive sample.
-
- Parameters:
- image: Image to extract sample
- positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
- """
- positive_patch_size = [self.radius * 2] * 3
- valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
- def get_random_patch():
- """
- Get a random patch from an image.
-
- Returns:
- list: A list containing the start and end indices of the random patch.
- """
- random_patch_idx = [
- [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
- ]
- return random_patch_idx
-
- random_patch_idx = get_random_patch()
-
- # escape_count = 0
- # while is_overlapping(positive_patch_idx, random_patch_idx):
- # if escape_count >= 3:
- # logger.warning("Random patch has overlap with positive patch")
- # return None
-
- # random_patch_idx = get_random_patch()
- # escape_count += 1
-
- random_patch = slice_image(image, random_patch_idx)
- random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
- negative_array = sitk.GetArrayFromImage(random_patch)
-
- negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
- return negative_tensor
-
- def __getitem__(self, idx: int):
- """
- Implement how to load the data corresponding to the idx element in the dataset from your data source.
- """
-
- # Get a row from the CSV file
- row = self.annotations.iloc[idx]
- image_path = row["image_path"]
- image = sitk.ReadImage(str(image_path))
- image = resample_image_to_spacing(image, self.resample_spacing, -1024) if self.resample_spacing is not None else image
-
- centroid = (row["coordX"], row["coordY"], row["coordZ"])
- centroid = image.TransformPhysicalPointToContinuousIndex(centroid)
- centroid = [int(d) for d in centroid]
-
- # Orient all images to LPI orientation
- image = sitk.DICOMOrient(image, "LPI") if self.orient else image
-
- # Extract positive with a specified radius around centroid
- patch_idx = [(c - self.radius, c + self.radius) for c in centroid]
- patch_image = slice_image(image, patch_idx)
-
- patch_image = sitk.DICOMOrient(patch_image, "LPS") if self.orient_patch else patch_image
-
- array = sitk.GetArrayFromImage(patch_image)
- tensor = array if self.transform is None else self.transform(array)
-
- if self.label is not None:
- target = int(row[self.label])
- elif self.input_is_target:
- target = tensor.clone()
- else:
- target = None
-
- if self.enable_negatives:
- return {"positive": tensor, "negative": self.get_negative_sample(image)}, target
-
- return tensor, target
-
-def get_labels(self)
-
Function to get labels for when they are available in the dataset.
-None
-None
def get_labels(self):
- """
- Function to get labels for when they are available in the dataset.
-
- Args:
- None
-
- Returns:
- None
- """
-
- labels = self.annotations[self.label].values
- assert not np.any(labels == -1), "All labels must be specified"
- return labels
-
-def get_negative_sample(self, image)
-
Extract a negative sample from the image background with no overlap to the positive sample.
-image: Image to extract sample -positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
def get_negative_sample(self, image):
- """
- Extract a negative sample from the image background with no overlap to the positive sample.
-
- Parameters:
- image: Image to extract sample
- positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
- """
- positive_patch_size = [self.radius * 2] * 3
- valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
- def get_random_patch():
- """
- Get a random patch from an image.
-
- Returns:
- list: A list containing the start and end indices of the random patch.
- """
- random_patch_idx = [
- [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
- ]
- return random_patch_idx
-
- random_patch_idx = get_random_patch()
-
- # escape_count = 0
- # while is_overlapping(positive_patch_idx, random_patch_idx):
- # if escape_count >= 3:
- # logger.warning("Random patch has overlap with positive patch")
- # return None
-
- # random_patch_idx = get_random_patch()
- # escape_count += 1
-
- random_patch = slice_image(image, random_patch_idx)
- random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
- negative_array = sitk.GetArrayFromImage(random_patch)
-
- negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
- return negative_tensor
-
-def get_rows(self)
-
Get the rows of the annotations as a list of dictionaries.
-list
of dict
def get_rows(self):
- """
- Get the rows of the annotations as a list of dictionaries.
-
- Returns:
- list of dict: The rows of the annotations as dictionaries.
- """
- return self.annotations.to_dict(orient="records")
-fmcib.datasets.utils
from pathlib import Path
-
-import numpy as np
-import SimpleITK as sitk
-
-# https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
-SITK_INTERPOLATOR_DICT = {
- "nearest": sitk.sitkNearestNeighbor,
- "linear": sitk.sitkLinear,
- "gaussian": sitk.sitkGaussian,
- "label_gaussian": sitk.sitkLabelGaussian,
- "bspline": sitk.sitkBSpline,
- "hamming_sinc": sitk.sitkHammingWindowedSinc,
- "cosine_windowed_sinc": sitk.sitkCosineWindowedSinc,
- "welch_windowed_sinc": sitk.sitkWelchWindowedSinc,
- "lanczos_windowed_sinc": sitk.sitkLanczosWindowedSinc,
-}
-
-
-def resample_image_to_spacing(image, new_spacing, default_value, interpolator="linear"):
- """
- Resample an image to a new spacing.
- """
- assert interpolator in SITK_INTERPOLATOR_DICT, (
- f"Interpolator '{interpolator}' not part of SimpleITK. "
- f"Please choose one of the following {list(SITK_INTERPOLATOR_DICT.keys())}."
- )
-
- assert image.GetDimension() == len(new_spacing), (
- f"Input is {image.GetDimension()}-dimensional while " f"the new spacing is {len(new_spacing)}-dimensional."
- )
-
- interpolator = SITK_INTERPOLATOR_DICT[interpolator]
- spacing = image.GetSpacing()
- size = image.GetSize()
- new_size = [int(round(siz * spac / n_spac)) for siz, spac, n_spac in zip(size, spacing, new_spacing)]
- return sitk.Resample(
- image,
- new_size, # size
- sitk.Transform(), # transform
- interpolator, # interpolator
- image.GetOrigin(), # outputOrigin
- new_spacing, # outputSpacing
- image.GetDirection(), # outputDirection
- default_value, # defaultPixelValue
- image.GetPixelID(),
- ) # outputPixelType
-
-
-def slice_image(image, patch_idx):
- """
- Slice an image.
- """
-
- start, stop = zip(*patch_idx)
- slice_filter = sitk.SliceImageFilter()
- slice_filter.SetStart(start)
- slice_filter.SetStop(stop)
- return slice_filter.Execute(image)
-
-
-def is_overlapping(patch1, patch2):
- """
- Check if two patches are overlapping.
-
- Args:
- patch1 (list of tuples): A list of tuples representing the ranges of each axis in patch1.
- patch2 (list of tuples): A list of tuples representing the ranges of each axis in patch2.
-
- Returns:
- bool: True if the two patches overlap, False otherwise.
-
- Note:
- This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch.
- The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range.
- The patches are considered overlapping if there is any overlap in the ranges of their axes.
- """
- overlap_by_axis = [max(axis1[0], axis2[0]) < min(axis1[1], axis2[1]) for axis1, axis2 in zip(patch1, patch2)]
-
- return np.all(overlap_by_axis)
-
-def is_overlapping(patch1, patch2)
-
Check if two patches are overlapping.
-patch1
: list
of tuples
patch2
: list
of tuples
bool
This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch. -The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range. -The patches are considered overlapping if there is any overlap in the ranges of their axes.
def is_overlapping(patch1, patch2):
- """
- Check if two patches are overlapping.
-
- Args:
- patch1 (list of tuples): A list of tuples representing the ranges of each axis in patch1.
- patch2 (list of tuples): A list of tuples representing the ranges of each axis in patch2.
-
- Returns:
- bool: True if the two patches overlap, False otherwise.
-
- Note:
- This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch.
- The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range.
- The patches are considered overlapping if there is any overlap in the ranges of their axes.
- """
- overlap_by_axis = [max(axis1[0], axis2[0]) < min(axis1[1], axis2[1]) for axis1, axis2 in zip(patch1, patch2)]
-
- return np.all(overlap_by_axis)
-
-def resample_image_to_spacing(image, new_spacing, default_value, interpolator='linear')
-
Resample an image to a new spacing.
def resample_image_to_spacing(image, new_spacing, default_value, interpolator="linear"):
- """
- Resample an image to a new spacing.
- """
- assert interpolator in SITK_INTERPOLATOR_DICT, (
- f"Interpolator '{interpolator}' not part of SimpleITK. "
- f"Please choose one of the following {list(SITK_INTERPOLATOR_DICT.keys())}."
- )
-
- assert image.GetDimension() == len(new_spacing), (
- f"Input is {image.GetDimension()}-dimensional while " f"the new spacing is {len(new_spacing)}-dimensional."
- )
-
- interpolator = SITK_INTERPOLATOR_DICT[interpolator]
- spacing = image.GetSpacing()
- size = image.GetSize()
- new_size = [int(round(siz * spac / n_spac)) for siz, spac, n_spac in zip(size, spacing, new_spacing)]
- return sitk.Resample(
- image,
- new_size, # size
- sitk.Transform(), # transform
- interpolator, # interpolator
- image.GetOrigin(), # outputOrigin
- new_spacing, # outputSpacing
- image.GetDirection(), # outputDirection
- default_value, # defaultPixelValue
- image.GetPixelID(),
- ) # outputPixelType
-
-def slice_image(image, patch_idx)
-
Slice an image.
def slice_image(image, patch_idx):
- """
- Slice an image.
- """
-
- start, stop = zip(*patch_idx)
- slice_filter = sitk.SliceImageFilter()
- slice_filter.SetStart(start)
- slice_filter.SetStop(stop)
- return slice_filter.Execute(image)
-fmcib
__version__ = "0.0.1a22"
-fmcib.callbacks
fmcib.datasets
fmcib.models
fmcib.optimizers
fmcib.preprocessing
fmcib.run
fmcib.ssl
fmcib.transforms
fmcib.utils
fmcib.visualization
fmcib.models.autoencoder
import torch
-import torch.nn as nn
-from monai.networks.blocks import Convolution, ResidualUnit
-from monai.networks.nets import AutoEncoder
-
-
-class CustomAE(AutoEncoder):
- """
- A custom AutoEncoder class.
-
- Inherits from AutoEncoder.
-
- Attributes:
- padding (int): The padding size for the convolutional layers.
- decoder (bool, optional): Determines if the decoder part of the network is included.
- kwargs: Additional keyword arguments passed to the parent class.
-
- Methods:
- _get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network.
- _get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.
- """
-
- def __init__(self, padding, decoder=True, **kwargs):
- """
- Initialize the object.
-
- Args:
- padding (int): Padding value.
- decoder (bool, optional): If True, use a decoder. Defaults to True.
- **kwargs: Additional keyword arguments.
-
- Attributes:
- padding (int): Padding value.
-
- Raises:
- None
- """
- self.padding = padding
- super().__init__(**kwargs)
- if not decoder:
- self.decode = nn.Sequential(nn.AvgPool3d(3), nn.Flatten())
-
- def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:
- """
- Returns a single layer of the encoder part of the network.
- """
- mod: nn.Module
- if self.num_res_units > 0:
- mod = ResidualUnit(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.kernel_size,
- padding=self.padding,
- subunits=self.num_res_units,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- last_conv_only=is_last,
- )
- return mod
- mod = Convolution(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.kernel_size,
- padding=self.padding,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- conv_only=is_last,
- )
- return mod
-
- def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:
- """
- Returns a single layer of the decoder part of the network.
- """
- decode = nn.Sequential()
-
- conv = Convolution(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.up_kernel_size,
- padding=self.padding,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- conv_only=is_last and self.num_res_units == 0,
- is_transposed=True,
- )
-
- decode.add_module("conv", conv)
-
- if self.num_res_units > 0:
- ru = ResidualUnit(
- spatial_dims=self.dimensions,
- in_channels=out_channels,
- out_channels=out_channels,
- padding=self.padding,
- strides=strides,
- kernel_size=self.kernel_size,
- subunits=1,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- last_conv_only=is_last,
- )
-
- decode.add_module("resunit", ru)
-
- return decode
-
-class CustomAE
-(padding, decoder=True, **kwargs)
-
A custom AutoEncoder class.
-Inherits from AutoEncoder.
-padding
: int
decoder
: bool
, optionalkwargs
_get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network. -_get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.
-Initialize the object.
-padding
: int
decoder
: bool
, optional**kwargs
padding
: int
None
class CustomAE(AutoEncoder):
- """
- A custom AutoEncoder class.
-
- Inherits from AutoEncoder.
-
- Attributes:
- padding (int): The padding size for the convolutional layers.
- decoder (bool, optional): Determines if the decoder part of the network is included.
- kwargs: Additional keyword arguments passed to the parent class.
-
- Methods:
- _get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network.
- _get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.
- """
-
- def __init__(self, padding, decoder=True, **kwargs):
- """
- Initialize the object.
-
- Args:
- padding (int): Padding value.
- decoder (bool, optional): If True, use a decoder. Defaults to True.
- **kwargs: Additional keyword arguments.
-
- Attributes:
- padding (int): Padding value.
-
- Raises:
- None
- """
- self.padding = padding
- super().__init__(**kwargs)
- if not decoder:
- self.decode = nn.Sequential(nn.AvgPool3d(3), nn.Flatten())
-
- def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:
- """
- Returns a single layer of the encoder part of the network.
- """
- mod: nn.Module
- if self.num_res_units > 0:
- mod = ResidualUnit(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.kernel_size,
- padding=self.padding,
- subunits=self.num_res_units,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- last_conv_only=is_last,
- )
- return mod
- mod = Convolution(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.kernel_size,
- padding=self.padding,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- conv_only=is_last,
- )
- return mod
-
- def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:
- """
- Returns a single layer of the decoder part of the network.
- """
- decode = nn.Sequential()
-
- conv = Convolution(
- spatial_dims=self.dimensions,
- in_channels=in_channels,
- out_channels=out_channels,
- strides=strides,
- kernel_size=self.up_kernel_size,
- padding=self.padding,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- conv_only=is_last and self.num_res_units == 0,
- is_transposed=True,
- )
-
- decode.add_module("conv", conv)
-
- if self.num_res_units > 0:
- ru = ResidualUnit(
- spatial_dims=self.dimensions,
- in_channels=out_channels,
- out_channels=out_channels,
- padding=self.padding,
- strides=strides,
- kernel_size=self.kernel_size,
- subunits=1,
- act=self.act,
- norm=self.norm,
- dropout=self.dropout,
- bias=self.bias,
- last_conv_only=is_last,
- )
-
- decode.add_module("resunit", ru)
-
- return decode
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x: torch.Tensor) ‑> Any
-
Define the computation performed at every call.
-Should be overridden by all subclasses.
-Note
-Although the recipe for forward pass needs to be defined within
-this function, one should call the :class:Module
instance afterwards
-instead of this since the former takes care of running the
-registered hooks while the latter silently ignores them.
def forward(self, x: torch.Tensor) -> Any:
- x = self.encode(x)
- x = self.intermediate(x)
- x = self.decode(x)
- return x
-fmcib.models
import os
-import pickle
-from pathlib import Path
-
-import wget
-from monai.networks.nets import resnet50
-
-from fmcib.utils.download_utils import bar_progress
-
-from .autoencoder import CustomAE as AutoEncoder
-from .load_model import LoadModel
-from .models_genesis import UNet3D as ModelsGenesisUNet3D
-
-
-def get_linear_classifier(weights_path=None, download_url="https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1"):
- if weights_path is None:
- weights_path = "/tmp/linear_model.pkl"
- wget.download(download_url, out=weights_path)
-
- return pickle.load(open(weights_path, "rb"))
-
-
-def fmcib_model():
- trunk = resnet50(
- pretrained=False,
- n_input_channels=1,
- widen_factor=2,
- conv1_t_stride=2,
- feed_forward=False,
- bias_downsample=True,
- )
- weights_url = "https://zenodo.org/records/10528450/files/model_weights.torch?download=1"
- current_path = Path(os.getcwd())
- if not (current_path / "model_weights.torch").exists():
- wget.download(weights_url, bar=bar_progress)
- model = LoadModel(trunk=trunk, weights_path=current_path / "model_weights.torch", heads=[])
- return model
-fmcib.models.autoencoder
fmcib.models.load_model
fmcib.models.models_genesis
-def fmcib_model()
-
def fmcib_model():
- trunk = resnet50(
- pretrained=False,
- n_input_channels=1,
- widen_factor=2,
- conv1_t_stride=2,
- feed_forward=False,
- bias_downsample=True,
- )
- weights_url = "https://zenodo.org/records/10528450/files/model_weights.torch?download=1"
- current_path = Path(os.getcwd())
- if not (current_path / "model_weights.torch").exists():
- wget.download(weights_url, bar=bar_progress)
- model = LoadModel(trunk=trunk, weights_path=current_path / "model_weights.torch", heads=[])
- return model
-
-def get_linear_classifier(weights_path=None, download_url='https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1')
-
def get_linear_classifier(weights_path=None, download_url="https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1"):
- if weights_path is None:
- weights_path = "/tmp/linear_model.pkl"
- wget.download(download_url, out=weights_path)
-
- return pickle.load(open(weights_path, "rb"))
-fmcib.models.load_model
from collections import OrderedDict
-
-import torch
-from loguru import logger
-from torch import nn
-
-
-class LoadModel(nn.Module):
- """
- A class representing a loaded model.
-
- Args:
- trunk (nn.Module, optional): The trunk of the model. Defaults to None.
- weights_path (str, optional): The path to the weights file. Defaults to None.
- heads (list, optional): The list of head layers in the model. Defaults to [].
-
- Attributes:
- trunk (nn.Module): The trunk of the model.
- heads (nn.Sequential): The concatenated head layers of the model.
-
- Methods:
- forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model.
- load(weights): Load the pretrained model weights.
- """
-
- def __init__(self, trunk=None, weights_path=None, heads=[]) -> None:
- """
- Initialize the model.
-
- Args:
- trunk (optional): The trunk of the model.
- weights_path (optional): The path to the weights file.
- heads (list, optional): A list of layer sizes for the heads of the model.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__()
- self.trunk = trunk
- head_layers = []
- for idx in range(len(heads) - 1):
- current_layers = []
- current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
- if idx != (len(heads) - 2):
- current_layers.append(nn.ReLU(inplace=True))
-
- head_layers.append(nn.Sequential(*current_layers))
-
- if len(head_layers):
- self.heads = nn.Sequential(*head_layers)
- else:
- self.heads = nn.Identity()
-
- if weights_path is not None:
- self.load(weights_path)
-
- def forward(self, x: torch.Tensor):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): The input tensor.
-
- Returns:
- torch.Tensor: The output tensor.
- """
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
- def load(self, weights):
- """
- Load pretrained model weights from a file.
-
- Args:
- weights (str): The path to the file containing the pretrained model weights.
-
- Raises:
- KeyError: If the input weights file does not contain the expected keys.
- Exception: If there is an error when loading the pretrained heads.
-
- Returns:
- None.
-
- Note:
- This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
- Warnings:
- - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
- - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
- Raises the appropriate warnings and logs informational messages.
- """
- pretrained_model = torch.load(weights)
-
- if "trunk_state_dict" in pretrained_model: # Loading ViSSL pretrained model
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- if "state_dict" in pretrained_model: # Loading Med3D pretrained model
- trained_model = pretrained_model["state_dict"]
-
- # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
- weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
- weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
- msg = self.trunk.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
- msg = self.heads.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-
-class LoadModel
-(trunk=None, weights_path=None, heads=[])
-
A class representing a loaded model.
-trunk
: nn.Module
, optionalweights_path
: str
, optionalheads
: list
, optionaltrunk
: nn.Module
heads
: nn.Sequential
forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model. -load(weights): Load the pretrained model weights.
-Initialize the model.
-trunk
: optional
weights_path
: optional
heads
: list
, optionalNone
-None
class LoadModel(nn.Module):
- """
- A class representing a loaded model.
-
- Args:
- trunk (nn.Module, optional): The trunk of the model. Defaults to None.
- weights_path (str, optional): The path to the weights file. Defaults to None.
- heads (list, optional): The list of head layers in the model. Defaults to [].
-
- Attributes:
- trunk (nn.Module): The trunk of the model.
- heads (nn.Sequential): The concatenated head layers of the model.
-
- Methods:
- forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model.
- load(weights): Load the pretrained model weights.
- """
-
- def __init__(self, trunk=None, weights_path=None, heads=[]) -> None:
- """
- Initialize the model.
-
- Args:
- trunk (optional): The trunk of the model.
- weights_path (optional): The path to the weights file.
- heads (list, optional): A list of layer sizes for the heads of the model.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__()
- self.trunk = trunk
- head_layers = []
- for idx in range(len(heads) - 1):
- current_layers = []
- current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
- if idx != (len(heads) - 2):
- current_layers.append(nn.ReLU(inplace=True))
-
- head_layers.append(nn.Sequential(*current_layers))
-
- if len(head_layers):
- self.heads = nn.Sequential(*head_layers)
- else:
- self.heads = nn.Identity()
-
- if weights_path is not None:
- self.load(weights_path)
-
- def forward(self, x: torch.Tensor):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): The input tensor.
-
- Returns:
- torch.Tensor: The output tensor.
- """
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
- def load(self, weights):
- """
- Load pretrained model weights from a file.
-
- Args:
- weights (str): The path to the file containing the pretrained model weights.
-
- Raises:
- KeyError: If the input weights file does not contain the expected keys.
- Exception: If there is an error when loading the pretrained heads.
-
- Returns:
- None.
-
- Note:
- This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
- Warnings:
- - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
- - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
- Raises the appropriate warnings and logs informational messages.
- """
- pretrained_model = torch.load(weights)
-
- if "trunk_state_dict" in pretrained_model: # Loading ViSSL pretrained model
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- if "state_dict" in pretrained_model: # Loading Med3D pretrained model
- trained_model = pretrained_model["state_dict"]
-
- # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
- weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
- weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
- msg = self.trunk.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
- msg = self.heads.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x: torch.Tensor) ‑> Callable[..., Any]
-
Forward pass of the neural network.
-x
: torch.Tensor
torch.Tensor
def forward(self, x: torch.Tensor):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): The input tensor.
-
- Returns:
- torch.Tensor: The output tensor.
- """
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
-def load(self, weights)
-
Load pretrained model weights from a file.
-weights
: str
KeyError
Exception
None.
-This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-Raises the appropriate warnings and logs informational messages.
def load(self, weights):
- """
- Load pretrained model weights from a file.
-
- Args:
- weights (str): The path to the file containing the pretrained model weights.
-
- Raises:
- KeyError: If the input weights file does not contain the expected keys.
- Exception: If there is an error when loading the pretrained heads.
-
- Returns:
- None.
-
- Note:
- This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
- Warnings:
- - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
- - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
- Raises the appropriate warnings and logs informational messages.
- """
- pretrained_model = torch.load(weights)
-
- if "trunk_state_dict" in pretrained_model: # Loading ViSSL pretrained model
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- if "state_dict" in pretrained_model: # Loading Med3D pretrained model
- trained_model = pretrained_model["state_dict"]
-
- # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
- weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
- weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
- msg = self.trunk.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
- msg = self.heads.load_state_dict(weights, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-fmcib.models.models_genesis
import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
- """
- A class representing a 3D contextual batch normalization layer.
-
- Attributes:
- running_mean (torch.Tensor): The running mean of the batch normalization.
- running_var (torch.Tensor): The running variance of the batch normalization.
- weight (torch.Tensor): The learnable weights of the batch normalization.
- bias (torch.Tensor): The learnable bias of the batch normalization.
- momentum (float): The momentum for updating the running statistics.
- eps (float): Small value added to the denominator for numerical stability.
- """
-
- def _check_input_dim(self, input):
- """
- Check if the input tensor is 5-dimensional.
-
- Args:
- input (torch.Tensor): Input tensor to check the dimensionality.
-
- Raises:
- ValueError: If the input tensor is not 5-dimensional.
- """
- if input.dim() != 5:
- raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
- # super(ContBatchNorm3d, self)._check_input_dim(input)
-
- def forward(self, input):
- """
- Apply forward pass for the input through batch normalization layer.
-
- Args:
- input (Tensor): Input tensor to be normalized.
-
- Returns:
- Tensor: Normalized output tensor.
-
- Raises:
- ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
- """
- self._check_input_dim(input)
- return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-
-
-class LUConv(nn.Module):
- """
- A class representing a LUConv module.
-
- This module performs a convolution operation on the input data with a specified number of input channels and output channels.
- The convolution is followed by batch normalization and an activation function.
-
- Attributes:
- in_chan (int): The number of input channels.
- out_chan (int): The number of output channels.
- act (str): The activation function to be applied. Can be one of 'relu', 'prelu', or 'elu'.
- """
-
- def __init__(self, in_chan, out_chan, act):
- """
- Initialize a LUConv layer.
-
- Args:
- in_chan (int): Number of input channels.
- out_chan (int): Number of output channels.
- act (str): Activation function. Options: 'relu', 'prelu', 'elu'.
-
- Returns:
- None
-
- Raises:
- TypeError: If the activation function is not one of the specified options.
- """
- super(LUConv, self).__init__()
- self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
- self.bn1 = ContBatchNorm3d(out_chan)
-
- if act == "relu":
- self.activation = nn.ReLU(out_chan)
- elif act == "prelu":
- self.activation = nn.PReLU(out_chan)
- elif act == "elu":
- self.activation = nn.ELU(inplace=True)
- else:
- raise
-
- def forward(self, x):
- """
- Apply forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor after passing through the network.
- """
- out = self.activation(self.bn1(self.conv1(x)))
- return out
-
-
-def _make_nConv(in_channel, depth, act, double_chnnel=False):
- """
- Make a two-layer convolutional neural network module.
-
- Args:
- in_channel (int): The number of input channels.
- depth (int): The depth of the network.
- act: Activation function to be used in the network.
- double_channel (bool, optional): If True, double the number of channels in the network. Defaults to False.
-
- Returns:
- nn.Sequential: A sequential module representing the two-layer convolutional network.
-
- Note:
- - If double_channel is True, the first layer will have 32 * 2 ** (depth + 1) channels and the second layer will have the same number of channels.
- - If double_channel is False, the first layer will have 32 * 2 ** depth channels and the second layer will have 32 * 2 ** depth * 2 channels.
- """
- if double_chnnel:
- layer1 = LUConv(in_channel, 32 * (2 ** (depth + 1)), act)
- layer2 = LUConv(32 * (2 ** (depth + 1)), 32 * (2 ** (depth + 1)), act)
- else:
- layer1 = LUConv(in_channel, 32 * (2**depth), act)
- layer2 = LUConv(32 * (2**depth), 32 * (2**depth) * 2, act)
-
- return nn.Sequential(layer1, layer2)
-
-
-# class InputTransition(nn.Module):
-# def __init__(self, outChans, elu):
-# super(InputTransition, self).__init__()
-# self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
-# self.bn1 = ContBatchNorm3d(16)
-# self.relu1 = ELUCons(elu, 16)
-#
-# def forward(self, x):
-# # do we want a PRELU here as well?
-# out = self.bn1(self.conv1(x))
-# # split input in to 16 channels
-# x16 = torch.cat((x, x, x, x, x, x, x, x,
-# x, x, x, x, x, x, x, x), 1)
-# out = self.relu1(torch.add(out, x16))
-# return out
-
-
-class DownTransition(nn.Module):
- """
- A class representing a down transition module in a neural network.
-
- Attributes:
- in_channel (int): The number of input channels.
- depth (int): The depth of the down transition module.
- act (nn.Module): The activation function used in the module.
- """
-
- def __init__(self, in_channel, depth, act):
- """
- Initialize a DownTransition object.
-
- Args:
- in_channel (int): The number of channels in the input.
- depth (int): The depth of the DownTransition.
- act (function): The activation function.
-
- Returns:
- None
-
- Raises:
- None
- """
- super(DownTransition, self).__init__()
- self.ops = _make_nConv(in_channel, depth, act)
- self.maxpool = nn.MaxPool3d(2)
- self.current_depth = depth
-
- def forward(self, x):
- """
- Perform a forward pass through the neural network.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
- Raises:
- None
- """
- if self.current_depth == 3:
- out = self.ops(x)
- out_before_pool = out
- else:
- out_before_pool = self.ops(x)
- out = self.maxpool(out_before_pool)
- return out, out_before_pool
-
-
-class UpTransition(nn.Module):
- """
- A class representing an up transition layer in a neural network.
-
- Attributes:
- inChans (int): The number of input channels.
- outChans (int): The number of output channels.
- depth (int): The depth of the layer.
- act (str): The activation function to be applied.
- """
-
- def __init__(self, inChans, outChans, depth, act):
- """
- Initialize the UpTransition module.
-
- Args:
- inChans (int): The number of input channels.
- outChans (int): The number of output channels.
- depth (int): The depth of the module.
- act (nn.Module): The activation function to be used.
-
- Returns:
- None.
-
- Raises:
- None.
- """
- super(UpTransition, self).__init__()
- self.depth = depth
- self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
- self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True)
-
- def forward(self, x, skip_x):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): Input tensor.
- skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
- Returns:
- torch.Tensor: The output tensor after passing through the network.
- """
- out_up_conv = self.up_conv(x)
- concat = torch.cat((out_up_conv, skip_x), 1)
- out = self.ops(concat)
- return out
-
-
-class OutputTransition(nn.Module):
- """
- A class representing the output transition in a neural network.
-
- Attributes:
- inChans (int): The number of input channels.
- n_labels (int): The number of output labels.
- """
-
- def __init__(self, inChans, n_labels):
- """
- Initialize the OutputTransition class.
-
- Args:
- inChans (int): Number of input channels.
- n_labels (int): Number of output labels.
-
- Returns:
- None
-
- Raises:
- None
- """
- super(OutputTransition, self).__init__()
- self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- """
- Forward pass through a neural network model.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- Tensor: The output tensor after passing through the model.
- """
- out = self.sigmoid(self.final_conv(x))
- return out
-
-
-class UNet3D(nn.Module):
- # the number of convolutions in each layer corresponds
- # to what is in the actual prototxt, not the intent
- """
- A class representing a 3D UNet model for segmentation.
-
- Attributes:
- n_class (int): The number of classes for segmentation.
- act (str): The activation function type used in the model.
- decoder (bool): Whether to include the decoder part in the model.
-
- Methods:
- forward(x): Forward pass of the model.
- """
-
- def __init__(self, n_class=1, act="relu", decoder=True):
- """
- Initialize a 3D UNet neural network model.
-
- Args:
- n_class (int): The number of output classes. Defaults to 1.
- act (str): The activation function to use. Defaults to 'relu'.
- decoder (bool): Whether to include the decoder layers. Defaults to True.
-
- Attributes:
- decoder (bool): Whether the model includes decoder layers.
- down_tr64 (DownTransition): The first down transition layer.
- down_tr128 (DownTransition): The second down transition layer.
- down_tr256 (DownTransition): The third down transition layer.
- down_tr512 (DownTransition): The fourth down transition layer.
- up_tr256 (UpTransition): The first up transition layer. (Only exists if `decoder` is True)
- up_tr128 (UpTransition): The second up transition layer. (Only exists if `decoder` is True)
- up_tr64 (UpTransition): The third up transition layer. (Only exists if `decoder` is True)
- out_tr (OutputTransition): The output transition layer. (Only exists if `decoder` is True)
- avg_pool (nn.AvgPool3d): The average pooling layer. (Only exists if `decoder` is False)
- flatten (nn.Flatten): The flattening layer. (Only exists if `decoder` is False)
- """
- super(UNet3D, self).__init__()
-
- self.decoder = decoder
-
- self.down_tr64 = DownTransition(1, 0, act)
- self.down_tr128 = DownTransition(64, 1, act)
- self.down_tr256 = DownTransition(128, 2, act)
- self.down_tr512 = DownTransition(256, 3, act)
-
- if self.decoder:
- self.up_tr256 = UpTransition(512, 512, 2, act)
- self.up_tr128 = UpTransition(256, 256, 1, act)
- self.up_tr64 = UpTransition(128, 128, 0, act)
- self.out_tr = OutputTransition(64, n_class)
- else:
- self.avg_pool = nn.AvgPool3d(3, stride=2)
- self.flatten = nn.Flatten()
-
- def forward(self, x):
- """
- Perform forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor from the network.
-
- Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
- Raises:
- None.
- """
- self.out64, self.skip_out64 = self.down_tr64(x)
- self.out128, self.skip_out128 = self.down_tr128(self.out64)
- self.out256, self.skip_out256 = self.down_tr256(self.out128)
- self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
- if self.decoder:
- self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
- self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
- self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
- self.out = self.out_tr(self.out_up_64)
- else:
- self.out = self.avg_pool(self.out512)
- self.out = self.flatten(self.out)
-
- return self.out
-
-class ContBatchNorm3d
-(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, dtype=None)
-
A class representing a 3D contextual batch normalization layer.
-running_mean
: torch.Tensor
running_var
: torch.Tensor
weight
: torch.Tensor
bias
: torch.Tensor
momentum
: float
eps
: float
Initialize internal Module state, shared by both nn.Module and ScriptModule.
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
- """
- A class representing a 3D contextual batch normalization layer.
-
- Attributes:
- running_mean (torch.Tensor): The running mean of the batch normalization.
- running_var (torch.Tensor): The running variance of the batch normalization.
- weight (torch.Tensor): The learnable weights of the batch normalization.
- bias (torch.Tensor): The learnable bias of the batch normalization.
- momentum (float): The momentum for updating the running statistics.
- eps (float): Small value added to the denominator for numerical stability.
- """
-
- def _check_input_dim(self, input):
- """
- Check if the input tensor is 5-dimensional.
-
- Args:
- input (torch.Tensor): Input tensor to check the dimensionality.
-
- Raises:
- ValueError: If the input tensor is not 5-dimensional.
- """
- if input.dim() != 5:
- raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
- # super(ContBatchNorm3d, self)._check_input_dim(input)
-
- def forward(self, input):
- """
- Apply forward pass for the input through batch normalization layer.
-
- Args:
- input (Tensor): Input tensor to be normalized.
-
- Returns:
- Tensor: Normalized output tensor.
-
- Raises:
- ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
- """
- self._check_input_dim(input)
- return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-var affine : bool
var eps : float
var momentum : float
var num_features : int
var track_running_stats : bool
-def forward(self, input) ‑> Callable[..., Any]
-
Apply forward pass for the input through batch normalization layer.
-input
: Tensor
Tensor
ValueError
def forward(self, input):
- """
- Apply forward pass for the input through batch normalization layer.
-
- Args:
- input (Tensor): Input tensor to be normalized.
-
- Returns:
- Tensor: Normalized output tensor.
-
- Raises:
- ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
- """
- self._check_input_dim(input)
- return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-
-class DownTransition
-(in_channel, depth, act)
-
A class representing a down transition module in a neural network.
-in_channel
: int
depth
: int
act
: nn.Module
Initialize a DownTransition object.
-in_channel
: int
depth
: int
act
: function
None
-None
class DownTransition(nn.Module):
- """
- A class representing a down transition module in a neural network.
-
- Attributes:
- in_channel (int): The number of input channels.
- depth (int): The depth of the down transition module.
- act (nn.Module): The activation function used in the module.
- """
-
- def __init__(self, in_channel, depth, act):
- """
- Initialize a DownTransition object.
-
- Args:
- in_channel (int): The number of channels in the input.
- depth (int): The depth of the DownTransition.
- act (function): The activation function.
-
- Returns:
- None
-
- Raises:
- None
- """
- super(DownTransition, self).__init__()
- self.ops = _make_nConv(in_channel, depth, act)
- self.maxpool = nn.MaxPool3d(2)
- self.current_depth = depth
-
- def forward(self, x):
- """
- Perform a forward pass through the neural network.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
- Raises:
- None
- """
- if self.current_depth == 3:
- out = self.ops(x)
- out_before_pool = out
- else:
- out_before_pool = self.ops(x)
- out = self.maxpool(out_before_pool)
- return out, out_before_pool
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x) ‑> Callable[..., Any]
-
Perform a forward pass through the neural network.
-x
: Tensor
tuple
None
def forward(self, x):
- """
- Perform a forward pass through the neural network.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
- Raises:
- None
- """
- if self.current_depth == 3:
- out = self.ops(x)
- out_before_pool = out
- else:
- out_before_pool = self.ops(x)
- out = self.maxpool(out_before_pool)
- return out, out_before_pool
-
-class LUConv
-(in_chan, out_chan, act)
-
A class representing a LUConv module.
-This module performs a convolution operation on the input data with a specified number of input channels and output channels. -The convolution is followed by batch normalization and an activation function.
-in_chan
: int
out_chan
: int
act
: str
Initialize a LUConv layer.
-in_chan
: int
out_chan
: int
act
: str
None
-TypeError
class LUConv(nn.Module):
- """
- A class representing a LUConv module.
-
- This module performs a convolution operation on the input data with a specified number of input channels and output channels.
- The convolution is followed by batch normalization and an activation function.
-
- Attributes:
- in_chan (int): The number of input channels.
- out_chan (int): The number of output channels.
- act (str): The activation function to be applied. Can be one of 'relu', 'prelu', or 'elu'.
- """
-
- def __init__(self, in_chan, out_chan, act):
- """
- Initialize a LUConv layer.
-
- Args:
- in_chan (int): Number of input channels.
- out_chan (int): Number of output channels.
- act (str): Activation function. Options: 'relu', 'prelu', 'elu'.
-
- Returns:
- None
-
- Raises:
- TypeError: If the activation function is not one of the specified options.
- """
- super(LUConv, self).__init__()
- self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
- self.bn1 = ContBatchNorm3d(out_chan)
-
- if act == "relu":
- self.activation = nn.ReLU(out_chan)
- elif act == "prelu":
- self.activation = nn.PReLU(out_chan)
- elif act == "elu":
- self.activation = nn.ELU(inplace=True)
- else:
- raise
-
- def forward(self, x):
- """
- Apply forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor after passing through the network.
- """
- out = self.activation(self.bn1(self.conv1(x)))
- return out
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x) ‑> Callable[..., Any]
-
Apply forward pass through the neural network.
-x
: Tensor
Tensor
def forward(self, x):
- """
- Apply forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor after passing through the network.
- """
- out = self.activation(self.bn1(self.conv1(x)))
- return out
-
-class OutputTransition
-(inChans, n_labels)
-
A class representing the output transition in a neural network.
-inChans
: int
n_labels
: int
Initialize the OutputTransition class.
-inChans
: int
n_labels
: int
None
-None
class OutputTransition(nn.Module):
- """
- A class representing the output transition in a neural network.
-
- Attributes:
- inChans (int): The number of input channels.
- n_labels (int): The number of output labels.
- """
-
- def __init__(self, inChans, n_labels):
- """
- Initialize the OutputTransition class.
-
- Args:
- inChans (int): Number of input channels.
- n_labels (int): Number of output labels.
-
- Returns:
- None
-
- Raises:
- None
- """
- super(OutputTransition, self).__init__()
- self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- """
- Forward pass through a neural network model.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- Tensor: The output tensor after passing through the model.
- """
- out = self.sigmoid(self.final_conv(x))
- return out
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x) ‑> Callable[..., Any]
-
Forward pass through a neural network model.
-x
: Tensor
Tensor
def forward(self, x):
- """
- Forward pass through a neural network model.
-
- Args:
- x (Tensor): The input tensor.
-
- Returns:
- Tensor: The output tensor after passing through the model.
- """
- out = self.sigmoid(self.final_conv(x))
- return out
-
-class UNet3D
-(n_class=1, act='relu', decoder=True)
-
A class representing a 3D UNet model for segmentation.
-n_class
: int
act
: str
decoder
: bool
forward(x): Forward pass of the model.
-Initialize a 3D UNet neural network model.
-n_class
: int
act
: str
decoder
: bool
decoder
: bool
down_tr64
: DownTransition
down_tr128
: DownTransition
down_tr256
: DownTransition
down_tr512
: DownTransition
up_tr256
: UpTransition
decoder
is True)up_tr128
: UpTransition
decoder
is True)up_tr64
: UpTransition
decoder
is True)out_tr
: OutputTransition
decoder
is True)avg_pool
: nn.AvgPool3d
decoder
is False)flatten
: nn.Flatten
decoder
is False)class UNet3D(nn.Module):
- # the number of convolutions in each layer corresponds
- # to what is in the actual prototxt, not the intent
- """
- A class representing a 3D UNet model for segmentation.
-
- Attributes:
- n_class (int): The number of classes for segmentation.
- act (str): The activation function type used in the model.
- decoder (bool): Whether to include the decoder part in the model.
-
- Methods:
- forward(x): Forward pass of the model.
- """
-
- def __init__(self, n_class=1, act="relu", decoder=True):
- """
- Initialize a 3D UNet neural network model.
-
- Args:
- n_class (int): The number of output classes. Defaults to 1.
- act (str): The activation function to use. Defaults to 'relu'.
- decoder (bool): Whether to include the decoder layers. Defaults to True.
-
- Attributes:
- decoder (bool): Whether the model includes decoder layers.
- down_tr64 (DownTransition): The first down transition layer.
- down_tr128 (DownTransition): The second down transition layer.
- down_tr256 (DownTransition): The third down transition layer.
- down_tr512 (DownTransition): The fourth down transition layer.
- up_tr256 (UpTransition): The first up transition layer. (Only exists if `decoder` is True)
- up_tr128 (UpTransition): The second up transition layer. (Only exists if `decoder` is True)
- up_tr64 (UpTransition): The third up transition layer. (Only exists if `decoder` is True)
- out_tr (OutputTransition): The output transition layer. (Only exists if `decoder` is True)
- avg_pool (nn.AvgPool3d): The average pooling layer. (Only exists if `decoder` is False)
- flatten (nn.Flatten): The flattening layer. (Only exists if `decoder` is False)
- """
- super(UNet3D, self).__init__()
-
- self.decoder = decoder
-
- self.down_tr64 = DownTransition(1, 0, act)
- self.down_tr128 = DownTransition(64, 1, act)
- self.down_tr256 = DownTransition(128, 2, act)
- self.down_tr512 = DownTransition(256, 3, act)
-
- if self.decoder:
- self.up_tr256 = UpTransition(512, 512, 2, act)
- self.up_tr128 = UpTransition(256, 256, 1, act)
- self.up_tr64 = UpTransition(128, 128, 0, act)
- self.out_tr = OutputTransition(64, n_class)
- else:
- self.avg_pool = nn.AvgPool3d(3, stride=2)
- self.flatten = nn.Flatten()
-
- def forward(self, x):
- """
- Perform forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor from the network.
-
- Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
- Raises:
- None.
- """
- self.out64, self.skip_out64 = self.down_tr64(x)
- self.out128, self.skip_out128 = self.down_tr128(self.out64)
- self.out256, self.skip_out256 = self.down_tr256(self.out128)
- self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
- if self.decoder:
- self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
- self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
- self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
- self.out = self.out_tr(self.out_up_64)
- else:
- self.out = self.avg_pool(self.out512)
- self.out = self.flatten(self.out)
-
- return self.out
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x) ‑> Callable[..., Any]
-
Perform forward pass through the neural network.
-x
: Tensor
Tensor
Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-None.
def forward(self, x):
- """
- Perform forward pass through the neural network.
-
- Args:
- x (Tensor): Input tensor to the network.
-
- Returns:
- Tensor: Output tensor from the network.
-
- Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
- Raises:
- None.
- """
- self.out64, self.skip_out64 = self.down_tr64(x)
- self.out128, self.skip_out128 = self.down_tr128(self.out64)
- self.out256, self.skip_out256 = self.down_tr256(self.out128)
- self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
- if self.decoder:
- self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
- self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
- self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
- self.out = self.out_tr(self.out_up_64)
- else:
- self.out = self.avg_pool(self.out512)
- self.out = self.flatten(self.out)
-
- return self.out
-
-class UpTransition
-(inChans, outChans, depth, act)
-
A class representing an up transition layer in a neural network.
-inChans
: int
outChans
: int
depth
: int
act
: str
Initialize the UpTransition module.
-inChans
: int
outChans
: int
depth
: int
act
: nn.Module
None.
-None.
class UpTransition(nn.Module):
- """
- A class representing an up transition layer in a neural network.
-
- Attributes:
- inChans (int): The number of input channels.
- outChans (int): The number of output channels.
- depth (int): The depth of the layer.
- act (str): The activation function to be applied.
- """
-
- def __init__(self, inChans, outChans, depth, act):
- """
- Initialize the UpTransition module.
-
- Args:
- inChans (int): The number of input channels.
- outChans (int): The number of output channels.
- depth (int): The depth of the module.
- act (nn.Module): The activation function to be used.
-
- Returns:
- None.
-
- Raises:
- None.
- """
- super(UpTransition, self).__init__()
- self.depth = depth
- self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
- self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True)
-
- def forward(self, x, skip_x):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): Input tensor.
- skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
- Returns:
- torch.Tensor: The output tensor after passing through the network.
- """
- out_up_conv = self.up_conv(x)
- concat = torch.cat((out_up_conv, skip_x), 1)
- out = self.ops(concat)
- return out
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x, skip_x) ‑> Callable[..., Any]
-
Forward pass of the neural network.
-x
: torch.Tensor
skip_x
: torch.Tensor
torch.Tensor
def forward(self, x, skip_x):
- """
- Forward pass of the neural network.
-
- Args:
- x (torch.Tensor): Input tensor.
- skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
- Returns:
- torch.Tensor: The output tensor after passing through the network.
- """
- out_up_conv = self.up_conv(x)
- concat = torch.cat((out_up_conv, skip_x), 1)
- out = self.ops(concat)
- return out
-fmcib.models.resnet50
import os
-from pathlib import Path
-
-import torch
-import tqdm
-import wget
-from loguru import logger
-from monai.networks.nets import resnet50 as resnet50_monai
-
-from fmcib.utils.download_utils import bar_progress
-
-
-def resnet50(
- pretrained=True,
- device="cuda",
- weights_path=None,
- download_url="https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1",
-):
- """
- Constructs a ResNet-50 model for image classification.
-
- Args:
- pretrained (bool, optional): If True, loads the pretrained weights. Default is True.
- device (str, optional): The device to load the model on. Default is "cuda".
- weights_path (str or Path, optional): The path to the pretrained weights file. If None, the weights will be downloaded. Default is None.
- download_url (str, optional): The URL to download the pretrained weights. Default is "https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1".
-
- Returns:
- torch.nn.Module: The ResNet-50 model.
- """
- logger.info(f"Loading pretrained foundation model (Resnet50) on {device}...")
-
- model = resnet50_monai(pretrained=False, n_input_channels=1, widen_factor=2, conv1_t_stride=2, feed_forward=False)
- model = model.to(device)
- if pretrained:
- if weights_path is None:
- current_path = Path(os.getcwd())
- if not (current_path / "fmcib.torch").exists():
- wget.download(download_url, bar=bar_progress)
- weights_path = current_path / "fmcib.torch"
-
- checkpoint = torch.load(weights_path, map_location=device)
-
- if "trunk_state_dict" in checkpoint:
- model_state_dict = checkpoint["trunk_state_dict"]
- elif "state_dict" in checkpoint:
- model_state_dict = checkpoint["state_dict"]
- model_state_dict = {key.replace("model.backbone.", ""): value for key, value in model_state_dict.items()}
-
- model.load_state_dict(model_state_dict, strict=False)
-
- return model
-
-def resnet50(pretrained=True, device='cuda', weights_path=None, download_url='https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1')
-
Constructs a ResNet-50 model for image classification.
-pretrained
: bool
, optionaldevice
: str
, optionalweights_path
: str
or Path
, optionaldownload_url
: str
, optionaltorch.nn.Module
def resnet50(
- pretrained=True,
- device="cuda",
- weights_path=None,
- download_url="https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1",
-):
- """
- Constructs a ResNet-50 model for image classification.
-
- Args:
- pretrained (bool, optional): If True, loads the pretrained weights. Default is True.
- device (str, optional): The device to load the model on. Default is "cuda".
- weights_path (str or Path, optional): The path to the pretrained weights file. If None, the weights will be downloaded. Default is None.
- download_url (str, optional): The URL to download the pretrained weights. Default is "https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1".
-
- Returns:
- torch.nn.Module: The ResNet-50 model.
- """
- logger.info(f"Loading pretrained foundation model (Resnet50) on {device}...")
-
- model = resnet50_monai(pretrained=False, n_input_channels=1, widen_factor=2, conv1_t_stride=2, feed_forward=False)
- model = model.to(device)
- if pretrained:
- if weights_path is None:
- current_path = Path(os.getcwd())
- if not (current_path / "fmcib.torch").exists():
- wget.download(download_url, bar=bar_progress)
- weights_path = current_path / "fmcib.torch"
-
- checkpoint = torch.load(weights_path, map_location=device)
-
- if "trunk_state_dict" in checkpoint:
- model_state_dict = checkpoint["trunk_state_dict"]
- elif "state_dict" in checkpoint:
- model_state_dict = checkpoint["state_dict"]
- model_state_dict = {key.replace("model.backbone.", ""): value for key, value in model_state_dict.items()}
-
- model.load_state_dict(model_state_dict, strict=False)
-
- return model
-fmcib.optimizers
from .lars import LARS
-fmcib.optimizers.lars
fmcib.optimizers.lars
"""
-References:
- - https://arxiv.org/pdf/1708.03888.pdf
- - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py
-"""
-import torch
-from torch.optim.optimizer import Optimizer, required
-
-
-class LARS(Optimizer):
- """Extends SGD in PyTorch with LARS scaling from the paper
- `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float): learning rate
- momentum (float, optional): momentum factor (default: 0)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- dampening (float, optional): dampening for momentum (default: 0)
- nesterov (bool, optional): enables Nesterov momentum (default: False)
- trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
- eps (float, optional): eps for division denominator (default: 1e-8)
-
- Example:
- >>> model = torch.nn.Linear(10, 1)
- >>> input = torch.Tensor(10)
- >>> target = torch.Tensor([1.])
- >>> loss_fn = lambda input, target: (input - target) ** 2
- >>> #
- >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
-
- .. note::
- The application of momentum in the SGD part is modified according to
- the PyTorch standards. LARS scaling fits into the equation in the
- following fashion.
-
- .. math::
- \begin{aligned}
- g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
- v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
- p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
- \\end{aligned}
-
- where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
- parameters, gradient, velocity, momentum, and weight decay respectively.
- The :math:`lars_lr` is defined by Eq. 6 in the paper.
- The Nesterov version is analogously modified.
-
- .. warning::
- Parameters with weight decay set to 0 will automatically be excluded from
- layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
- and BYOL.
- """
-
- def __init__(
- self,
- params,
- lr=required,
- momentum=0,
- dampening=0,
- weight_decay=0,
- nesterov=False,
- trust_coefficient=0.001,
- eps=1e-8,
- ):
- """
- Initialize an optimizer with the given parameters.
-
- Args:
- params (iterable): Iterable of parameters to optimize.
- lr (float, optional): Learning rate. Default is required.
- momentum (float, optional): Momentum factor. Default is 0.
- dampening (float, optional): Dampening for momentum. Default is 0.
- weight_decay (float, optional): Weight decay factor. Default is 0.
- nesterov (bool, optional): Use Nesterov momentum. Default is False.
- trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
- eps (float, optional): Small value for numerical stability. Default is 1e-08.
-
- Raises:
- ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
- ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
- """
- if lr is not required and lr < 0.0:
- raise ValueError(f"Invalid learning rate: {lr}")
- if momentum < 0.0:
- raise ValueError(f"Invalid momentum value: {momentum}")
- if weight_decay < 0.0:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
- defaults = dict(
- lr=lr,
- momentum=momentum,
- dampening=dampening,
- weight_decay=weight_decay,
- nesterov=nesterov,
- trust_coefficient=trust_coefficient,
- eps=eps,
- )
- if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
- super().__init__(params, defaults)
-
- def __setstate__(self, state):
- """
- Set the state of the optimizer.
-
- Args:
- state (dict): A dictionary containing the state of the optimizer.
-
- Returns:
- None
-
- Note:
- This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
- """
- super().__setstate__(state)
-
- for group in self.param_groups:
- group.setdefault("nesterov", False)
-
- @torch.no_grad()
- def step(self, closure=None):
- """
- Performs a single optimization step.
-
- Parameters:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-
-class LARS
-(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)
-
Extends SGD in PyTorch with LARS scaling from the paper
-Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>
_.
params
: iterable
lr
: float
momentum
: float
, optionalweight_decay
: float
, optionaldampening
: float
, optionalnesterov
: bool
, optionaltrust_coefficient
: float
, optionaleps
: float
, optional>>> model = torch.nn.Linear(10, 1)
->>> input = torch.Tensor(10)
->>> target = torch.Tensor([1.])
->>> loss_fn = lambda input, target: (input - target) ** 2
->>> #
->>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
->>> optimizer.zero_grad()
->>> loss_fn(model(input), target).backward()
->>> optimizer.step()
-
-Note
-The application of momentum in the SGD part is modified according to -the PyTorch standards. LARS scaling fits into the equation in the -following fashion.
-[ egin{aligned}
-g_{t+1} & =
-ext{lars_lr} * (eta * p_{t} + g_{t+1}), \
-v_{t+1} & = \mu * v_{t} + g_{t+1}, \
-p_{t+1} & = p_{t} -
-ext{lr} * v_{t+1},
-\end{aligned} ]
-where :math:p
, :math:g
, :math:v
, :math:\mu
and :math:eta
denote the
-parameters, gradient, velocity, momentum, and weight decay respectively.
-The :math:lars_lr
is defined by Eq. 6 in the paper.
-The Nesterov version is analogously modified.
Warning
-Parameters with weight decay set to 0 will automatically be excluded from -layer-wise LR scaling. This is to ensure consistency with papers like SimCLR -and BYOL.
-Initialize an optimizer with the given parameters.
-params
: iterable
lr
: float
, optionalmomentum
: float
, optionaldampening
: float
, optionalweight_decay
: float
, optionalnesterov
: bool
, optionaltrust_coefficient
: float
, optionaleps
: float
, optionalValueError
ValueError
class LARS(Optimizer):
- """Extends SGD in PyTorch with LARS scaling from the paper
- `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float): learning rate
- momentum (float, optional): momentum factor (default: 0)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- dampening (float, optional): dampening for momentum (default: 0)
- nesterov (bool, optional): enables Nesterov momentum (default: False)
- trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
- eps (float, optional): eps for division denominator (default: 1e-8)
-
- Example:
- >>> model = torch.nn.Linear(10, 1)
- >>> input = torch.Tensor(10)
- >>> target = torch.Tensor([1.])
- >>> loss_fn = lambda input, target: (input - target) ** 2
- >>> #
- >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
-
- .. note::
- The application of momentum in the SGD part is modified according to
- the PyTorch standards. LARS scaling fits into the equation in the
- following fashion.
-
- .. math::
- \begin{aligned}
- g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
- v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
- p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
- \\end{aligned}
-
- where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
- parameters, gradient, velocity, momentum, and weight decay respectively.
- The :math:`lars_lr` is defined by Eq. 6 in the paper.
- The Nesterov version is analogously modified.
-
- .. warning::
- Parameters with weight decay set to 0 will automatically be excluded from
- layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
- and BYOL.
- """
-
- def __init__(
- self,
- params,
- lr=required,
- momentum=0,
- dampening=0,
- weight_decay=0,
- nesterov=False,
- trust_coefficient=0.001,
- eps=1e-8,
- ):
- """
- Initialize an optimizer with the given parameters.
-
- Args:
- params (iterable): Iterable of parameters to optimize.
- lr (float, optional): Learning rate. Default is required.
- momentum (float, optional): Momentum factor. Default is 0.
- dampening (float, optional): Dampening for momentum. Default is 0.
- weight_decay (float, optional): Weight decay factor. Default is 0.
- nesterov (bool, optional): Use Nesterov momentum. Default is False.
- trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
- eps (float, optional): Small value for numerical stability. Default is 1e-08.
-
- Raises:
- ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
- ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
- """
- if lr is not required and lr < 0.0:
- raise ValueError(f"Invalid learning rate: {lr}")
- if momentum < 0.0:
- raise ValueError(f"Invalid momentum value: {momentum}")
- if weight_decay < 0.0:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
- defaults = dict(
- lr=lr,
- momentum=momentum,
- dampening=dampening,
- weight_decay=weight_decay,
- nesterov=nesterov,
- trust_coefficient=trust_coefficient,
- eps=eps,
- )
- if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
- super().__init__(params, defaults)
-
- def __setstate__(self, state):
- """
- Set the state of the optimizer.
-
- Args:
- state (dict): A dictionary containing the state of the optimizer.
-
- Returns:
- None
-
- Note:
- This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
- """
- super().__setstate__(state)
-
- for group in self.param_groups:
- group.setdefault("nesterov", False)
-
- @torch.no_grad()
- def step(self, closure=None):
- """
- Performs a single optimization step.
-
- Parameters:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-var OptimizerPostHook : typing_extensions.TypeAlias
var OptimizerPreHook : typing_extensions.TypeAlias
-def step(self, closure=None)
-
Performs a single optimization step.
-closure (callable, optional): A closure that reevaluates the model -and returns the loss.
@torch.no_grad()
-def step(self, closure=None):
- """
- Performs a single optimization step.
-
- Parameters:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-fmcib.preprocessing
import monai
-import torchvision
-from loguru import logger
-from monai import transforms as monai_transforms
-
-from .seed_based_crop import SeedBasedPatchCropd
-
-
-def preprocess(image, spatial_size=(50, 50, 50)):
- T = get_transforms(spatial_size=spatial_size)
- return T(image)
-
-
-def get_transforms(spatial_size=(50, 50, 50), precropped=False):
- if precropped:
- return monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.ScaleIntensityRanged(
- keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
- ),
- monai_transforms.SelectItemsd(keys=["image_path"]),
- monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
- torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
- ]
- )
- else:
- return monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
- monai_transforms.ScaleIntensityRanged(
- keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
- ),
- monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
- SeedBasedPatchCropd(
- keys=["image_path"], roi_size=spatial_size[::-1], coord_orientation="LPS", global_coordinates=True
- ),
- monai_transforms.SelectItemsd(keys=["image_path"]),
- monai_transforms.Transposed(keys=["image_path"], indices=(0, 3, 2, 1)),
- monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
- torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
- ]
- )
-
-
-def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False):
- logger.info("Building dataloader instance ...")
- T = get_transforms(spatial_size=spatial_size, precropped=precropped)
- dataset = monai.data.CSVDataset(csv_path, transform=T)
- dataloader = monai.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
- return dataloader
-fmcib.preprocessing.seed_based_crop
Author: Suraj Pai -Email: bspai@bwh.harvard.edu -This script contains two classes: -1. SeedBasedPatchCropd -2. SeedBasedPatchCrop
-def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False)
-
def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False):
- logger.info("Building dataloader instance ...")
- T = get_transforms(spatial_size=spatial_size, precropped=precropped)
- dataset = monai.data.CSVDataset(csv_path, transform=T)
- dataloader = monai.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
- return dataloader
-
-def get_transforms(spatial_size=(50, 50, 50), precropped=False)
-
def get_transforms(spatial_size=(50, 50, 50), precropped=False):
- if precropped:
- return monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.ScaleIntensityRanged(
- keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
- ),
- monai_transforms.SelectItemsd(keys=["image_path"]),
- monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
- torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
- ]
- )
- else:
- return monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
- monai_transforms.ScaleIntensityRanged(
- keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
- ),
- monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
- SeedBasedPatchCropd(
- keys=["image_path"], roi_size=spatial_size[::-1], coord_orientation="LPS", global_coordinates=True
- ),
- monai_transforms.SelectItemsd(keys=["image_path"]),
- monai_transforms.Transposed(keys=["image_path"], indices=(0, 3, 2, 1)),
- monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
- torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
- ]
- )
-
-def preprocess(image, spatial_size=(50, 50, 50))
-
def preprocess(image, spatial_size=(50, 50, 50)):
- T = get_transforms(spatial_size=spatial_size)
- return T(image)
-fmcib.preprocessing.seed_based_crop
Author: Suraj Pai -Email: bspai@bwh.harvard.edu -This script contains two classes: -1. SeedBasedPatchCropd -2. SeedBasedPatchCrop
-"""
-Author: Suraj Pai
-Email: bspai@bwh.harvard.edu
-This script contains two classes:
-1. SeedBasedPatchCropd
-2. SeedBasedPatchCrop
-"""
-
-from typing import Any, Dict, Hashable, Mapping, Tuple
-
-import numpy as np
-from monai.config.type_definitions import NdarrayOrTensor
-from monai.transforms import MapTransform, Transform
-
-
-class SeedBasedPatchCropd(MapTransform):
- """
- A class representing a seed-based patch crop transformation.
-
- Inherits from MapTransform.
-
- Attributes:
- keys (list): List of keys for images in the input data dictionary.
- roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
- allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
- coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
- global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
- """
-
- def __init__(self, keys, roi_size, allow_missing_keys=False, coord_orientation="RAS", global_coordinates=True) -> None:
- """
- Initialize SeedBasedPatchCropd class.
-
- Args:
- keys (List): List of keys for images in the input data dictionary.
- roi_size (Tuple): Tuple indicating the size of the region of interest (ROI).
- allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
- coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
- global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
- """
- super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
- self.coord_orientation = coord_orientation
- self.global_coordinates = global_coordinates
- self.cropper = SeedBasedPatchCrop(roi_size=roi_size)
-
- def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
- """
- Apply transformation to given data.
-
- Args:
- data (dict): Dictionary with image keys and required center coordinates.
-
- Returns:
- dict: Dictionary with cropped patches for each key in the input data dictionary.
- """
- d = dict(data)
-
- assert "coordX" in d.keys(), "coordX not found in data"
- assert "coordY" in d.keys(), "coordY not found in data"
- assert "coordZ" in d.keys(), "coordZ not found in data"
-
- # Convert coordinates to RAS orientation to match image orientation
- if self.coord_orientation == "RAS":
- center = (d["coordX"], d["coordY"], d["coordZ"])
- elif self.coord_orientation == "LPS":
- center = (-d["coordX"], -d["coordY"], d["coordZ"])
-
- for key in self.key_iterator(d):
- d[key] = self.cropper(d[key], center=center, global_coordinates=self.global_coordinates)
- return d
-
-
-class SeedBasedPatchCrop(Transform):
- """
- A class representing a seed-based patch crop transformation.
-
- Attributes:
- roi_size: Tuple indicating the size of the region of interest (ROI).
-
- Methods:
- __call__: Crop a patch from the input image centered around the seed coordinate.
-
- Args:
- roi_size: Tuple indicating the size of the region of interest (ROI).
-
- Returns:
- NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-
- Raises:
- AssertionError: If the input image has dimensions other than (C, H, W, D)
- AssertionError: If the coordinates are invalid (e.g., min_h >= max_h)
- """
-
- def __init__(self, roi_size) -> None:
- """
- Initialize SeedBasedPatchCrop class.
-
- Args:
- roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
- """
- super().__init__()
- self.roi_size = roi_size
-
- def __call__(self, img: NdarrayOrTensor, center: tuple, global_coordinates=False) -> NdarrayOrTensor:
- """
- Crop a patch from the input image centered around the seed coordinate.
-
- Args:
- img (NdarrayOrTensor): Image to crop, with dimensions (C, H, W, D). C is the number of channels.
- center (tuple): Seed coordinate around which to crop the patch (X, Y, Z).
- global_coordinates (bool): If True, seed coordinate is in global space; otherwise, local space.
-
- Returns:
- NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
- """
- assert len(img.shape) == 4, "Input image must have dimensions: (C, H, W, D)"
- C, H, W, D = img.shape
- Ph, Pw, Pd = self.roi_size
-
- # If global coordinates, convert to local coordinates
- if global_coordinates:
- center = np.linalg.inv(np.array(img.affine)) @ np.array(center + (1,))
- center = [int(x) for x in center[:3]]
-
- # Calculate and clamp the ranges for cropping
- min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
- min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
- min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
- # Check if coordinates are valid
- assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
- assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
- assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
- # Crop the patch from the image
- patch = img[:, min_h:max_h, min_w:max_w, min_d:max_d]
-
- return patch
-
-class SeedBasedPatchCrop
-(roi_size)
-
A class representing a seed-based patch crop transformation.
-roi_size
call: Crop a patch from the input image centered around the seed coordinate.
-roi_size
NdarrayOrTensor
AssertionError
AssertionError
Initialize SeedBasedPatchCrop class.
-roi_size
: tuple
class SeedBasedPatchCrop(Transform):
- """
- A class representing a seed-based patch crop transformation.
-
- Attributes:
- roi_size: Tuple indicating the size of the region of interest (ROI).
-
- Methods:
- __call__: Crop a patch from the input image centered around the seed coordinate.
-
- Args:
- roi_size: Tuple indicating the size of the region of interest (ROI).
-
- Returns:
- NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-
- Raises:
- AssertionError: If the input image has dimensions other than (C, H, W, D)
- AssertionError: If the coordinates are invalid (e.g., min_h >= max_h)
- """
-
- def __init__(self, roi_size) -> None:
- """
- Initialize SeedBasedPatchCrop class.
-
- Args:
- roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
- """
- super().__init__()
- self.roi_size = roi_size
-
- def __call__(self, img: NdarrayOrTensor, center: tuple, global_coordinates=False) -> NdarrayOrTensor:
- """
- Crop a patch from the input image centered around the seed coordinate.
-
- Args:
- img (NdarrayOrTensor): Image to crop, with dimensions (C, H, W, D). C is the number of channels.
- center (tuple): Seed coordinate around which to crop the patch (X, Y, Z).
- global_coordinates (bool): If True, seed coordinate is in global space; otherwise, local space.
-
- Returns:
- NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
- """
- assert len(img.shape) == 4, "Input image must have dimensions: (C, H, W, D)"
- C, H, W, D = img.shape
- Ph, Pw, Pd = self.roi_size
-
- # If global coordinates, convert to local coordinates
- if global_coordinates:
- center = np.linalg.inv(np.array(img.affine)) @ np.array(center + (1,))
- center = [int(x) for x in center[:3]]
-
- # Calculate and clamp the ranges for cropping
- min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
- min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
- min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
- # Check if coordinates are valid
- assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
- assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
- assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
- # Crop the patch from the image
- patch = img[:, min_h:max_h, min_w:max_w, min_d:max_d]
-
- return patch
-var backend : list[TransformBackends]
-class SeedBasedPatchCropd
-(keys, roi_size, allow_missing_keys=False, coord_orientation='RAS', global_coordinates=True)
-
A class representing a seed-based patch crop transformation.
-Inherits from MapTransform.
-keys
: list
roi_size
: tuple
allow_missing_keys
: bool
coord_orientation
: str
global_coordinates
: bool
Initialize SeedBasedPatchCropd class.
-keys
: List
roi_size
: Tuple
allow_missing_keys
: bool
coord_orientation
: str
global_coordinates
: bool
class SeedBasedPatchCropd(MapTransform):
- """
- A class representing a seed-based patch crop transformation.
-
- Inherits from MapTransform.
-
- Attributes:
- keys (list): List of keys for images in the input data dictionary.
- roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
- allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
- coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
- global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
- """
-
- def __init__(self, keys, roi_size, allow_missing_keys=False, coord_orientation="RAS", global_coordinates=True) -> None:
- """
- Initialize SeedBasedPatchCropd class.
-
- Args:
- keys (List): List of keys for images in the input data dictionary.
- roi_size (Tuple): Tuple indicating the size of the region of interest (ROI).
- allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
- coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
- global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
- """
- super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
- self.coord_orientation = coord_orientation
- self.global_coordinates = global_coordinates
- self.cropper = SeedBasedPatchCrop(roi_size=roi_size)
-
- def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
- """
- Apply transformation to given data.
-
- Args:
- data (dict): Dictionary with image keys and required center coordinates.
-
- Returns:
- dict: Dictionary with cropped patches for each key in the input data dictionary.
- """
- d = dict(data)
-
- assert "coordX" in d.keys(), "coordX not found in data"
- assert "coordY" in d.keys(), "coordY not found in data"
- assert "coordZ" in d.keys(), "coordZ not found in data"
-
- # Convert coordinates to RAS orientation to match image orientation
- if self.coord_orientation == "RAS":
- center = (d["coordX"], d["coordY"], d["coordZ"])
- elif self.coord_orientation == "LPS":
- center = (-d["coordX"], -d["coordY"], d["coordZ"])
-
- for key in self.key_iterator(d):
- d[key] = self.cropper(d[key], center=center, global_coordinates=self.global_coordinates)
- return d
-var backend : list[TransformBackends]
fmcib.run
import numpy as np
-import pandas as pd
-import torch
-from loguru import logger
-from monai.networks.nets import resnet50
-from tqdm import tqdm
-
-from .models import LoadModel, fmcib_model
-from .preprocessing import get_dataloader
-
-
-def get_features(
- csv_path,
- weights_path=None,
- spatial_size=(50, 50, 50),
- precropped=False,
-):
- """
- Extracts features from images specified in a CSV file.
-
- Args:
- csv_path (str): Path to the CSV file containing image paths.
- weights_path (str, optional): Path to the pre-trained weights file. Default is None.
- spatial_size (tuple, optional): Spatial size of the input images. Default is (50, 50, 50).
- precropped (bool, optional): Whether the images are already pre-cropped. Default is False.
-
- Returns:
- pandas.DataFrame: DataFrame containing the original data from the CSV file along with the extracted features.
- """
- logger.info("Loading CSV file ...")
- df = pd.read_csv(csv_path)
- dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped)
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- if weights_path is None:
- model = fmcib_model().to(device)
- else:
- logger.warning(
- "Loading custom model provided from weights file. If this is not intended, please do not provide the weights_path argument."
- )
- trunk = resnet50(
- pretrained=False,
- n_input_channels=1,
- widen_factor=2,
- conv1_t_stride=2,
- feed_forward=False,
- bias_downsample=True,
- )
- model = LoadModel(trunk=trunk, weights_path=weights_path).to(device)
-
- feature_list = []
- logger.info("Running inference over batches ...")
-
- model.eval()
- for batch in tqdm(dataloader, total=len(dataloader)):
- feature = model(batch.to(device)).detach().cpu().numpy()
- feature_list.append(feature)
-
- features = np.concatenate(feature_list, axis=0)
- # Flatten features into a list
- features = features.reshape(-1, 4096)
-
- # Add the features to the dataframe
- df = pd.concat([df, pd.DataFrame(features, columns=[f"pred_{idx}" for idx in range(4096)])], axis=1)
- return df
-
-def get_features(csv_path, weights_path=None, spatial_size=(50, 50, 50), precropped=False)
-
Extracts features from images specified in a CSV file.
-csv_path
: str
weights_path
: str
, optionalspatial_size
: tuple
, optionalprecropped
: bool
, optionalpandas.DataFrame
def get_features(
- csv_path,
- weights_path=None,
- spatial_size=(50, 50, 50),
- precropped=False,
-):
- """
- Extracts features from images specified in a CSV file.
-
- Args:
- csv_path (str): Path to the CSV file containing image paths.
- weights_path (str, optional): Path to the pre-trained weights file. Default is None.
- spatial_size (tuple, optional): Spatial size of the input images. Default is (50, 50, 50).
- precropped (bool, optional): Whether the images are already pre-cropped. Default is False.
-
- Returns:
- pandas.DataFrame: DataFrame containing the original data from the CSV file along with the extracted features.
- """
- logger.info("Loading CSV file ...")
- df = pd.read_csv(csv_path)
- dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped)
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- if weights_path is None:
- model = fmcib_model().to(device)
- else:
- logger.warning(
- "Loading custom model provided from weights file. If this is not intended, please do not provide the weights_path argument."
- )
- trunk = resnet50(
- pretrained=False,
- n_input_channels=1,
- widen_factor=2,
- conv1_t_stride=2,
- feed_forward=False,
- bias_downsample=True,
- )
- model = LoadModel(trunk=trunk, weights_path=weights_path).to(device)
-
- feature_list = []
- logger.info("Running inference over batches ...")
-
- model.eval()
- for batch in tqdm(dataloader, total=len(dataloader)):
- feature = model(batch.to(device)).detach().cpu().numpy()
- feature_list.append(feature)
-
- features = np.concatenate(feature_list, axis=0)
- # Flatten features into a list
- features = features.reshape(-1, 4096)
-
- # Add the features to the dataframe
- df = pd.concat([df, pd.DataFrame(features, columns=[f"pred_{idx}" for idx in range(4096)])], axis=1)
- return df
-fmcib.ssl
fmcib.ssl.losses
fmcib.ssl.modules
fmcib.ssl.losses
from .neg_mining_info_nce_loss import NegativeMiningInfoNCECriterion
-from .nnclr_loss import NNCLRLoss
-from .ntxent_loss import NTXentLoss
-from .ntxent_mined_loss import NTXentNegativeMinedLoss
-from .swav_loss import SwaVLoss
-fmcib.ssl.losses.neg_mining_info_nce_loss
fmcib.ssl.losses.nnclr_loss
fmcib.ssl.losses.ntxent_loss
fmcib.ssl.losses.ntxent_mined_loss
Contrastive Loss Functions
fmcib.ssl.losses.swav_loss
fmcib.ssl.losses.neg_mining_info_nce_loss
# Copyright (c) Facebook, Inc. and its affiliates.
-
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import pprint
-
-import numpy as np
-import torch
-from lightly.utils import dist
-from loguru import logger
-from torch import nn
-
-
-class NegativeMiningInfoNCECriterion(nn.Module):
- """
- The criterion corresponding to the SimCLR loss as defined in the paper
- https://arxiv.org/abs/2002.05709.
-
- Args:
- temperature (float): The temperature to be applied on the logits.
- buffer_params (dict): A dictionary containing the following keys:
- - world_size (int): Total number of trainers in training.
- - embedding_dim (int): Output dimensions of the features projects.
- - effective_batch_size (int): Total batch size used (includes positives).
- """
-
- def __init__(
- self, embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True
- ):
- """
- Initialize the NegativeMiningInfoNCECriterion class.
-
- Args:
- embedding_dim (int): The dimension of the embedding space.
- batch_size (int): The size of the input batch.
- world_size (int): The number of distributed processes.
- gather_distributed (bool): Whether to gather distributed data.
- temperature (float): The temperature used in the computation.
- balanced (bool): Whether to use balanced sampling.
-
- Attributes:
- embedding_dim (int): The dimension of the embedding space.
- use_gpu (bool): Whether to use GPU for computations.
- temperature (float): The temperature used in the computation.
- num_pos (int): The number of positive samples.
- num_neg (int): The number of negative samples.
- criterion (nn.CrossEntropyLoss): The loss function.
- gather_distributed (bool): Whether to gather distributed data.
- world_size (int): The number of distributed processes.
- effective_batch_size (int): The effective batch size, taking into account world size and number of positive samples.
- pos_mask (None or Tensor): Mask for positive samples.
- neg_mask (None or Tensor): Mask for negative samples.
- balanced (bool): Whether to use balanced sampling.
- setup (bool): Whether the setup has been done.
- """
- super(NegativeMiningInfoNCECriterion, self).__init__()
- self.embedding_dim = embedding_dim
- self.use_gpu = torch.cuda.is_available()
- self.temperature = temperature
- self.num_pos = 2
-
- # Same number of negatives as positives are loaded
- self.num_neg = self.num_pos
- self.criterion = nn.CrossEntropyLoss()
- self.gather_distributed = gather_distributed
- self.world_size = world_size
- self.effective_batch_size = batch_size * self.world_size * self.num_pos
- self.pos_mask = None
- self.neg_mask = None
- self.balanced = balanced
- self.setup = False
-
- def precompute_pos_neg_mask(self):
- """
- Precompute the positive and negative masks to speed up the loss calculation.
- """
- # computed once at the begining of training
-
- # total_images is x2 SimCLR Info-NCE loss
- # as we have negative samples for each positive sample
-
- total_images = self.effective_batch_size * self.num_neg
- world_size = self.world_size
-
- # Batch size computation is different from SimCLR paper
- batch_size = self.effective_batch_size // world_size
- orig_images = batch_size // self.num_pos
- rank = dist.rank()
-
- pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
- neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
- all_indices = np.arange(total_images)
-
- # Index for pairs of images (original + copy)
- pairs = orig_images * np.arange(self.num_pos)
-
- # Remove all indices associated with positive samples & copies (for neg_mask)
- all_pos_members = []
- for _rank in range(world_size):
- all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
- all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
- # Index of original positive images
- orig_members = torch.arange(orig_images)
-
- for anchor in np.arange(self.num_pos):
- for img_idx in range(orig_images):
- # delete_inds are spaced by batch_size for each rank as
- # all_indices_pos_removed (half of the indices) is deleted first
- delete_inds = batch_size * rank + img_idx + pairs
- neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
- neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
- for pos in np.delete(np.arange(self.num_pos), anchor):
- # Pos_inds are spaced by batch_size * self.num_neg for each rank
- pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
- pos_mask[
- torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
- pos_inds.long(),
- ] = 1
-
- self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
- self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-
- def forward(self, out: torch.Tensor):
- """
- Calculate the loss. Operates on embeddings tensor.
- """
- if not self.setup:
- logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
- self.precompute_pos_neg_mask()
- self.setup = True
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
- embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
- embedding = nn.functional.normalize(embedding, dim=1, p=2)
- assert embedding.ndim == 2
- assert embedding.shape[1] == int(self.embedding_dim)
-
- batch_size = embedding.shape[0]
- T = self.temperature
- num_pos = self.num_pos
-
- assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
- assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
- # Step 1: gather all the embeddings. Shape example: 4096 x 128
- embeddings_buffer = self.gather_embeddings(embedding)
-
- # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
- # divide by temperature.
- similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
- pos = torch.sum(similarity * self.pos_mask, 1)
- neg = torch.sum(similarity * self.neg_mask, 1)
-
- # Ignore the negative samples as entries for loss calculation
- pos = pos[: (batch_size // 2)]
- neg = neg[: (batch_size // 2)]
-
- loss = -(torch.mean(torch.log(pos / (pos + neg))))
- return loss
-
- def __repr__(self):
- """
- Return a string representation of the object.
-
- Returns:
- str: A formatted string representation of the object.
-
- Examples:
- The following example shows the string representation of the object:
-
- {
- 'name': <object_name>,
- 'temperature': <temperature_value>,
- 'num_negatives': <num_negatives_value>,
- 'num_pos': <num_pos_value>,
- 'dist_rank': <dist_rank_value>
- }
-
- Note:
- This function is intended to be used with the pprint module for pretty printing.
- """
- num_negatives = self.effective_batch_size - 2
- T = self.temperature
- num_pos = self.num_pos
- repr_dict = {
- "name": self._get_name(),
- "temperature": T,
- "num_negatives": num_negatives,
- "num_pos": num_pos,
- "dist_rank": dist.rank(),
- }
- return pprint.pformat(repr_dict, indent=2)
-
- def gather_embeddings(self, embedding: torch.Tensor):
- """
- Do a gather over all embeddings, so we can compute the loss.
- Final shape is like: (batch_size * num_gpus) x embedding_dim
- """
- if self.gather_distributed:
- embedding_gathered = torch.cat(dist.gather(embedding), 0)
- else:
- embedding_gathered = embedding
- return embedding_gathered
-
-class NegativeMiningInfoNCECriterion
-(embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True)
-
The criterion corresponding to the SimCLR loss as defined in the paper -https://arxiv.org/abs/2002.05709.
-temperature
: float
buffer_params
: dict
Initialize the NegativeMiningInfoNCECriterion class.
-embedding_dim
: int
batch_size
: int
world_size
: int
gather_distributed
: bool
temperature
: float
balanced
: bool
embedding_dim
: int
use_gpu
: bool
temperature
: float
num_pos
: int
num_neg
: int
criterion
: nn.CrossEntropyLoss
gather_distributed
: bool
world_size
: int
effective_batch_size
: int
pos_mask
: None
or Tensor
neg_mask
: None
or Tensor
balanced
: bool
setup
: bool
class NegativeMiningInfoNCECriterion(nn.Module):
- """
- The criterion corresponding to the SimCLR loss as defined in the paper
- https://arxiv.org/abs/2002.05709.
-
- Args:
- temperature (float): The temperature to be applied on the logits.
- buffer_params (dict): A dictionary containing the following keys:
- - world_size (int): Total number of trainers in training.
- - embedding_dim (int): Output dimensions of the features projects.
- - effective_batch_size (int): Total batch size used (includes positives).
- """
-
- def __init__(
- self, embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True
- ):
- """
- Initialize the NegativeMiningInfoNCECriterion class.
-
- Args:
- embedding_dim (int): The dimension of the embedding space.
- batch_size (int): The size of the input batch.
- world_size (int): The number of distributed processes.
- gather_distributed (bool): Whether to gather distributed data.
- temperature (float): The temperature used in the computation.
- balanced (bool): Whether to use balanced sampling.
-
- Attributes:
- embedding_dim (int): The dimension of the embedding space.
- use_gpu (bool): Whether to use GPU for computations.
- temperature (float): The temperature used in the computation.
- num_pos (int): The number of positive samples.
- num_neg (int): The number of negative samples.
- criterion (nn.CrossEntropyLoss): The loss function.
- gather_distributed (bool): Whether to gather distributed data.
- world_size (int): The number of distributed processes.
- effective_batch_size (int): The effective batch size, taking into account world size and number of positive samples.
- pos_mask (None or Tensor): Mask for positive samples.
- neg_mask (None or Tensor): Mask for negative samples.
- balanced (bool): Whether to use balanced sampling.
- setup (bool): Whether the setup has been done.
- """
- super(NegativeMiningInfoNCECriterion, self).__init__()
- self.embedding_dim = embedding_dim
- self.use_gpu = torch.cuda.is_available()
- self.temperature = temperature
- self.num_pos = 2
-
- # Same number of negatives as positives are loaded
- self.num_neg = self.num_pos
- self.criterion = nn.CrossEntropyLoss()
- self.gather_distributed = gather_distributed
- self.world_size = world_size
- self.effective_batch_size = batch_size * self.world_size * self.num_pos
- self.pos_mask = None
- self.neg_mask = None
- self.balanced = balanced
- self.setup = False
-
- def precompute_pos_neg_mask(self):
- """
- Precompute the positive and negative masks to speed up the loss calculation.
- """
- # computed once at the begining of training
-
- # total_images is x2 SimCLR Info-NCE loss
- # as we have negative samples for each positive sample
-
- total_images = self.effective_batch_size * self.num_neg
- world_size = self.world_size
-
- # Batch size computation is different from SimCLR paper
- batch_size = self.effective_batch_size // world_size
- orig_images = batch_size // self.num_pos
- rank = dist.rank()
-
- pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
- neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
- all_indices = np.arange(total_images)
-
- # Index for pairs of images (original + copy)
- pairs = orig_images * np.arange(self.num_pos)
-
- # Remove all indices associated with positive samples & copies (for neg_mask)
- all_pos_members = []
- for _rank in range(world_size):
- all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
- all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
- # Index of original positive images
- orig_members = torch.arange(orig_images)
-
- for anchor in np.arange(self.num_pos):
- for img_idx in range(orig_images):
- # delete_inds are spaced by batch_size for each rank as
- # all_indices_pos_removed (half of the indices) is deleted first
- delete_inds = batch_size * rank + img_idx + pairs
- neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
- neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
- for pos in np.delete(np.arange(self.num_pos), anchor):
- # Pos_inds are spaced by batch_size * self.num_neg for each rank
- pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
- pos_mask[
- torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
- pos_inds.long(),
- ] = 1
-
- self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
- self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-
- def forward(self, out: torch.Tensor):
- """
- Calculate the loss. Operates on embeddings tensor.
- """
- if not self.setup:
- logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
- self.precompute_pos_neg_mask()
- self.setup = True
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
- embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
- embedding = nn.functional.normalize(embedding, dim=1, p=2)
- assert embedding.ndim == 2
- assert embedding.shape[1] == int(self.embedding_dim)
-
- batch_size = embedding.shape[0]
- T = self.temperature
- num_pos = self.num_pos
-
- assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
- assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
- # Step 1: gather all the embeddings. Shape example: 4096 x 128
- embeddings_buffer = self.gather_embeddings(embedding)
-
- # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
- # divide by temperature.
- similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
- pos = torch.sum(similarity * self.pos_mask, 1)
- neg = torch.sum(similarity * self.neg_mask, 1)
-
- # Ignore the negative samples as entries for loss calculation
- pos = pos[: (batch_size // 2)]
- neg = neg[: (batch_size // 2)]
-
- loss = -(torch.mean(torch.log(pos / (pos + neg))))
- return loss
-
- def __repr__(self):
- """
- Return a string representation of the object.
-
- Returns:
- str: A formatted string representation of the object.
-
- Examples:
- The following example shows the string representation of the object:
-
- {
- 'name': <object_name>,
- 'temperature': <temperature_value>,
- 'num_negatives': <num_negatives_value>,
- 'num_pos': <num_pos_value>,
- 'dist_rank': <dist_rank_value>
- }
-
- Note:
- This function is intended to be used with the pprint module for pretty printing.
- """
- num_negatives = self.effective_batch_size - 2
- T = self.temperature
- num_pos = self.num_pos
- repr_dict = {
- "name": self._get_name(),
- "temperature": T,
- "num_negatives": num_negatives,
- "num_pos": num_pos,
- "dist_rank": dist.rank(),
- }
- return pprint.pformat(repr_dict, indent=2)
-
- def gather_embeddings(self, embedding: torch.Tensor):
- """
- Do a gather over all embeddings, so we can compute the loss.
- Final shape is like: (batch_size * num_gpus) x embedding_dim
- """
- if self.gather_distributed:
- embedding_gathered = torch.cat(dist.gather(embedding), 0)
- else:
- embedding_gathered = embedding
- return embedding_gathered
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, out: torch.Tensor) ‑> Callable[..., Any]
-
Calculate the loss. Operates on embeddings tensor.
def forward(self, out: torch.Tensor):
- """
- Calculate the loss. Operates on embeddings tensor.
- """
- if not self.setup:
- logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
- self.precompute_pos_neg_mask()
- self.setup = True
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
- embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
- embedding = nn.functional.normalize(embedding, dim=1, p=2)
- assert embedding.ndim == 2
- assert embedding.shape[1] == int(self.embedding_dim)
-
- batch_size = embedding.shape[0]
- T = self.temperature
- num_pos = self.num_pos
-
- assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
- assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
- # Step 1: gather all the embeddings. Shape example: 4096 x 128
- embeddings_buffer = self.gather_embeddings(embedding)
-
- # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
- # divide by temperature.
- similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
- pos = torch.sum(similarity * self.pos_mask, 1)
- neg = torch.sum(similarity * self.neg_mask, 1)
-
- # Ignore the negative samples as entries for loss calculation
- pos = pos[: (batch_size // 2)]
- neg = neg[: (batch_size // 2)]
-
- loss = -(torch.mean(torch.log(pos / (pos + neg))))
- return loss
-
-def gather_embeddings(self, embedding: torch.Tensor)
-
Do a gather over all embeddings, so we can compute the loss. -Final shape is like: (batch_size * num_gpus) x embedding_dim
def gather_embeddings(self, embedding: torch.Tensor):
- """
- Do a gather over all embeddings, so we can compute the loss.
- Final shape is like: (batch_size * num_gpus) x embedding_dim
- """
- if self.gather_distributed:
- embedding_gathered = torch.cat(dist.gather(embedding), 0)
- else:
- embedding_gathered = embedding
- return embedding_gathered
-
-def precompute_pos_neg_mask(self)
-
Precompute the positive and negative masks to speed up the loss calculation.
def precompute_pos_neg_mask(self):
- """
- Precompute the positive and negative masks to speed up the loss calculation.
- """
- # computed once at the begining of training
-
- # total_images is x2 SimCLR Info-NCE loss
- # as we have negative samples for each positive sample
-
- total_images = self.effective_batch_size * self.num_neg
- world_size = self.world_size
-
- # Batch size computation is different from SimCLR paper
- batch_size = self.effective_batch_size // world_size
- orig_images = batch_size // self.num_pos
- rank = dist.rank()
-
- pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
- neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
- all_indices = np.arange(total_images)
-
- # Index for pairs of images (original + copy)
- pairs = orig_images * np.arange(self.num_pos)
-
- # Remove all indices associated with positive samples & copies (for neg_mask)
- all_pos_members = []
- for _rank in range(world_size):
- all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
- all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
- # Index of original positive images
- orig_members = torch.arange(orig_images)
-
- for anchor in np.arange(self.num_pos):
- for img_idx in range(orig_images):
- # delete_inds are spaced by batch_size for each rank as
- # all_indices_pos_removed (half of the indices) is deleted first
- delete_inds = batch_size * rank + img_idx + pairs
- neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
- neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
- for pos in np.delete(np.arange(self.num_pos), anchor):
- # Pos_inds are spaced by batch_size * self.num_neg for each rank
- pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
- pos_mask[
- torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
- pos_inds.long(),
- ] = 1
-
- self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
- self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-fmcib.ssl.losses.nnclr_loss
from lightly.loss import NTXentLoss
-
-
-class NNCLRLoss(NTXentLoss):
- """
- A class representing the NNCLRLoss.
-
- This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.
-
- Attributes:
- temperature (float): The temperature for the loss function. Default is 0.1.
- gather_distributed (bool): A flag indicating whether the distributed gathering is used. Default is False.
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize a new instance of the class.
-
- Args:
- temperature (float): The temperature to use for initialization. Default value is 0.1.
- gather_distributed (bool): Whether to use gather distributed mode. Default value is False.
- """
- super().__init__(temperature, gather_distributed)
-
- def forward(self, out):
- """
- Symmetric loss function for NNCLR.
- """
- (z0, p0), (z1, p1) = out
- loss0 = super().forward(z0, p0)
- loss1 = super().forward(z1, p1)
- return (loss0 + loss1) / 2
-
-class NNCLRLoss
-(temperature: float = 0.1, gather_distributed: bool = False)
-
A class representing the NNCLRLoss.
-This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.
-temperature
: float
gather_distributed
: bool
Initialize a new instance of the class.
-temperature
: float
gather_distributed
: bool
class NNCLRLoss(NTXentLoss):
- """
- A class representing the NNCLRLoss.
-
- This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.
-
- Attributes:
- temperature (float): The temperature for the loss function. Default is 0.1.
- gather_distributed (bool): A flag indicating whether the distributed gathering is used. Default is False.
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize a new instance of the class.
-
- Args:
- temperature (float): The temperature to use for initialization. Default value is 0.1.
- gather_distributed (bool): Whether to use gather distributed mode. Default value is False.
- """
- super().__init__(temperature, gather_distributed)
-
- def forward(self, out):
- """
- Symmetric loss function for NNCLR.
- """
- (z0, p0), (z1, p1) = out
- loss0 = super().forward(z0, p0)
- loss1 = super().forward(z1, p1)
- return (loss0 + loss1) / 2
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, out) ‑> Callable[..., Any]
-
Symmetric loss function for NNCLR.
def forward(self, out):
- """
- Symmetric loss function for NNCLR.
- """
- (z0, p0), (z1, p1) = out
- loss0 = super().forward(z0, p0)
- loss1 = super().forward(z1, p1)
- return (loss0 + loss1) / 2
-fmcib.ssl.losses.ntxent_loss
from typing import List
-
-from lightly.loss import NTXentLoss as lightly_NTXentLoss
-
-
-class NTXentLoss(lightly_NTXentLoss):
- """
- NTXentNegativeMinedLoss:
- NTXentLoss with explicitly mined negatives
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize an instance of the class.
-
- Args:
- temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
- gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
- """
- super().__init__(temperature, gather_distributed)
-
- def forward(self, out: List):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (List[torch.Tensor]): List of tensors
-
- Returns:
- float: Contrastive Cross Entropy Loss value.
- """
- return super().forward(*out)
-
-class NTXentLoss
-(temperature: float = 0.1, gather_distributed: bool = False)
-
NTXentNegativeMinedLoss: -NTXentLoss with explicitly mined negatives
-Initialize an instance of the class.
-temperature
: float
, optionalgather_distributed
: bool
, optionalclass NTXentLoss(lightly_NTXentLoss):
- """
- NTXentNegativeMinedLoss:
- NTXentLoss with explicitly mined negatives
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize an instance of the class.
-
- Args:
- temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
- gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
- """
- super().__init__(temperature, gather_distributed)
-
- def forward(self, out: List):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (List[torch.Tensor]): List of tensors
-
- Returns:
- float: Contrastive Cross Entropy Loss value.
- """
- return super().forward(*out)
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, out: List[~T]) ‑> Callable[..., Any]
-
Forward pass through Negative mining contrastive Cross-Entropy Loss.
-out
: List[torch.Tensor]
float
def forward(self, out: List):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (List[torch.Tensor]): List of tensors
-
- Returns:
- float: Contrastive Cross Entropy Loss value.
- """
- return super().forward(*out)
-fmcib.ssl.losses.ntxent_mined_loss
Contrastive Loss Functions
-""" Contrastive Loss Functions """
-
-# Copyright (c) 2020. Lightly AG and its affiliates.
-# All Rights Reserved
-
-# Modified to function for explicitly selected negatives
-
-from typing import Dict
-
-import torch
-from lightly.utils import dist
-from torch import nn
-
-
-class NTXentNegativeMinedLoss(torch.nn.Module):
- """
- NTXentNegativeMinedLoss:
- NTXentLoss with explicitly mined negatives
-
- Args:
- temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
- gather_distributed (bool): Whether to gather hidden representations from other processes in a distributed setting. Default is False.
-
- Raises:
- ValueError: If the absolute value of temperature is less than 1e-8.
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize the NTXentNegativeMinedLoss object.
-
- Args:
- temperature (float, optional): The temperature parameter for the loss function. Defaults to 0.1.
- gather_distributed (bool, optional): Whether to use distributed gathering or not. Defaults to False.
-
- Raises:
- ValueError: If the absolute value of the temperature is too small.
-
- Attributes:
- temperature (float): The temperature parameter for the loss function.
- gather_distributed (bool): Whether to use distributed gathering or not.
- cross_entropy (torch.nn.CrossEntropyLoss): The cross entropy loss function.
- eps (float): A small value to avoid division by zero.
- """
- super(NTXentNegativeMinedLoss, self).__init__()
- self.temperature = temperature
- self.gather_distributed = gather_distributed
- self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
- self.eps = 1e-8
-
- if abs(self.temperature) < self.eps:
- raise ValueError("Illegal temperature: abs({}) < 1e-8".format(self.temperature))
-
- def forward(self, out: Dict):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
- Returns:
- torch.Tensor: Contrastive Cross Entropy Loss value.
-
- Raises:
- AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
- """
-
- assert "positive" in out, "`positive` key needs to be specified"
- assert "negative" in out, "`negative` key needs to be specified"
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
-
- device = pos0.device
- batch_size, _ = pos0.shape
-
- # normalize the output to length 1
- pos0 = nn.functional.normalize(pos0, dim=1)
- pos1 = nn.functional.normalize(pos1, dim=1)
- neg0 = nn.functional.normalize(neg0, dim=1)
- neg1 = nn.functional.normalize(neg1, dim=1)
-
- if self.gather_distributed and dist.world_size() > 1:
- # gather hidden representations from other processes
- pos0_large = torch.cat(dist.gather(pos0), 0)
- pos1_large = torch.cat(dist.gather(pos1), 0)
- neg0_large = torch.cat(dist.gather(neg0), 0)
- neg1_large = torch.cat(dist.gather(neg1), 0)
- diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
- else:
- # gather hidden representations from other processes
- pos0_large = pos0
- pos1_large = pos1
- neg0_large = neg0
- neg1_large = neg1
- diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
- logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
- logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
- logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
- logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
- logits_01 = logits_01[diag_mask].view(batch_size, -1)
- logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
- logits_0100 = torch.cat([logits_01, logits_00], dim=1)
- logits_1011 = torch.cat([logits_10, logits_11], dim=1)
- logits = torch.cat([logits_0100, logits_1011], dim=0)
-
- labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
- loss = self.cross_entropy(logits, labels)
-
- return loss
-
-class NTXentNegativeMinedLoss
-(temperature: float = 0.1, gather_distributed: bool = False)
-
NTXentNegativeMinedLoss: -NTXentLoss with explicitly mined negatives
-temperature
: float
gather_distributed
: bool
ValueError
Initialize the NTXentNegativeMinedLoss object.
-temperature
: float
, optionalgather_distributed
: bool
, optionalValueError
temperature
: float
gather_distributed
: bool
cross_entropy
: torch.nn.CrossEntropyLoss
eps
: float
class NTXentNegativeMinedLoss(torch.nn.Module):
- """
- NTXentNegativeMinedLoss:
- NTXentLoss with explicitly mined negatives
-
- Args:
- temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
- gather_distributed (bool): Whether to gather hidden representations from other processes in a distributed setting. Default is False.
-
- Raises:
- ValueError: If the absolute value of temperature is less than 1e-8.
- """
-
- def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
- """
- Initialize the NTXentNegativeMinedLoss object.
-
- Args:
- temperature (float, optional): The temperature parameter for the loss function. Defaults to 0.1.
- gather_distributed (bool, optional): Whether to use distributed gathering or not. Defaults to False.
-
- Raises:
- ValueError: If the absolute value of the temperature is too small.
-
- Attributes:
- temperature (float): The temperature parameter for the loss function.
- gather_distributed (bool): Whether to use distributed gathering or not.
- cross_entropy (torch.nn.CrossEntropyLoss): The cross entropy loss function.
- eps (float): A small value to avoid division by zero.
- """
- super(NTXentNegativeMinedLoss, self).__init__()
- self.temperature = temperature
- self.gather_distributed = gather_distributed
- self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
- self.eps = 1e-8
-
- if abs(self.temperature) < self.eps:
- raise ValueError("Illegal temperature: abs({}) < 1e-8".format(self.temperature))
-
- def forward(self, out: Dict):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
- Returns:
- torch.Tensor: Contrastive Cross Entropy Loss value.
-
- Raises:
- AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
- """
-
- assert "positive" in out, "`positive` key needs to be specified"
- assert "negative" in out, "`negative` key needs to be specified"
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
-
- device = pos0.device
- batch_size, _ = pos0.shape
-
- # normalize the output to length 1
- pos0 = nn.functional.normalize(pos0, dim=1)
- pos1 = nn.functional.normalize(pos1, dim=1)
- neg0 = nn.functional.normalize(neg0, dim=1)
- neg1 = nn.functional.normalize(neg1, dim=1)
-
- if self.gather_distributed and dist.world_size() > 1:
- # gather hidden representations from other processes
- pos0_large = torch.cat(dist.gather(pos0), 0)
- pos1_large = torch.cat(dist.gather(pos1), 0)
- neg0_large = torch.cat(dist.gather(neg0), 0)
- neg1_large = torch.cat(dist.gather(neg1), 0)
- diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
- else:
- # gather hidden representations from other processes
- pos0_large = pos0
- pos1_large = pos1
- neg0_large = neg0
- neg1_large = neg1
- diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
- logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
- logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
- logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
- logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
- logits_01 = logits_01[diag_mask].view(batch_size, -1)
- logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
- logits_0100 = torch.cat([logits_01, logits_00], dim=1)
- logits_1011 = torch.cat([logits_10, logits_11], dim=1)
- logits = torch.cat([logits_0100, logits_1011], dim=0)
-
- labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
- loss = self.cross_entropy(logits, labels)
-
- return loss
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, out: Dict[~KT, ~VT]) ‑> Callable[..., Any]
-
Forward pass through Negative mining contrastive Cross-Entropy Loss.
-out
: Dict
positive
and negative
keys to represent positive selected and negative selected samples.torch.Tensor
AssertionError
positive
or negative
keys are not specified in the input dictionary.def forward(self, out: Dict):
- """
- Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
- Args:
- out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
- Returns:
- torch.Tensor: Contrastive Cross Entropy Loss value.
-
- Raises:
- AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
- """
-
- assert "positive" in out, "`positive` key needs to be specified"
- assert "negative" in out, "`negative` key needs to be specified"
-
- pos0, pos1 = out["positive"]
- neg0, neg1 = out["negative"]
-
- device = pos0.device
- batch_size, _ = pos0.shape
-
- # normalize the output to length 1
- pos0 = nn.functional.normalize(pos0, dim=1)
- pos1 = nn.functional.normalize(pos1, dim=1)
- neg0 = nn.functional.normalize(neg0, dim=1)
- neg1 = nn.functional.normalize(neg1, dim=1)
-
- if self.gather_distributed and dist.world_size() > 1:
- # gather hidden representations from other processes
- pos0_large = torch.cat(dist.gather(pos0), 0)
- pos1_large = torch.cat(dist.gather(pos1), 0)
- neg0_large = torch.cat(dist.gather(neg0), 0)
- neg1_large = torch.cat(dist.gather(neg1), 0)
- diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
- else:
- # gather hidden representations from other processes
- pos0_large = pos0
- pos1_large = pos1
- neg0_large = neg0
- neg1_large = neg1
- diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
- logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
- logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
- logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
- logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
- logits_01 = logits_01[diag_mask].view(batch_size, -1)
- logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
- logits_0100 = torch.cat([logits_01, logits_00], dim=1)
- logits_1011 = torch.cat([logits_10, logits_11], dim=1)
- logits = torch.cat([logits_0100, logits_1011], dim=0)
-
- labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
- loss = self.cross_entropy(logits, labels)
-
- return loss
-fmcib.ssl.losses.swav_loss
import lightly
-
-
-class SwaVLoss(lightly.loss.swav_loss.SwaVLoss):
- """
- A class representing a custom SwaV loss function.
-
- Attributes:
- temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
- sinkhorn_iterations (int): The number of iterations for Sinkhorn algorithm. Default is 3.
- sinkhorn_epsilon (float): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
- sinkhorn_gather_distributed (bool): Whether to gather distributed results for Sinkhorn algorithm. Default is False.
- """
-
- def __init__(
- self,
- temperature: float = 0.1,
- sinkhorn_iterations: int = 3,
- sinkhorn_epsilon: float = 0.05,
- sinkhorn_gather_distributed: bool = False,
- ):
- """
- Initialize the object with specified parameters.
-
- Args:
- temperature (float, optional): The temperature parameter. Default is 0.1.
- sinkhorn_iterations (int, optional): The number of Sinkhorn iterations. Default is 3.
- sinkhorn_epsilon (float, optional): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
- sinkhorn_gather_distributed (bool, optional): Whether to use distributed computation for Sinkhorn algorithm. Default is False.
- """
- super().__init__(temperature, sinkhorn_iterations, sinkhorn_epsilon, sinkhorn_gather_distributed)
-
- def forward(self, pred):
- """
- Perform a forward pass of the model.
-
- Args:
- pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
- Returns:
- The output of the forward pass.
- """
- high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
- return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-
-class SwaVLoss
-(temperature: float = 0.1, sinkhorn_iterations: int = 3, sinkhorn_epsilon: float = 0.05, sinkhorn_gather_distributed: bool = False)
-
A class representing a custom SwaV loss function.
-temperature
: float
sinkhorn_iterations
: int
sinkhorn_epsilon
: float
sinkhorn_gather_distributed
: bool
Initialize the object with specified parameters.
-temperature
: float
, optionalsinkhorn_iterations
: int
, optionalsinkhorn_epsilon
: float
, optionalsinkhorn_gather_distributed
: bool
, optionalclass SwaVLoss(lightly.loss.swav_loss.SwaVLoss):
- """
- A class representing a custom SwaV loss function.
-
- Attributes:
- temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
- sinkhorn_iterations (int): The number of iterations for Sinkhorn algorithm. Default is 3.
- sinkhorn_epsilon (float): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
- sinkhorn_gather_distributed (bool): Whether to gather distributed results for Sinkhorn algorithm. Default is False.
- """
-
- def __init__(
- self,
- temperature: float = 0.1,
- sinkhorn_iterations: int = 3,
- sinkhorn_epsilon: float = 0.05,
- sinkhorn_gather_distributed: bool = False,
- ):
- """
- Initialize the object with specified parameters.
-
- Args:
- temperature (float, optional): The temperature parameter. Default is 0.1.
- sinkhorn_iterations (int, optional): The number of Sinkhorn iterations. Default is 3.
- sinkhorn_epsilon (float, optional): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
- sinkhorn_gather_distributed (bool, optional): Whether to use distributed computation for Sinkhorn algorithm. Default is False.
- """
- super().__init__(temperature, sinkhorn_iterations, sinkhorn_epsilon, sinkhorn_gather_distributed)
-
- def forward(self, pred):
- """
- Perform a forward pass of the model.
-
- Args:
- pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
- Returns:
- The output of the forward pass.
- """
- high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
- return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, pred) ‑> Callable[..., Any]
-
Perform a forward pass of the model.
-pred
: tuple
The output of the forward pass.
def forward(self, pred):
- """
- Perform a forward pass of the model.
-
- Args:
- pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
- Returns:
- The output of the forward pass.
- """
- high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
- return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-fmcib.ssl.modules.exneg_simclr
from typing import Dict, Union
-
-import torch
-import torch.nn as nn
-from lightly.models import SimCLR
-
-
-class ExNegSimCLR(SimCLR):
- """
- Extended Negative Sampling SimCLR model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
- out_dim (int): Dimension of the output feature embeddings. Default is 128.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
- print(backbone)
- super().__init__(backbone, num_ftrs, out_dim)
-
- def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- out (Dict): Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-
-class ExNegSimCLR
-(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128)
-
Extended Negative Sampling SimCLR model.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
Initializes internal Module state, shared by both nn.Module and ScriptModule.
class ExNegSimCLR(SimCLR):
- """
- Extended Negative Sampling SimCLR model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
- out_dim (int): Dimension of the output feature embeddings. Default is 128.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
- print(backbone)
- super().__init__(backbone, num_ftrs, out_dim)
-
- def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- out (Dict): Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-
-def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False) ‑> Callable[..., Any]
-
Forward pass of the ExNegSimCLR model.
-x
: Union[Dict, torch.Tensor]
return_features
: bool
out (Dict): Output dictionary containing the forward pass results for each input view.
def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- out (Dict): Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-fmcib.ssl.modules
from .exneg_simclr import ExNegSimCLR
-from .load_pretrained_resnet import LoadPretrainedResnet3D
-fmcib.ssl.modules.exneg_simclr
fmcib.ssl.modules.load_pretrained_resnet
fmcib.ssl.modules.load_pretrained_resnet
import monai
-import torch
-from loguru import logger
-from monai.networks.nets.resnet import ResNetBottleneck as Bottleneck
-from torch import nn
-
-
-class LoadPretrainedResnet3D(nn.Module):
- """
- LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.
-
- Args:
- pretrained (str): Path to the pretrained model file. Default is None.
- vissl (bool): Whether the pretrained model is from VISSL. Default is False.
- heads (list): List of integers specifying the number of input and output channels for each head. Default is an empty list.
-
- Attributes:
- trunk (nn.Module): The ResNet trunk network.
- heads (nn.Module): The sequential module containing the heads.
-
- Methods:
- forward(x): Forward pass of the model.
- load(pretrained): Load the pretrained model weights.
-
- Example:
- model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])
- """
-
- def __init__(self, pretrained=None, vissl=False, heads=[]) -> None:
- super().__init__()
- self.trunk = monai.networks.nets.resnet.ResNet(
- block=Bottleneck,
- layers=(3, 4, 6, 3),
- block_inplanes=(64, 128, 256, 512),
- spatial_dims=3,
- n_input_channels=1,
- conv1_t_stride=2,
- conv1_t_size=7,
- widen_factor=2,
- feed_forward=False,
- )
-
- head_layers = []
- for idx in range(len(heads) - 1):
- current_layers = []
- current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
- if idx != (len(heads) - 2):
- current_layers.append(nn.ReLU(inplace=True))
-
- head_layers.append(nn.Sequential(*current_layers))
-
- if len(head_layers):
- self.heads = nn.Sequential(*head_layers)
- else:
- self.heads = nn.Identity()
-
- if pretrained is not None:
- self.load(pretrained)
-
- def forward(self, x: torch.Tensor):
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
- def load(self, pretrained):
- pretrained_model = torch.load(pretrained)
-
- # Load trained trunk
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-
-class LoadPretrainedResnet3D
-(pretrained=None, vissl=False, heads=[])
-
LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.
-pretrained
: str
vissl
: bool
heads
: list
trunk
: nn.Module
heads
: nn.Module
forward(x): Forward pass of the model. -load(pretrained): Load the pretrained model weights.
-model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])
-Initializes internal Module state, shared by both nn.Module and ScriptModule.
class LoadPretrainedResnet3D(nn.Module):
- """
- LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.
-
- Args:
- pretrained (str): Path to the pretrained model file. Default is None.
- vissl (bool): Whether the pretrained model is from VISSL. Default is False.
- heads (list): List of integers specifying the number of input and output channels for each head. Default is an empty list.
-
- Attributes:
- trunk (nn.Module): The ResNet trunk network.
- heads (nn.Module): The sequential module containing the heads.
-
- Methods:
- forward(x): Forward pass of the model.
- load(pretrained): Load the pretrained model weights.
-
- Example:
- model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])
- """
-
- def __init__(self, pretrained=None, vissl=False, heads=[]) -> None:
- super().__init__()
- self.trunk = monai.networks.nets.resnet.ResNet(
- block=Bottleneck,
- layers=(3, 4, 6, 3),
- block_inplanes=(64, 128, 256, 512),
- spatial_dims=3,
- n_input_channels=1,
- conv1_t_stride=2,
- conv1_t_size=7,
- widen_factor=2,
- feed_forward=False,
- )
-
- head_layers = []
- for idx in range(len(heads) - 1):
- current_layers = []
- current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
- if idx != (len(heads) - 2):
- current_layers.append(nn.ReLU(inplace=True))
-
- head_layers.append(nn.Sequential(*current_layers))
-
- if len(head_layers):
- self.heads = nn.Sequential(*head_layers)
- else:
- self.heads = nn.Identity()
-
- if pretrained is not None:
- self.load(pretrained)
-
- def forward(self, x: torch.Tensor):
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
- def load(self, pretrained):
- pretrained_model = torch.load(pretrained)
-
- # Load trained trunk
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-
-def forward(self, x: torch.Tensor) ‑> Callable[..., Any]
-
Defines the computation performed at every call.
-Should be overridden by all subclasses.
-Note
-Although the recipe for forward pass needs to be defined within
-this function, one should call the :class:Module
instance afterwards
-instead of this since the former takes care of running the
-registered hooks while the latter silently ignores them.
def forward(self, x: torch.Tensor):
- out = self.trunk(x)
- out = self.heads(out)
- return out
-
-def load(self, pretrained)
-
def load(self, pretrained):
- pretrained_model = torch.load(pretrained)
-
- # Load trained trunk
- trained_trunk = pretrained_model["trunk_state_dict"]
- msg = self.trunk.load_state_dict(trained_trunk, strict=False)
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- # Load trained heads
- if "head_state_dict" in pretrained_model:
- trained_heads = pretrained_model["head_state_dict"]
-
- try:
- msg = self.heads.load_state_dict(trained_heads, strict=False)
- except Exception as e:
- logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
- logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
- logger.info(f"Loaded pretrained model weights \n")
-fmcib.ssl.modules.exneg_simclr
from typing import Dict, Union
-
-import torch
-import torch.nn as nn
-from lightly.models import SimCLR as lightly_SimCLR
-from lightly.models.modules import SimCLRProjectionHead
-
-
-class ExNegSimCLR(lightly_SimCLR):
- """
- Extended Negative Sampling SimCLR model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
- out_dim (int): Dimension of the output feature embeddings. Default is 128.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
- """
- Initialize the object.
-
- Args:
- backbone (nn.Module): The backbone neural network.
- num_ftrs (int, optional): The number of input features for the projection head. Default is 32.
- out_dim (int, optional): The output dimension of the projection head. Default is 128.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__(backbone, num_ftrs, out_dim)
- # replace the projection head with a new one
- self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
- def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- Dict: Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-
-class ExNegSimCLR
-(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128)
-
Extended Negative Sampling SimCLR model.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
Initialize the object.
-backbone
: nn.Module
num_ftrs
: int
, optionalout_dim
: int
, optionalNone
-None
class ExNegSimCLR(lightly_SimCLR):
- """
- Extended Negative Sampling SimCLR model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
- out_dim (int): Dimension of the output feature embeddings. Default is 128.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
- """
- Initialize the object.
-
- Args:
- backbone (nn.Module): The backbone neural network.
- num_ftrs (int, optional): The number of input features for the projection head. Default is 32.
- out_dim (int, optional): The output dimension of the projection head. Default is 128.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__(backbone, num_ftrs, out_dim)
- # replace the projection head with a new one
- self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
- def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- Dict: Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x: Union[Dict[~KT, ~VT], torch.Tensor], return_features: bool = False) ‑> Callable[..., Any]
-
Forward pass of the ExNegSimCLR model.
-x
: Union[Dict, torch.Tensor]
return_features
: bool
Dict
def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
- """
- Forward pass of the ExNegSimCLR model.
-
- Args:
- x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
- return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
- Returns:
- Dict: Output dictionary containing the forward pass results for each input view.
- """
- assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
- out = {}
- for key, value in x.items():
- if isinstance(value, list):
- out[key] = super().forward(*value, return_features)
-
- return out
-fmcib.ssl.modules
from .exneg_simclr import ExNegSimCLR
-from .nnclr import NNCLR
-from .simclr import SimCLR
-from .swav import SwaV
-fmcib.ssl.modules.exneg_simclr
fmcib.ssl.modules.nnclr
fmcib.ssl.modules.simclr
fmcib.ssl.modules.swav
fmcib.ssl.modules.nnclr
from typing import Any, Dict, List, Optional, Union
-
-import torch
-import torch.nn as nn
-from lightly.models.modules import NNCLRPredictionHead, NNCLRProjectionHead, NNMemoryBankModule
-
-
-class NNCLR(nn.Module):
- """
- Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py
- """
-
- def __init__(
- self,
- backbone: nn.Module,
- num_ftrs: int = 4096,
- proj_hidden_dim: int = 4096,
- pred_hidden_dim: int = 4096,
- out_dim: int = 256,
- memory_bank_size: int = 4096,
- ) -> None:
- """
- Initialize the NNCLR model.
-
- Args:
- backbone (nn.Module): The backbone neural network model.
- num_ftrs (int, optional): The number of features in the backbone output. Default is 4096.
- proj_hidden_dim (int, optional): The hidden dimension of the projection head. Default is 4096.
- pred_hidden_dim (int, optional): The hidden dimension of the prediction head. Default is 4096.
- out_dim (int, optional): The output dimension of the model. Default is 256.
- memory_bank_size (int, optional): The size of the memory bank module. Default is 4096.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__()
- self.backbone = backbone
- self.projection_head = NNCLRProjectionHead(num_ftrs, proj_hidden_dim, out_dim)
- self.prediction_head = NNCLRPredictionHead(out_dim, pred_hidden_dim, out_dim)
- self.memory_bank = NNMemoryBankModule(memory_bank_size)
-
- def forward(
- self,
- x: List[torch.Tensor],
- get_nearest_neighbor: bool = True,
- ):
- # forward pass of first input x0
- """
- Forward pass of the model.
-
- Args:
- x (List[torch.Tensor]): A list containing two input tensors.
- get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
- Defaults to True.
-
- Returns:
- Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
- A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
- for each input tensor.
- """
- x0, x1 = x
- f0 = self.backbone(x0).flatten(start_dim=1)
- z0 = self.projection_head(f0)
- p0 = self.prediction_head(z0)
-
- if get_nearest_neighbor:
- z0 = self.memory_bank(z0, update=False)
-
- # forward pass of second input x1
- f1 = self.backbone(x1).flatten(start_dim=1)
- z1 = self.projection_head(f1)
- p1 = self.prediction_head(z1)
-
- if get_nearest_neighbor:
- z1 = self.memory_bank(z1, update=True)
-
- return (z0, p0), (z1, p1)
-
-class NNCLR
-(backbone: torch.nn.modules.module.Module, num_ftrs: int = 4096, proj_hidden_dim: int = 4096, pred_hidden_dim: int = 4096, out_dim: int = 256, memory_bank_size: int = 4096)
-
Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py
-Initialize the NNCLR model.
-backbone
: nn.Module
num_ftrs
: int
, optionalproj_hidden_dim
: int
, optionalpred_hidden_dim
: int
, optionalout_dim
: int
, optionalmemory_bank_size
: int
, optionalNone
-None
class NNCLR(nn.Module):
- """
- Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py
- """
-
- def __init__(
- self,
- backbone: nn.Module,
- num_ftrs: int = 4096,
- proj_hidden_dim: int = 4096,
- pred_hidden_dim: int = 4096,
- out_dim: int = 256,
- memory_bank_size: int = 4096,
- ) -> None:
- """
- Initialize the NNCLR model.
-
- Args:
- backbone (nn.Module): The backbone neural network model.
- num_ftrs (int, optional): The number of features in the backbone output. Default is 4096.
- proj_hidden_dim (int, optional): The hidden dimension of the projection head. Default is 4096.
- pred_hidden_dim (int, optional): The hidden dimension of the prediction head. Default is 4096.
- out_dim (int, optional): The output dimension of the model. Default is 256.
- memory_bank_size (int, optional): The size of the memory bank module. Default is 4096.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__()
- self.backbone = backbone
- self.projection_head = NNCLRProjectionHead(num_ftrs, proj_hidden_dim, out_dim)
- self.prediction_head = NNCLRPredictionHead(out_dim, pred_hidden_dim, out_dim)
- self.memory_bank = NNMemoryBankModule(memory_bank_size)
-
- def forward(
- self,
- x: List[torch.Tensor],
- get_nearest_neighbor: bool = True,
- ):
- # forward pass of first input x0
- """
- Forward pass of the model.
-
- Args:
- x (List[torch.Tensor]): A list containing two input tensors.
- get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
- Defaults to True.
-
- Returns:
- Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
- A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
- for each input tensor.
- """
- x0, x1 = x
- f0 = self.backbone(x0).flatten(start_dim=1)
- z0 = self.projection_head(f0)
- p0 = self.prediction_head(z0)
-
- if get_nearest_neighbor:
- z0 = self.memory_bank(z0, update=False)
-
- # forward pass of second input x1
- f1 = self.backbone(x1).flatten(start_dim=1)
- z1 = self.projection_head(f1)
- p1 = self.prediction_head(z1)
-
- if get_nearest_neighbor:
- z1 = self.memory_bank(z1, update=True)
-
- return (z0, p0), (z1, p1)
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x: List[torch.Tensor], get_nearest_neighbor: bool = True) ‑> Callable[..., Any]
-
Forward pass of the model.
-x
: List[torch.Tensor]
get_nearest_neighbor
: bool
, optionalTuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: -A tuple containing two tuples. The inner tuples contain the projection and prediction vectors -for each input tensor.
def forward(
- self,
- x: List[torch.Tensor],
- get_nearest_neighbor: bool = True,
-):
- # forward pass of first input x0
- """
- Forward pass of the model.
-
- Args:
- x (List[torch.Tensor]): A list containing two input tensors.
- get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
- Defaults to True.
-
- Returns:
- Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
- A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
- for each input tensor.
- """
- x0, x1 = x
- f0 = self.backbone(x0).flatten(start_dim=1)
- z0 = self.projection_head(f0)
- p0 = self.prediction_head(z0)
-
- if get_nearest_neighbor:
- z0 = self.memory_bank(z0, update=False)
-
- # forward pass of second input x1
- f1 = self.backbone(x1).flatten(start_dim=1)
- z1 = self.projection_head(f1)
- p1 = self.prediction_head(z1)
-
- if get_nearest_neighbor:
- z1 = self.memory_bank(z1, update=True)
-
- return (z0, p0), (z1, p1)
-fmcib.ssl.modules.simclr
import torch
-import torch.nn as nn
-from lightly.models import SimCLR as lightly_SimCLR
-from lightly.models.modules import SimCLRProjectionHead
-
-
-class SimCLR(lightly_SimCLR):
- """
- A class representing a SimCLR model.
-
- Attributes:
- backbone (nn.Module): The backbone model used in the SimCLR model.
- num_ftrs (int): The number of output features from the backbone model.
- out_dim (int): The dimension of the output representations.
- projection_head (SimCLRProjectionHead): The projection head used for projection head training.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128):
- """
- Initialize the object with a backbone network, number of features, and output dimension.
-
- Args:
- backbone (nn.Module): The backbone network.
- num_ftrs (int): The number of features. Default is 32.
- out_dim (int): The output dimension. Default is 128.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__(backbone, num_ftrs, out_dim)
- self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
- def forward(self, x, return_features=False):
- """
- Perform a forward pass of the neural network.
-
- Args:
- x (tuple): A tuple of input data. Each element of the tuple represents a different input.
- return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
- Returns:
- torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
- If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
- Raises:
- None.
- """
- return super().forward(*x, return_features)
-
-class SimCLR
-(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128)
-
A class representing a SimCLR model.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
projection_head
: SimCLRProjectionHead
Initialize the object with a backbone network, number of features, and output dimension.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
None
-None
class SimCLR(lightly_SimCLR):
- """
- A class representing a SimCLR model.
-
- Attributes:
- backbone (nn.Module): The backbone model used in the SimCLR model.
- num_ftrs (int): The number of output features from the backbone model.
- out_dim (int): The dimension of the output representations.
- projection_head (SimCLRProjectionHead): The projection head used for projection head training.
- """
-
- def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128):
- """
- Initialize the object with a backbone network, number of features, and output dimension.
-
- Args:
- backbone (nn.Module): The backbone network.
- num_ftrs (int): The number of features. Default is 32.
- out_dim (int): The output dimension. Default is 128.
-
- Returns:
- None
-
- Raises:
- None
- """
- super().__init__(backbone, num_ftrs, out_dim)
- self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
- def forward(self, x, return_features=False):
- """
- Perform a forward pass of the neural network.
-
- Args:
- x (tuple): A tuple of input data. Each element of the tuple represents a different input.
- return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
- Returns:
- torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
- If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
- Raises:
- None.
- """
- return super().forward(*x, return_features)
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, x, return_features=False) ‑> Callable[..., Any]
-
Perform a forward pass of the neural network.
-x
: tuple
return_features
: bool
, optionaltorch.Tensor
or tuple
None.
def forward(self, x, return_features=False):
- """
- Perform a forward pass of the neural network.
-
- Args:
- x (tuple): A tuple of input data. Each element of the tuple represents a different input.
- return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
- Returns:
- torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
- If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
- Raises:
- None.
- """
- return super().forward(*x, return_features)
-fmcib.ssl.modules.swav
import torch
-from lightly.loss.memory_bank import MemoryBankModule
-from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
-from torch import nn
-
-torch.set_float32_matmul_precision("medium")
-
-
-class SwaV(nn.Module):
- """
- Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
-
- Args:
- backbone (nn.Module): CNN backbone for feature extraction.
- num_ftrs (int): Number of input features for the projection head.
- out_dim (int): Output dimension for the projection head.
- n_prototypes (int): Number of prototypes to compute.
- n_queues (int): Number of memory banks (queues). Should be equal to the number of high-resolution inputs.
- queue_length (int, optional): Length of the memory bank. Defaults to 0.
- start_queue_at_epoch (int, optional): Number of the epoch at which SwaV starts using the queued features. Defaults to 0.
- n_steps_frozen_prototypes (int, optional): Number of steps during which we keep the prototypes fixed. Defaults to 0.
- """
-
- def __init__(
- self,
- backbone: nn.Module,
- num_ftrs: int,
- out_dim: int,
- n_prototypes: int,
- n_queues: int,
- queue_length: int = 0,
- start_queue_at_epoch: int = 0,
- n_steps_frozen_prototypes: int = 0,
- ):
- """
- Initialize a SwaV model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): The number of input features.
- out_dim (int): The dimension of the output.
- n_prototypes (int): The number of prototypes.
- n_queues (int): The number of queues.
- queue_length (int, optional): The length of the queue. Default is 0.
- start_queue_at_epoch (int, optional): The epoch at which to start using the queue. Default is 0.
- n_steps_frozen_prototypes (int, optional): The number of steps to freeze prototypes. Default is 0.
-
- Returns:
- None
-
- Attributes:
- backbone (nn.Module): The backbone model.
- projection_head (SwaVProjectionHead): The projection head.
- prototypes (SwaVPrototypes): The prototypes.
- queues (nn.ModuleList, optional): The queues. If n_queues > 0, this will be initialized with MemoryBankModules.
- queue_length (int, optional): The length of the queue.
- num_features_queued (int): The number of features queued.
- start_queue_at_epoch (int): The epoch at which to start using the queue.
- """
- super().__init__()
- # Backbone for feature extraction
- self.backbone = backbone
- # Projection head to project features to a lower-dimensional space
- self.projection_head = SwaVProjectionHead(num_ftrs, num_ftrs // 2, out_dim)
- # SwAV Prototypes module for prototype computation
- self.prototypes = SwaVPrototypes(out_dim, n_prototypes, n_steps_frozen_prototypes)
-
- self.queues = None
- if n_queues > 0:
- # Initialize the memory banks (queues)
- self.queues = nn.ModuleList([MemoryBankModule(size=queue_length) for _ in range(n_queues)])
- self.queue_length = queue_length
- self.num_features_queued = 0
- self.start_queue_at_epoch = start_queue_at_epoch
-
- def forward(self, input, epoch=None, step=None):
- """
- Performs the forward pass for the SwAV model.
-
- Args:
- input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
- and a list of low-resolution input images.
- epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
- step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
- Returns:
- Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
- low-resolution prototypes, and queue prototypes.
- """
- high_resolution, low_resolution = input
-
- # Normalize prototypes
- self.prototypes.normalize()
-
- # Compute high and low resolution features
- high_resolution_features = [self._subforward(x) for x in high_resolution]
- low_resolution_features = [self._subforward(x) for x in low_resolution]
-
- # Compute prototypes for high and low resolution features
- high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
- low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
- # Compute prototypes for queued features
- queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
- return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-
- def _subforward(self, input):
- """
- Subforward pass to compute features for the input image.
-
- Args:
- input (Tensor): Input image tensor.
-
- Returns:
- Tensor: L2-normalized feature tensor.
- """
- # Extract features using the backbone
- features = self.backbone(input).flatten(start_dim=1)
- # Project features using the projection head
- features = self.projection_head(features)
- # L2-normalize features
- features = nn.functional.normalize(features, dim=1, p=2)
- return features
-
- @torch.no_grad()
- def _get_queue_prototypes(self, high_resolution_features, epoch=None):
- """
- Compute the queue prototypes for the given high-resolution features.
-
- Args:
- high_resolution_features (List[Tensor]): List of high-resolution feature tensors.
- epoch (int, optional): Current epoch number. Required if `start_queue_at_epoch` > 0. Defaults to None.
-
- Returns:
- List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None.
- """
- if self.queues is None:
- return None
-
- if len(high_resolution_features) != len(self.queues):
- raise ValueError(
- f"The number of queues ({len(self.queues)}) should be equal to the number of high "
- f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
- )
-
- # Get the queue features
- queue_features = []
- for i in range(len(self.queues)):
- _, features = self.queues[i](high_resolution_features[i], update=True)
- # Queue features are in (num_ftrs X queue_length) shape, while the high res
- # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
- features = torch.permute(features, (1, 0))
- queue_features.append(features)
-
- # Do not return queue prototypes if not enough features have been queued
- self.num_features_queued += high_resolution_features[0].shape[0]
- if self.num_features_queued < self.queue_length:
- return None
-
- # If loss calculation with queue prototypes starts at a later epoch,
- # just queue the features and return None instead of queue prototypes.
- if self.start_queue_at_epoch > 0:
- if epoch is None:
- raise ValueError(
- "The epoch number must be passed to the `forward()` " "method if `start_queue_at_epoch` is greater than 0."
- )
- if epoch < self.start_queue_at_epoch:
- return None
-
- # Assign prototypes
- queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
- # Do not return queue prototypes if not enough features have been queued
- return queue_prototypes
-
-class SwaV
-(backbone: torch.nn.modules.module.Module, num_ftrs: int, out_dim: int, n_prototypes: int, n_queues: int, queue_length: int = 0, start_queue_at_epoch: int = 0, n_steps_frozen_prototypes: int = 0)
-
Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
n_prototypes
: int
n_queues
: int
queue_length
: int
, optionalstart_queue_at_epoch
: int
, optionaln_steps_frozen_prototypes
: int
, optionalInitialize a SwaV model.
-backbone
: nn.Module
num_ftrs
: int
out_dim
: int
n_prototypes
: int
n_queues
: int
queue_length
: int
, optionalstart_queue_at_epoch
: int
, optionaln_steps_frozen_prototypes
: int
, optionalNone
-backbone
: nn.Module
projection_head
: SwaVProjectionHead
prototypes
: SwaVPrototypes
queues
: nn.ModuleList
, optionalqueue_length
: int
, optionalnum_features_queued
: int
start_queue_at_epoch
: int
class SwaV(nn.Module):
- """
- Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
-
- Args:
- backbone (nn.Module): CNN backbone for feature extraction.
- num_ftrs (int): Number of input features for the projection head.
- out_dim (int): Output dimension for the projection head.
- n_prototypes (int): Number of prototypes to compute.
- n_queues (int): Number of memory banks (queues). Should be equal to the number of high-resolution inputs.
- queue_length (int, optional): Length of the memory bank. Defaults to 0.
- start_queue_at_epoch (int, optional): Number of the epoch at which SwaV starts using the queued features. Defaults to 0.
- n_steps_frozen_prototypes (int, optional): Number of steps during which we keep the prototypes fixed. Defaults to 0.
- """
-
- def __init__(
- self,
- backbone: nn.Module,
- num_ftrs: int,
- out_dim: int,
- n_prototypes: int,
- n_queues: int,
- queue_length: int = 0,
- start_queue_at_epoch: int = 0,
- n_steps_frozen_prototypes: int = 0,
- ):
- """
- Initialize a SwaV model.
-
- Args:
- backbone (nn.Module): The backbone model.
- num_ftrs (int): The number of input features.
- out_dim (int): The dimension of the output.
- n_prototypes (int): The number of prototypes.
- n_queues (int): The number of queues.
- queue_length (int, optional): The length of the queue. Default is 0.
- start_queue_at_epoch (int, optional): The epoch at which to start using the queue. Default is 0.
- n_steps_frozen_prototypes (int, optional): The number of steps to freeze prototypes. Default is 0.
-
- Returns:
- None
-
- Attributes:
- backbone (nn.Module): The backbone model.
- projection_head (SwaVProjectionHead): The projection head.
- prototypes (SwaVPrototypes): The prototypes.
- queues (nn.ModuleList, optional): The queues. If n_queues > 0, this will be initialized with MemoryBankModules.
- queue_length (int, optional): The length of the queue.
- num_features_queued (int): The number of features queued.
- start_queue_at_epoch (int): The epoch at which to start using the queue.
- """
- super().__init__()
- # Backbone for feature extraction
- self.backbone = backbone
- # Projection head to project features to a lower-dimensional space
- self.projection_head = SwaVProjectionHead(num_ftrs, num_ftrs // 2, out_dim)
- # SwAV Prototypes module for prototype computation
- self.prototypes = SwaVPrototypes(out_dim, n_prototypes, n_steps_frozen_prototypes)
-
- self.queues = None
- if n_queues > 0:
- # Initialize the memory banks (queues)
- self.queues = nn.ModuleList([MemoryBankModule(size=queue_length) for _ in range(n_queues)])
- self.queue_length = queue_length
- self.num_features_queued = 0
- self.start_queue_at_epoch = start_queue_at_epoch
-
- def forward(self, input, epoch=None, step=None):
- """
- Performs the forward pass for the SwAV model.
-
- Args:
- input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
- and a list of low-resolution input images.
- epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
- step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
- Returns:
- Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
- low-resolution prototypes, and queue prototypes.
- """
- high_resolution, low_resolution = input
-
- # Normalize prototypes
- self.prototypes.normalize()
-
- # Compute high and low resolution features
- high_resolution_features = [self._subforward(x) for x in high_resolution]
- low_resolution_features = [self._subforward(x) for x in low_resolution]
-
- # Compute prototypes for high and low resolution features
- high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
- low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
- # Compute prototypes for queued features
- queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
- return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-
- def _subforward(self, input):
- """
- Subforward pass to compute features for the input image.
-
- Args:
- input (Tensor): Input image tensor.
-
- Returns:
- Tensor: L2-normalized feature tensor.
- """
- # Extract features using the backbone
- features = self.backbone(input).flatten(start_dim=1)
- # Project features using the projection head
- features = self.projection_head(features)
- # L2-normalize features
- features = nn.functional.normalize(features, dim=1, p=2)
- return features
-
- @torch.no_grad()
- def _get_queue_prototypes(self, high_resolution_features, epoch=None):
- """
- Compute the queue prototypes for the given high-resolution features.
-
- Args:
- high_resolution_features (List[Tensor]): List of high-resolution feature tensors.
- epoch (int, optional): Current epoch number. Required if `start_queue_at_epoch` > 0. Defaults to None.
-
- Returns:
- List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None.
- """
- if self.queues is None:
- return None
-
- if len(high_resolution_features) != len(self.queues):
- raise ValueError(
- f"The number of queues ({len(self.queues)}) should be equal to the number of high "
- f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
- )
-
- # Get the queue features
- queue_features = []
- for i in range(len(self.queues)):
- _, features = self.queues[i](high_resolution_features[i], update=True)
- # Queue features are in (num_ftrs X queue_length) shape, while the high res
- # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
- features = torch.permute(features, (1, 0))
- queue_features.append(features)
-
- # Do not return queue prototypes if not enough features have been queued
- self.num_features_queued += high_resolution_features[0].shape[0]
- if self.num_features_queued < self.queue_length:
- return None
-
- # If loss calculation with queue prototypes starts at a later epoch,
- # just queue the features and return None instead of queue prototypes.
- if self.start_queue_at_epoch > 0:
- if epoch is None:
- raise ValueError(
- "The epoch number must be passed to the `forward()` " "method if `start_queue_at_epoch` is greater than 0."
- )
- if epoch < self.start_queue_at_epoch:
- return None
-
- # Assign prototypes
- queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
- # Do not return queue prototypes if not enough features have been queued
- return queue_prototypes
-var call_super_init : bool
var dump_patches : bool
var training : bool
-def forward(self, input, epoch=None, step=None) ‑> Callable[..., Any]
-
Performs the forward pass for the SwAV model.
-input
: Tuple[List[Tensor], List[Tensor]]
epoch
: int
, optionalstart_queue_at_epoch
> 0. Defaults to None.step
: int
, optionaln_steps_frozen_prototypes
> 0. Defaults to None.Tuple[List[Tensor], List[Tensor], List[Tensor]]
def forward(self, input, epoch=None, step=None):
- """
- Performs the forward pass for the SwAV model.
-
- Args:
- input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
- and a list of low-resolution input images.
- epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
- step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
- Returns:
- Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
- low-resolution prototypes, and queue prototypes.
- """
- high_resolution, low_resolution = input
-
- # Normalize prototypes
- self.prototypes.normalize()
-
- # Compute high and low resolution features
- high_resolution_features = [self._subforward(x) for x in high_resolution]
- low_resolution_features = [self._subforward(x) for x in low_resolution]
-
- # Compute prototypes for high and low resolution features
- high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
- low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
- # Compute prototypes for queued features
- queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
- return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-fmcib.ssl.optimizers
from .lars import LARS
-fmcib.ssl.optimizers.lars
fmcib.ssl.optimizers.lars
"""
-References:
- - https://arxiv.org/pdf/1708.03888.pdf
- - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py
-"""
-import torch
-from torch.optim.optimizer import Optimizer, required
-
-
-class LARS(Optimizer):
- """Extends SGD in PyTorch with LARS scaling from the paper
- `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float): learning rate
- momentum (float, optional): momentum factor (default: 0)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- dampening (float, optional): dampening for momentum (default: 0)
- nesterov (bool, optional): enables Nesterov momentum (default: False)
- trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
- eps (float, optional): eps for division denominator (default: 1e-8)
-
- Example:
- >>> model = torch.nn.Linear(10, 1)
- >>> input = torch.Tensor(10)
- >>> target = torch.Tensor([1.])
- >>> loss_fn = lambda input, target: (input - target) ** 2
- >>> #
- >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
-
- .. note::
- The application of momentum in the SGD part is modified according to
- the PyTorch standards. LARS scaling fits into the equation in the
- following fashion.
-
- .. math::
- \begin{aligned}
- g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
- v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
- p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
- \\end{aligned}
-
- where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
- parameters, gradient, velocity, momentum, and weight decay respectively.
- The :math:`lars_lr` is defined by Eq. 6 in the paper.
- The Nesterov version is analogously modified.
-
- .. warning::
- Parameters with weight decay set to 0 will automatically be excluded from
- layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
- and BYOL.
- """
-
- def __init__(
- self,
- params,
- lr=required,
- momentum=0,
- dampening=0,
- weight_decay=0,
- nesterov=False,
- trust_coefficient=0.001,
- eps=1e-8,
- ):
- if lr is not required and lr < 0.0:
- raise ValueError(f"Invalid learning rate: {lr}")
- if momentum < 0.0:
- raise ValueError(f"Invalid momentum value: {momentum}")
- if weight_decay < 0.0:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
- defaults = dict(
- lr=lr,
- momentum=momentum,
- dampening=dampening,
- weight_decay=weight_decay,
- nesterov=nesterov,
- trust_coefficient=trust_coefficient,
- eps=eps,
- )
- if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
- super().__init__(params, defaults)
-
- def __setstate__(self, state):
- super().__setstate__(state)
-
- for group in self.param_groups:
- group.setdefault("nesterov", False)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-
-class LARS
-(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)
-
Extends SGD in PyTorch with LARS scaling from the paper
-Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>
_.
params
: iterable
lr
: float
momentum
: float
, optionalweight_decay
: float
, optionaldampening
: float
, optionalnesterov
: bool
, optionaltrust_coefficient
: float
, optionaleps
: float
, optional>>> model = torch.nn.Linear(10, 1)
->>> input = torch.Tensor(10)
->>> target = torch.Tensor([1.])
->>> loss_fn = lambda input, target: (input - target) ** 2
->>> #
->>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
->>> optimizer.zero_grad()
->>> loss_fn(model(input), target).backward()
->>> optimizer.step()
-
-Note
-The application of momentum in the SGD part is modified according to -the PyTorch standards. LARS scaling fits into the equation in the -following fashion.
-[ egin{aligned}
-g_{t+1} & =
-ext{lars_lr} * (eta * p_{t} + g_{t+1}), \
-v_{t+1} & = \mu * v_{t} + g_{t+1}, \
-p_{t+1} & = p_{t} -
-ext{lr} * v_{t+1},
-\end{aligned} ]
-where :math:p
, :math:g
, :math:v
, :math:\mu
and :math:eta
denote the
-parameters, gradient, velocity, momentum, and weight decay respectively.
-The :math:lars_lr
is defined by Eq. 6 in the paper.
-The Nesterov version is analogously modified.
Warning
-Parameters with weight decay set to 0 will automatically be excluded from -layer-wise LR scaling. This is to ensure consistency with papers like SimCLR -and BYOL.
-class LARS(Optimizer):
- """Extends SGD in PyTorch with LARS scaling from the paper
- `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float): learning rate
- momentum (float, optional): momentum factor (default: 0)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- dampening (float, optional): dampening for momentum (default: 0)
- nesterov (bool, optional): enables Nesterov momentum (default: False)
- trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
- eps (float, optional): eps for division denominator (default: 1e-8)
-
- Example:
- >>> model = torch.nn.Linear(10, 1)
- >>> input = torch.Tensor(10)
- >>> target = torch.Tensor([1.])
- >>> loss_fn = lambda input, target: (input - target) ** 2
- >>> #
- >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
-
- .. note::
- The application of momentum in the SGD part is modified according to
- the PyTorch standards. LARS scaling fits into the equation in the
- following fashion.
-
- .. math::
- \begin{aligned}
- g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
- v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
- p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
- \\end{aligned}
-
- where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
- parameters, gradient, velocity, momentum, and weight decay respectively.
- The :math:`lars_lr` is defined by Eq. 6 in the paper.
- The Nesterov version is analogously modified.
-
- .. warning::
- Parameters with weight decay set to 0 will automatically be excluded from
- layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
- and BYOL.
- """
-
- def __init__(
- self,
- params,
- lr=required,
- momentum=0,
- dampening=0,
- weight_decay=0,
- nesterov=False,
- trust_coefficient=0.001,
- eps=1e-8,
- ):
- if lr is not required and lr < 0.0:
- raise ValueError(f"Invalid learning rate: {lr}")
- if momentum < 0.0:
- raise ValueError(f"Invalid momentum value: {momentum}")
- if weight_decay < 0.0:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
- defaults = dict(
- lr=lr,
- momentum=momentum,
- dampening=dampening,
- weight_decay=weight_decay,
- nesterov=nesterov,
- trust_coefficient=trust_coefficient,
- eps=eps,
- )
- if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
- super().__init__(params, defaults)
-
- def __setstate__(self, state):
- super().__setstate__(state)
-
- for group in self.param_groups:
- group.setdefault("nesterov", False)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-
-def step(self, closure=None)
-
Performs a single optimization step.
-closure
: callable
, optional@torch.no_grad()
-def step(self, closure=None):
- """Performs a single optimization step.
-
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- # exclude scaling for params with 0 weight decay
- for group in self.param_groups:
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- dampening = group["dampening"]
- nesterov = group["nesterov"]
-
- for p in group["params"]:
- if p.grad is None:
- continue
-
- d_p = p.grad
- p_norm = torch.norm(p.data)
- g_norm = torch.norm(p.grad.data)
-
- # lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
- lars_lr *= group["trust_coefficient"]
-
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
-
- # sgd part
- if momentum != 0:
- param_state = self.state[p]
- if "momentum_buffer" not in param_state:
- buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
- else:
- buf = param_state["momentum_buffer"]
- buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
- if nesterov:
- d_p = d_p.add(buf, alpha=momentum)
- else:
- d_p = buf
-
- p.add_(d_p, alpha=-group["lr"])
-
- return loss
-fmcib.ssl.transforms.duplicate
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-
-
-class Duplicate:
- """Duplicate an input and apply two different transforms. Used for SimCLR primarily."""
-
- def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
- """Duplicates an input and applies the given transformations to each copy separately.
-
- Args:
- transforms1 (Optional[Callable], optional): _description_. Defaults to None.
- transforms2 (Optional[Callable], optional): _description_. Defaults to None.
- """
- # Wrapped into a list if it isn't one already to allow both a
- # list of transforms as well as `torchvision.transform.Compose` transforms.
- self.transforms1 = transforms1
- self.transforms2 = transforms2
-
- def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- input (torch.Tensor or any other type supported by the given transforms): Input.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: a tuple of two tensors.
- """
- out1, out2 = input, deepcopy(input)
- if self.transforms1 is not None:
- out1 = self.transforms1(out1)
- if self.transforms2 is not None:
- out2 = self.transforms2(out2)
- return (out1, out2)
-
-class Duplicate
-(transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None)
-
Duplicate an input and apply two different transforms. Used for SimCLR primarily.
-Duplicates an input and applies the given transformations to each copy separately.
-transforms1
: Optional[Callable]
, optionaltransforms2
: Optional[Callable]
, optionalclass Duplicate:
- """Duplicate an input and apply two different transforms. Used for SimCLR primarily."""
-
- def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
- """Duplicates an input and applies the given transformations to each copy separately.
-
- Args:
- transforms1 (Optional[Callable], optional): _description_. Defaults to None.
- transforms2 (Optional[Callable], optional): _description_. Defaults to None.
- """
- # Wrapped into a list if it isn't one already to allow both a
- # list of transforms as well as `torchvision.transform.Compose` transforms.
- self.transforms1 = transforms1
- self.transforms2 = transforms2
-
- def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- input (torch.Tensor or any other type supported by the given transforms): Input.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: a tuple of two tensors.
- """
- out1, out2 = input, deepcopy(input)
- if self.transforms1 is not None:
- out1 = self.transforms1(out1)
- if self.transforms2 is not None:
- out2 = self.transforms2(out2)
- return (out1, out2)
-fmcib.ssl.transforms
from .duplicate import Duplicate
-from .random_resized_crop import RandomResizedCrop3D
-fmcib.ssl.transforms.duplicate
fmcib.ssl.transforms.random_resized_crop
fmcib.ssl.transforms.random_resized_crop
from typing import Any, Dict, List
-
-import torch
-from monai.transforms import RandScaleCrop, Resize, Transform
-
-
-class RandomResizedCrop3D(Transform):
- """
- Combines monai's random spatial crop followed by resize to the desired size.
-
- Modification:
- 1. The spatial crop is done with same dimensions for all the axes
- 2. Handles cases where the image_size is less than the crop_size by choosing
- the smallest dimension as the random scale.
-
- """
-
- def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
- """
- Args:
- scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- """
- super().__init__()
- self.prob = prob
- self.scale = scale
- self.size = [size] * 3
-
- def __call__(self, image):
- if torch.rand(1) < self.prob:
- random_scale = torch.empty(1).uniform_(*self.scale).item()
- rand_cropper = RandScaleCrop(random_scale, random_size=False)
- resizer = Resize(self.size, mode="trilinear")
-
- for transform in [rand_cropper, resizer]:
- image = transform(image)
-
- return image
-
-class RandomResizedCrop3D
-(prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0])
-
Combines monai's random spatial crop followed by resize to the desired size.
-Modification: -1. The spatial crop is done with same dimensions for all the axes -2. Handles cases where the image_size is less than the crop_size by choosing -the smallest dimension as the random scale.
-scale
: List[int]
before resizing. The scale is defined with respect to the area of the original image.
class RandomResizedCrop3D(Transform):
- """
- Combines monai's random spatial crop followed by resize to the desired size.
-
- Modification:
- 1. The spatial crop is done with same dimensions for all the axes
- 2. Handles cases where the image_size is less than the crop_size by choosing
- the smallest dimension as the random scale.
-
- """
-
- def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
- """
- Args:
- scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- """
- super().__init__()
- self.prob = prob
- self.scale = scale
- self.size = [size] * 3
-
- def __call__(self, image):
- if torch.rand(1) < self.prob:
- random_scale = torch.empty(1).uniform_(*self.scale).item()
- rand_cropper = RandScaleCrop(random_scale, random_size=False)
- resizer = Resize(self.size, mode="trilinear")
-
- for transform in [rand_cropper, resizer]:
- image = transform(image)
-
- return image
-fmcib.transforms.duplicate
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-
-
-class Duplicate:
- """
- Duplicate an input and apply two different transforms. Used for SimCLR primarily.
- """
-
- def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
- """
- Duplicates an input and applies the given transformations to each copy separately.
-
- Args:
- transforms1 (Optional[Callable]): _description_. Default is None.
- transforms2 (Optional[Callable]): _description_. Default is None.
- """
- # Wrapped into a list if it isn't one already to allow both a
- # list of transforms as well as `torchvision.transform.Compose` transforms.
- self.transforms1 = transforms1
- self.transforms2 = transforms2
-
- def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- input (torch.Tensor or any other type supported by the given transforms): Input.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors.
- """
- out1, out2 = input, deepcopy(input)
- if self.transforms1 is not None:
- out1 = self.transforms1(out1)
- if self.transforms2 is not None:
- out2 = self.transforms2(out2)
- return (out1, out2)
-
-class Duplicate
-(transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None)
-
Duplicate an input and apply two different transforms. Used for SimCLR primarily.
-Duplicates an input and applies the given transformations to each copy separately.
-transforms1
: Optional[Callable]
transforms2
: Optional[Callable]
class Duplicate:
- """
- Duplicate an input and apply two different transforms. Used for SimCLR primarily.
- """
-
- def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
- """
- Duplicates an input and applies the given transformations to each copy separately.
-
- Args:
- transforms1 (Optional[Callable]): _description_. Default is None.
- transforms2 (Optional[Callable]): _description_. Default is None.
- """
- # Wrapped into a list if it isn't one already to allow both a
- # list of transforms as well as `torchvision.transform.Compose` transforms.
- self.transforms1 = transforms1
- self.transforms2 = transforms2
-
- def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- input (torch.Tensor or any other type supported by the given transforms): Input.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors.
- """
- out1, out2 = input, deepcopy(input)
- if self.transforms1 is not None:
- out1 = self.transforms1(out1)
- if self.transforms2 is not None:
- out2 = self.transforms2(out2)
- return (out1, out2)
-fmcib.transforms
from .duplicate import Duplicate
-from .med3d import IntensityNormalizeOneVolume
-from .multicrop import MultiCrop
-from .random_resized_crop import RandomResizedCrop3D
-fmcib.transforms.duplicate
fmcib.transforms.med3d
fmcib.transforms.multicrop
fmcib.transforms.random_resized_crop
fmcib.transforms.med3d
import numpy as np
-from monai.transforms import Transform
-
-
-class IntensityNormalizeOneVolume(Transform):
- """
- A class representing an intensity normalized volume.
-
- Attributes:
- None
-
- Methods:
- __call__(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.
-
- Args:
- volume (numpy.ndarray): The input n-dimensional volume.
-
- Returns:
- out (numpy.ndarray): The normalized n-dimensional volume.
- """
-
- def __init__(self):
- """
- Initialize the object.
-
- Returns:
- None
- """
- super().__init__()
-
- def __call__(self, volume):
- """
- Normalize the intensity of an nd volume based on the mean and std of the non-zero region.
-
- Args:
- volume: The input nd volume.
-
- Returns:
- out: The normalized nd volume.
- """
- volume = volume.astype(np.float32)
- low, high = np.percentile(volume, [0.5, 99.5])
- if high > 0:
- volume = np.clip(volume, low, high)
-
- pixels = volume[volume > 0]
- mean = pixels.mean()
- std = pixels.std()
- out = (volume - mean) / std
- out_random = np.random.normal(0, 1, size=volume.shape)
- out[volume == 0] = out_random[volume == 0]
- return out
-
-class IntensityNormalizeOneVolume
-
A class representing an intensity normalized volume.
-None
-call(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.
-Args: -volume (numpy.ndarray): The input n-dimensional volume.
-Returns: -out (numpy.ndarray): The normalized n-dimensional volume.
-Initialize the object.
-None
class IntensityNormalizeOneVolume(Transform):
- """
- A class representing an intensity normalized volume.
-
- Attributes:
- None
-
- Methods:
- __call__(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.
-
- Args:
- volume (numpy.ndarray): The input n-dimensional volume.
-
- Returns:
- out (numpy.ndarray): The normalized n-dimensional volume.
- """
-
- def __init__(self):
- """
- Initialize the object.
-
- Returns:
- None
- """
- super().__init__()
-
- def __call__(self, volume):
- """
- Normalize the intensity of an nd volume based on the mean and std of the non-zero region.
-
- Args:
- volume: The input nd volume.
-
- Returns:
- out: The normalized nd volume.
- """
- volume = volume.astype(np.float32)
- low, high = np.percentile(volume, [0.5, 99.5])
- if high > 0:
- volume = np.clip(volume, low, high)
-
- pixels = volume[volume > 0]
- mean = pixels.mean()
- std = pixels.std()
- out = (volume - mean) / std
- out_random = np.random.normal(0, 1, size=volume.shape)
- out[volume == 0] = out_random[volume == 0]
- return out
-var backend : list[TransformBackends]
fmcib.transforms.multicrop
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-from lighter.utils.misc import ensure_list
-
-
-class MultiCrop:
- """
- Multi-Crop augmentation.
- """
-
- def __init__(self, high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]]):
- """
- Initialize an instance of a class with transformations for high-resolution and low-resolution images.
-
- Args:
- high_resolution_transforms (list): A list of Callable objects representing the transformations to be applied to high-resolution images.
- low_resolution_transforms (list, optional): A list of Callable objects representing the transformations to be applied to low-resolution images. Default is None.
- """
- self.high_resolution_transforms = ensure_list(high_resolution_transforms)
- self.low_resolution_transforms = ensure_list(low_resolution_transforms)
-
- def __call__(self, input):
- """
- This function applies a set of transformations to an input image and returns high and low-resolution crops.
-
- Args:
- input (image): The input image to be transformed.
-
- Returns:
- tuple: A tuple containing two lists:
- - high_resolution_crops (list): A list of high-resolution cropped images.
- - low_resolution_crops (list): A list of low-resolution cropped images.
- """
- high_resolution_crops = [transform(input) for transform in self.high_resolution_transforms]
- low_resolution_crops = [transform(input) for transform in self.low_resolution_transforms]
- return high_resolution_crops, low_resolution_crops
-
-class MultiCrop
-(high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]])
-
Multi-Crop augmentation.
-Initialize an instance of a class with transformations for high-resolution and low-resolution images.
-high_resolution_transforms
: list
low_resolution_transforms
: list
, optionalclass MultiCrop:
- """
- Multi-Crop augmentation.
- """
-
- def __init__(self, high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]]):
- """
- Initialize an instance of a class with transformations for high-resolution and low-resolution images.
-
- Args:
- high_resolution_transforms (list): A list of Callable objects representing the transformations to be applied to high-resolution images.
- low_resolution_transforms (list, optional): A list of Callable objects representing the transformations to be applied to low-resolution images. Default is None.
- """
- self.high_resolution_transforms = ensure_list(high_resolution_transforms)
- self.low_resolution_transforms = ensure_list(low_resolution_transforms)
-
- def __call__(self, input):
- """
- This function applies a set of transformations to an input image and returns high and low-resolution crops.
-
- Args:
- input (image): The input image to be transformed.
-
- Returns:
- tuple: A tuple containing two lists:
- - high_resolution_crops (list): A list of high-resolution cropped images.
- - low_resolution_crops (list): A list of low-resolution cropped images.
- """
- high_resolution_crops = [transform(input) for transform in self.high_resolution_transforms]
- low_resolution_crops = [transform(input) for transform in self.low_resolution_transforms]
- return high_resolution_crops, low_resolution_crops
-fmcib.transforms.random_resized_crop
from typing import Any, Dict, List
-
-import torch
-from monai.transforms import RandScaleCrop, Resize, Transform
-
-
-class RandomResizedCrop3D(Transform):
- """
- Combines monai's random spatial crop followed by resize to the desired size.
-
- Modifications:
- 1. The spatial crop is done with the same dimensions for all the axes.
- 2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.
- """
-
- def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
- """
- Args:
- scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- """
- super().__init__()
- self.prob = prob
- self.scale = scale
- self.size = [size] * 3
-
- def __call__(self, image):
- """
- Call method to apply random scale cropping and resizing to an image.
-
- Args:
- image (torch.Tensor): The input image.
-
- Returns:
- torch.Tensor: The transformed image.
- """
- if torch.rand(1) < self.prob:
- random_scale = torch.empty(1).uniform_(*self.scale).item()
- rand_cropper = RandScaleCrop(random_scale, random_size=False)
- resizer = Resize(self.size, mode="trilinear")
-
- for transform in [rand_cropper, resizer]:
- image = transform(image)
-
- return image
-
-class RandomResizedCrop3D
-(prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0])
-
Combines monai's random spatial crop followed by resize to the desired size.
-Modifications: -1. The spatial crop is done with the same dimensions for all the axes. -2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.
-scale
: List[int]
before resizing. The scale is defined with respect to the area of the original image.
class RandomResizedCrop3D(Transform):
- """
- Combines monai's random spatial crop followed by resize to the desired size.
-
- Modifications:
- 1. The spatial crop is done with the same dimensions for all the axes.
- 2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.
- """
-
- def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
- """
- Args:
- scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- """
- super().__init__()
- self.prob = prob
- self.scale = scale
- self.size = [size] * 3
-
- def __call__(self, image):
- """
- Call method to apply random scale cropping and resizing to an image.
-
- Args:
- image (torch.Tensor): The input image.
-
- Returns:
- torch.Tensor: The transformed image.
- """
- if torch.rand(1) < self.prob:
- random_scale = torch.empty(1).uniform_(*self.scale).item()
- rand_cropper = RandScaleCrop(random_scale, random_size=False)
- resizer = Resize(self.size, mode="trilinear")
-
- for transform in [rand_cropper, resizer]:
- image = transform(image)
-
- return image
-var backend : list[TransformBackends]
fmcib.utils.download_utils
import sys
-
-
-# create this bar_progress method which is invoked automatically from wget
-def bar_progress(current, total, width=80):
- """
- Display a progress bar for a download.
-
- Args:
- current (int): The current progress value.
- total (int): The total progress value.
- width (int, optional): The width of the progress bar in characters. Defaults to 80.
-
- Raises:
- None
-
- Returns:
- None
- """
- progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
- # Don't use print() as it will print in new line every time.
- sys.stdout.write("\r" + progress_message)
- sys.stdout.flush()
-
-def bar_progress(current, total, width=80)
-
Display a progress bar for a download.
-current
: int
total
: int
width
: int
, optionalNone
-None
def bar_progress(current, total, width=80):
- """
- Display a progress bar for a download.
-
- Args:
- current (int): The current progress value.
- total (int): The total progress value.
- width (int, optional): The width of the progress bar in characters. Defaults to 80.
-
- Raises:
- None
-
- Returns:
- None
- """
- progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
- # Don't use print() as it will print in new line every time.
- sys.stdout.write("\r" + progress_message)
- sys.stdout.flush()
-fmcib.utils.idc_helper
import concurrent.futures
-import os
-import sys
-from pathlib import Path
-
-import numpy as np
-import pandas as pd
-import pydicom
-import pydicom_seg
-import SimpleITK as sitk
-import wget
-
-
-class SuppressPrint:
- """
- A class that temporarily suppresses print statements.
-
- Methods:
- __enter__(): Sets sys.stdout to a dummy file object, suppressing print output.
- __exit__(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.
- """
-
- def __enter__(self):
- """
- Enter the context manager and redirect the standard output to nothing.
-
- Returns:
- object: The context manager object.
-
- Notes:
- This context manager is used to redirect the standard output to nothing using the `open` function.
- It saves the original standard output and assigns a new output destination as `/dev/null` on Unix-like systems.
- """
- self._original_stdout = sys.stdout
- sys.stdout = open(os.devnull, "w")
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Restores the original stdout and closes the modified stdout.
-
- Args:
- exc_type (type): The exception type, if an exception occurred. Otherwise, None.
- exc_val (Exception): The exception instance, if an exception occurred. Otherwise, None.
- exc_tb (traceback): The traceback object, if an exception occurred. Otherwise, None.
-
- Returns:
- None
-
- Raises:
- None
- """
- sys.stdout.close()
- sys.stdout = self._original_stdout
-
-
-with SuppressPrint():
- from dcmrtstruct2nii import dcmrtstruct2nii
- from dcmrtstruct2nii.adapters.input.image.dcminputadapter import DcmInputAdapter
- from dcmrtstruct2nii.adapters.output.niioutputadapter import NiiOutputAdapter
-
-from google.cloud import storage
-from loguru import logger
-from tqdm import tqdm
-
-
-def dcmseg2nii(dcmseg_path, output_dir, tag=""):
- """
- Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.
-
- Args:
- dcmseg_path (str): The file path of the DICOM Segmentation object.
- output_dir (str): The directory where the NIfTI files will be saved.
- tag (str, optional): An optional tag to prepend to the output file names. Defaults to "".
- """
- dcm = pydicom.dcmread(dcmseg_path)
- reader = pydicom_seg.SegmentReader()
- result = reader.read(dcm)
-
- for segment_number in result.available_segments:
- image = result.segment_image(segment_number) # lazy construction
- sitk.WriteImage(image, output_dir + f"/{tag}{segment_number}.nii.gz", True)
-
-
-def download_from_manifest(df, save_dir, samples):
- """
- Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.
-
- Parameters:
- df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files.
- save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved.
- samples (int): The number of random samples to download. If None, all available samples will be downloaded.
-
- Returns:
- None
- """
- # Instantiates a client
- storage_client = storage.Client()
- logger.info("Downloading DICOM data from IDC (Imaging Data Commons) ...")
- (save_dir / "dicom").mkdir(exist_ok=True, parents=True)
-
- if samples is not None:
- assert "PatientID" in df.columns
- rows_with_annotations = df[df["Modality"].isin(["RTSTRUCT", "SEG"])]
- unique_elements = rows_with_annotations["PatientID"].unique()
- selected_elements = np.random.choice(unique_elements, min(len(unique_elements), samples), replace=False)
- df = df[df["PatientID"].isin(selected_elements)]
-
- def download_file(row):
- """
- Download a file from Google Cloud Storage.
-
- Args:
- row (dict): A dictionary containing the row data.
-
- Raises:
- None
-
- Returns:
- None
- """
- bucket_name, directory, file = row["gcs_url"].split("/")[-3:]
- fn = f"{directory}/{file}"
- bucket = storage_client.bucket(bucket_name)
- blob = bucket.blob(fn)
-
- current_save_dir = save_dir / "dicom" / row["PatientID"] / row["StudyInstanceUID"]
- current_save_dir.mkdir(exist_ok=True, parents=True)
- blob.download_to_filename(
- str(current_save_dir / f'{row["Modality"]}_{row["SeriesInstanceUID"]}_{row["InstanceNumber"]}.dcm')
- )
-
- with concurrent.futures.ThreadPoolExecutor() as executor:
- futures = []
- for idx, row in df.iterrows():
- futures.append(executor.submit(download_file, row))
- for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
- pass
-
-
-def download_LUNG1(path, samples=None):
- """
- Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.
-
- Parameters:
- path (str): The directory path where the LUNG1 data manifest will be saved.
- samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.
-
- Returns:
- None
- """
- save_dir = Path(path).resolve()
- save_dir.mkdir(exist_ok=True, parents=True)
-
- logger.info("Downloading LUNG1 manifest from Dropbox ...")
- # Download LUNG1 data manifest, this is precomputed but any set of GCS dicom files can be used here
- wget.download(
- "https://www.dropbox.com/s/lkvv33nmepecyu5/nsclc_radiomics.csv?dl=1",
- out=f"{save_dir}/nsclc_radiomics.csv",
- )
-
- df = pd.read_csv(f"{save_dir}/nsclc_radiomics.csv")
-
- download_from_manifest(df, save_dir, samples)
-
-
-def download_RADIO(path, samples=None):
- """
- Downloads the RADIO manifest from Dropbox and saves it to the specified path.
-
- Args:
- path (str): The path where the manifest file will be saved.
- samples (list, optional): A list of sample names to download. If None, all samples will be downloaded.
-
- Returns:
- None
- """
- save_dir = Path(path).resolve()
- save_dir.mkdir(exist_ok=True, parents=True)
-
- logger.info("Downloading RADIO manifest from Dropbox ...")
- # Download RADIO data manifest, this is precomputed but any set of GCS dicom files can be used here
- wget.download(
- "https://www.dropbox.com/s/nhh1tb0rclrb7mw/nsclc_radiogenomics.csv?dl=1",
- out=f"{save_dir}/nsclc_radiogenomics.csv",
- )
-
- df = pd.read_csv(f"{save_dir}/nsclc_radiogenomics.csv")
-
- download_from_manifest(df, save_dir, samples)
-
-
-def process_series_dir(series_dir):
- """
- Process the series directory and extract relevant information.
-
- Args:
- series_dir (Path): The path to the series directory.
-
- Returns:
- dict: A dictionary containing the extracted information, including the image path, patient ID, and coordinates.
-
- Raises:
- None
- """
- # Check if RTSTRUCT file exists
- rtstuct_files = list(series_dir.glob("*RTSTRUCT*"))
- seg_files = list(series_dir.glob("*SEG*"))
-
- if len(rtstuct_files) != 0:
- dcmrtstruct2nii(str(rtstuct_files[0]), str(series_dir), str(series_dir))
-
- elif len(seg_files) != 0:
- dcmseg2nii(str(seg_files[0]), str(series_dir), tag="GTV-")
-
- series_id = str(list(series_dir.glob("CT*.dcm"))[0]).split("_")[-2]
- dicom_image = DcmInputAdapter().ingest(str(series_dir), series_id=series_id)
- nii_output_adapter = NiiOutputAdapter()
- nii_output_adapter.write(dicom_image, f"{series_dir}/image", gzip=True)
- else:
- logger.warning("Skipped file without any RTSTRUCT or SEG file")
- return None
-
- image = sitk.ReadImage(str(series_dir / "image.nii.gz"))
- mask = sitk.ReadImage(str(list(series_dir.glob("*GTV-1*"))[0]))
-
- # Get centroid from label shape filter
- label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
- label_shape_filter.Execute(mask)
-
- try:
- centroid = label_shape_filter.GetCentroid(255)
- except:
- centroid = label_shape_filter.GetCentroid(1)
-
- x, y, z = centroid
-
- row = {
- "image_path": str(series_dir / "image.nii.gz"),
- "PatientID": series_dir.parent.name,
- "coordX": x,
- "coordY": y,
- "coordZ": z,
- }
-
- return row
-
-
-def build_image_seed_dict(path, samples=None):
- """
- Build a dictionary of image seeds from DICOM files.
-
- Args:
- path (str): The path to the directory containing DICOM files.
- samples (int, optional): The number of samples to process. If None, all samples will be processed.
-
- Returns:
- pd.DataFrame: A DataFrame containing the image seeds.
- """
- sorted_dir = Path(path).resolve()
- series_dirs = [x.parent for x in sorted_dir.rglob("*.dcm")]
- series_dirs = sorted(list(set(series_dirs)))
-
- logger.info("Converting DICOM files to NIFTI ...")
-
- if samples is None:
- samples = len(series_dirs)
-
- rows = []
-
- num_workers = os.cpu_count() # Adjust this value based on the number of available CPU cores
- with concurrent.futures.ProcessPoolExecutor(num_workers) as executor:
- processed_rows = list(tqdm(executor.map(process_series_dir, series_dirs[:samples]), total=samples))
-
- rows = [row for row in processed_rows if row]
- return pd.DataFrame(rows)
-
-def build_image_seed_dict(path, samples=None)
-
Build a dictionary of image seeds from DICOM files.
-path
: str
samples
: int
, optionalpd.DataFrame
def build_image_seed_dict(path, samples=None):
- """
- Build a dictionary of image seeds from DICOM files.
-
- Args:
- path (str): The path to the directory containing DICOM files.
- samples (int, optional): The number of samples to process. If None, all samples will be processed.
-
- Returns:
- pd.DataFrame: A DataFrame containing the image seeds.
- """
- sorted_dir = Path(path).resolve()
- series_dirs = [x.parent for x in sorted_dir.rglob("*.dcm")]
- series_dirs = sorted(list(set(series_dirs)))
-
- logger.info("Converting DICOM files to NIFTI ...")
-
- if samples is None:
- samples = len(series_dirs)
-
- rows = []
-
- num_workers = os.cpu_count() # Adjust this value based on the number of available CPU cores
- with concurrent.futures.ProcessPoolExecutor(num_workers) as executor:
- processed_rows = list(tqdm(executor.map(process_series_dir, series_dirs[:samples]), total=samples))
-
- rows = [row for row in processed_rows if row]
- return pd.DataFrame(rows)
-
-def dcmseg2nii(dcmseg_path, output_dir, tag='')
-
Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.
-dcmseg_path
: str
output_dir
: str
tag
: str
, optionaldef dcmseg2nii(dcmseg_path, output_dir, tag=""):
- """
- Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.
-
- Args:
- dcmseg_path (str): The file path of the DICOM Segmentation object.
- output_dir (str): The directory where the NIfTI files will be saved.
- tag (str, optional): An optional tag to prepend to the output file names. Defaults to "".
- """
- dcm = pydicom.dcmread(dcmseg_path)
- reader = pydicom_seg.SegmentReader()
- result = reader.read(dcm)
-
- for segment_number in result.available_segments:
- image = result.segment_image(segment_number) # lazy construction
- sitk.WriteImage(image, output_dir + f"/{tag}{segment_number}.nii.gz", True)
-
-def download_LUNG1(path, samples=None)
-
Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.
-path (str): The directory path where the LUNG1 data manifest will be saved. -samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.
-None
def download_LUNG1(path, samples=None):
- """
- Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.
-
- Parameters:
- path (str): The directory path where the LUNG1 data manifest will be saved.
- samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.
-
- Returns:
- None
- """
- save_dir = Path(path).resolve()
- save_dir.mkdir(exist_ok=True, parents=True)
-
- logger.info("Downloading LUNG1 manifest from Dropbox ...")
- # Download LUNG1 data manifest, this is precomputed but any set of GCS dicom files can be used here
- wget.download(
- "https://www.dropbox.com/s/lkvv33nmepecyu5/nsclc_radiomics.csv?dl=1",
- out=f"{save_dir}/nsclc_radiomics.csv",
- )
-
- df = pd.read_csv(f"{save_dir}/nsclc_radiomics.csv")
-
- download_from_manifest(df, save_dir, samples)
-
-def download_RADIO(path, samples=None)
-
Downloads the RADIO manifest from Dropbox and saves it to the specified path.
-path
: str
samples
: list
, optionalNone
def download_RADIO(path, samples=None):
- """
- Downloads the RADIO manifest from Dropbox and saves it to the specified path.
-
- Args:
- path (str): The path where the manifest file will be saved.
- samples (list, optional): A list of sample names to download. If None, all samples will be downloaded.
-
- Returns:
- None
- """
- save_dir = Path(path).resolve()
- save_dir.mkdir(exist_ok=True, parents=True)
-
- logger.info("Downloading RADIO manifest from Dropbox ...")
- # Download RADIO data manifest, this is precomputed but any set of GCS dicom files can be used here
- wget.download(
- "https://www.dropbox.com/s/nhh1tb0rclrb7mw/nsclc_radiogenomics.csv?dl=1",
- out=f"{save_dir}/nsclc_radiogenomics.csv",
- )
-
- df = pd.read_csv(f"{save_dir}/nsclc_radiogenomics.csv")
-
- download_from_manifest(df, save_dir, samples)
-
-def download_from_manifest(df, save_dir, samples)
-
Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.
-df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files. -save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved. -samples (int): The number of random samples to download. If None, all available samples will be downloaded.
-None
def download_from_manifest(df, save_dir, samples):
- """
- Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.
-
- Parameters:
- df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files.
- save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved.
- samples (int): The number of random samples to download. If None, all available samples will be downloaded.
-
- Returns:
- None
- """
- # Instantiates a client
- storage_client = storage.Client()
- logger.info("Downloading DICOM data from IDC (Imaging Data Commons) ...")
- (save_dir / "dicom").mkdir(exist_ok=True, parents=True)
-
- if samples is not None:
- assert "PatientID" in df.columns
- rows_with_annotations = df[df["Modality"].isin(["RTSTRUCT", "SEG"])]
- unique_elements = rows_with_annotations["PatientID"].unique()
- selected_elements = np.random.choice(unique_elements, min(len(unique_elements), samples), replace=False)
- df = df[df["PatientID"].isin(selected_elements)]
-
- def download_file(row):
- """
- Download a file from Google Cloud Storage.
-
- Args:
- row (dict): A dictionary containing the row data.
-
- Raises:
- None
-
- Returns:
- None
- """
- bucket_name, directory, file = row["gcs_url"].split("/")[-3:]
- fn = f"{directory}/{file}"
- bucket = storage_client.bucket(bucket_name)
- blob = bucket.blob(fn)
-
- current_save_dir = save_dir / "dicom" / row["PatientID"] / row["StudyInstanceUID"]
- current_save_dir.mkdir(exist_ok=True, parents=True)
- blob.download_to_filename(
- str(current_save_dir / f'{row["Modality"]}_{row["SeriesInstanceUID"]}_{row["InstanceNumber"]}.dcm')
- )
-
- with concurrent.futures.ThreadPoolExecutor() as executor:
- futures = []
- for idx, row in df.iterrows():
- futures.append(executor.submit(download_file, row))
- for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
- pass
-
-def process_series_dir(series_dir)
-
Process the series directory and extract relevant information.
-series_dir
: Path
dict
None
def process_series_dir(series_dir):
- """
- Process the series directory and extract relevant information.
-
- Args:
- series_dir (Path): The path to the series directory.
-
- Returns:
- dict: A dictionary containing the extracted information, including the image path, patient ID, and coordinates.
-
- Raises:
- None
- """
- # Check if RTSTRUCT file exists
- rtstuct_files = list(series_dir.glob("*RTSTRUCT*"))
- seg_files = list(series_dir.glob("*SEG*"))
-
- if len(rtstuct_files) != 0:
- dcmrtstruct2nii(str(rtstuct_files[0]), str(series_dir), str(series_dir))
-
- elif len(seg_files) != 0:
- dcmseg2nii(str(seg_files[0]), str(series_dir), tag="GTV-")
-
- series_id = str(list(series_dir.glob("CT*.dcm"))[0]).split("_")[-2]
- dicom_image = DcmInputAdapter().ingest(str(series_dir), series_id=series_id)
- nii_output_adapter = NiiOutputAdapter()
- nii_output_adapter.write(dicom_image, f"{series_dir}/image", gzip=True)
- else:
- logger.warning("Skipped file without any RTSTRUCT or SEG file")
- return None
-
- image = sitk.ReadImage(str(series_dir / "image.nii.gz"))
- mask = sitk.ReadImage(str(list(series_dir.glob("*GTV-1*"))[0]))
-
- # Get centroid from label shape filter
- label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
- label_shape_filter.Execute(mask)
-
- try:
- centroid = label_shape_filter.GetCentroid(255)
- except:
- centroid = label_shape_filter.GetCentroid(1)
-
- x, y, z = centroid
-
- row = {
- "image_path": str(series_dir / "image.nii.gz"),
- "PatientID": series_dir.parent.name,
- "coordX": x,
- "coordY": y,
- "coordZ": z,
- }
-
- return row
-
-class SuppressPrint
-
A class that temporarily suppresses print statements.
-enter(): Sets sys.stdout to a dummy file object, suppressing print output. -exit(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.
class SuppressPrint:
- """
- A class that temporarily suppresses print statements.
-
- Methods:
- __enter__(): Sets sys.stdout to a dummy file object, suppressing print output.
- __exit__(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.
- """
-
- def __enter__(self):
- """
- Enter the context manager and redirect the standard output to nothing.
-
- Returns:
- object: The context manager object.
-
- Notes:
- This context manager is used to redirect the standard output to nothing using the `open` function.
- It saves the original standard output and assigns a new output destination as `/dev/null` on Unix-like systems.
- """
- self._original_stdout = sys.stdout
- sys.stdout = open(os.devnull, "w")
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Restores the original stdout and closes the modified stdout.
-
- Args:
- exc_type (type): The exception type, if an exception occurred. Otherwise, None.
- exc_val (Exception): The exception instance, if an exception occurred. Otherwise, None.
- exc_tb (traceback): The traceback object, if an exception occurred. Otherwise, None.
-
- Returns:
- None
-
- Raises:
- None
- """
- sys.stdout.close()
- sys.stdout = self._original_stdout
-fmcib.utils
from .download_utils import bar_progress
-from .idc_helper import *
-fmcib.utils.download_utils
fmcib.utils.idc_helper
fmcib.visualization
from .verify_io import visualize_seed_point
-fmcib.visualization.verify_io
fmcib.visualization.verify_io
import matplotlib.pyplot as plt
-import monai.transforms as monai_transforms
-import numpy as np
-import torch
-from monai.visualize import blend_images
-
-
-def visualize_seed_point(row):
- """
- This function visualizes a seed point on an image.
-
- Args:
- row (pandas.Series): A row containing the information of the seed point, including the image path and the coordinates.
- The following columns are expected: "image_path", "coordX", "coordY", "coordZ".
-
- Returns:
- None
- """
- # Define the transformation pipeline
- T = monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
- monai_transforms.ScaleIntensityRanged(keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True),
- monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
- monai_transforms.SelectItemsd(keys=["image_path", "coordX", "coordY", "coordZ"]),
- ]
- )
-
- # Apply the transformation pipeline
- out = T(row)
-
- # Calculate the center of the image
- center = (-out["coordX"], -out["coordY"], out["coordZ"])
- center = np.linalg.inv(np.array(out["image_path"].affine)) @ np.array(center + (1,))
- center = [int(x) for x in center[:3]]
-
- # Define the image and label
- image = out["image_path"]
- label = torch.zeros_like(image)
-
- # Define the dimensions of the image and the patch
- C, H, W, D = image.shape
- Ph, Pw, Pd = 50, 50, 50
-
- # Calculate and clamp the ranges for cropping
- min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
- min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
- min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
- # Check if coordinates are valid
- assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
- assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
- assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
- # Define the label for the cropped region
- label[:, min_h:max_h, min_w:max_w, min_d:max_d] = 1
-
- # Blend the image and the label
- ret = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
- ret = ret.permute(3, 2, 1, 0)
-
- # Plot axial slice
- plt.figure(figsize=(10, 10))
- plt.subplot(1, 3, 1)
- plt.imshow(ret[center[2], :, :])
- plt.title("Axial")
- plt.axis("off")
-
- # Plot sagittal slice
- plt.subplot(1, 3, 2)
- plt.imshow(np.flipud(ret[:, center[1], :]))
- plt.title("Coronal")
- plt.axis("off")
-
- # Plot coronal slice
- plt.subplot(1, 3, 3)
- plt.imshow(np.flipud(ret[:, :, center[0]]))
- plt.title("Sagittal")
-
- plt.axis("off")
- plt.show()
-
-def visualize_seed_point(row)
-
This function visualizes a seed point on an image.
-row
: pandas.Series
None
def visualize_seed_point(row):
- """
- This function visualizes a seed point on an image.
-
- Args:
- row (pandas.Series): A row containing the information of the seed point, including the image path and the coordinates.
- The following columns are expected: "image_path", "coordX", "coordY", "coordZ".
-
- Returns:
- None
- """
- # Define the transformation pipeline
- T = monai_transforms.Compose(
- [
- monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
- monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
- monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
- monai_transforms.ScaleIntensityRanged(keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True),
- monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
- monai_transforms.SelectItemsd(keys=["image_path", "coordX", "coordY", "coordZ"]),
- ]
- )
-
- # Apply the transformation pipeline
- out = T(row)
-
- # Calculate the center of the image
- center = (-out["coordX"], -out["coordY"], out["coordZ"])
- center = np.linalg.inv(np.array(out["image_path"].affine)) @ np.array(center + (1,))
- center = [int(x) for x in center[:3]]
-
- # Define the image and label
- image = out["image_path"]
- label = torch.zeros_like(image)
-
- # Define the dimensions of the image and the patch
- C, H, W, D = image.shape
- Ph, Pw, Pd = 50, 50, 50
-
- # Calculate and clamp the ranges for cropping
- min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
- min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
- min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
- # Check if coordinates are valid
- assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
- assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
- assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
- # Define the label for the cropped region
- label[:, min_h:max_h, min_w:max_w, min_d:max_d] = 1
-
- # Blend the image and the label
- ret = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
- ret = ret.permute(3, 2, 1, 0)
-
- # Plot axial slice
- plt.figure(figsize=(10, 10))
- plt.subplot(1, 3, 1)
- plt.imshow(ret[center[2], :, :])
- plt.title("Axial")
- plt.axis("off")
-
- # Plot sagittal slice
- plt.subplot(1, 3, 2)
- plt.imshow(np.flipud(ret[:, center[1], :]))
- plt.title("Coronal")
- plt.axis("off")
-
- # Plot coronal slice
- plt.subplot(1, 3, 3)
- plt.imshow(np.flipud(ret[:, :, center[0]]))
- plt.title("Sagittal")
-
- plt.axis("off")
- plt.show()
-