diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 44ee607..a98c6de 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,5 +25,5 @@ jobs: path: .cache restore-keys: | mkdocs-material- - - run: pip install mkdocs-material + - run: pip install mkdocs-material mkdocstrings mkdocstrings-python mkdocs-section-index mkdocs-literate-nav mkdocs-gen-files mkdocs-autorefs - run: mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index f4f09ed..05cbbc2 100644 --- a/.gitignore +++ b/.gitignore @@ -634,3 +634,4 @@ checkpoints/ data/csvs/ ./models ./outputs +docs/reference/* \ No newline at end of file diff --git a/docs/api_docs/fmcib/callbacks/index.html b/docs/api_docs/fmcib/callbacks/index.html deleted file mode 100644 index fa6bad3..0000000 --- a/docs/api_docs/fmcib/callbacks/index.html +++ /dev/null @@ -1,76 +0,0 @@ - - - - - - -fmcib.callbacks API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.callbacks

-
-
-
- -Expand source code - -
from .prediction_saver import SavePredictions
-
-
-
-

Sub-modules

-
-
fmcib.callbacks.prediction_saver
-
-
-
-
fmcib.callbacks.utils
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/callbacks/prediction_saver.html b/docs/api_docs/fmcib/callbacks/prediction_saver.html deleted file mode 100644 index 6e8811f..0000000 --- a/docs/api_docs/fmcib/callbacks/prediction_saver.html +++ /dev/null @@ -1,437 +0,0 @@ - - - - - - -fmcib.callbacks.prediction_saver API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.callbacks.prediction_saver

-
-
-
- -Expand source code - -
from typing import Any, List
-
-from pathlib import Path
-
-import pandas as pd
-import torchvision
-from loguru import logger
-from pytorch_lightning.callbacks import BasePredictionWriter
-
-from .utils import decollate, handle_image
-
-
-class SavePredictions(BasePredictionWriter):
-    """
-    A class that saves model predictions.
-
-    Attributes:
-        path (str): The path to save the output CSV file.
-        save_preview_samples (bool): If True, save preview images.
-        keys (List[str]): A list of keys.
-    """
-
-    def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
-        """
-        Initialize an instance of the class.
-
-        Args:
-            path (str): The path to save the output CSV file.
-            save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
-            keys (List[str], optional): A list of keys. Defaults to None.
-
-        Raises:
-            None
-
-        Returns:
-            None
-        """
-        super().__init__("epoch")
-        self.output_csv = Path(path)
-        self.keys = keys
-        self.save_preview_samples = save_preview_samples
-        self.output_csv.parent.mkdir(parents=True, exist_ok=True)
-
-    def save_preview_image(self, data, tag):
-        """
-        Save a preview image to a specified directory.
-
-        Args:
-            self (object): The object calling the function. (self in Python)
-            data (tuple): A tuple containing the image data and its corresponding tag.
-            tag (str): The tag for the image.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
-        self.output_dir.mkdir(parents=True, exist_ok=True)
-        image, _ = data
-        image = handle_image(image)
-        fp = self.output_dir / f"{tag}.png"
-        torchvision.utils.save_image(image, fp)
-
-    def write_on_epoch_end(
-        self,
-        trainer,
-        pl_module: "LightningModule",
-        predictions: List[Any],
-        batch_indices: List[Any],
-    ):
-        """
-        Write predictions on epoch end.
-
-        Args:
-            self: The instance of the class.
-            trainer: The trainer object.
-            pl_module (LightningModule): The Lightning module.
-            predictions (List[Any]): A list of prediction values.
-            batch_indices (List[Any]): A list of batch indices.
-
-        Raises:
-            AssertionError: If 'predict' is not present in pl_module.datasets.
-            AssertionError: If 'data' is not defined in pl_module.datasets.
-
-        Returns:
-            None
-        """
-        rows = []
-        assert "predict" in pl_module.datasets, "`data` not defined"
-        dataset = pl_module.datasets["predict"]
-        predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
-        for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
-            for i, v in enumerate(pred):
-                row[f"pred_{i}"] = v.item()
-
-            rows.append(row)
-
-            # Save image previews
-            if idx <= self.save_preview_samples:
-                input = dataset[idx]
-                self.save_preview_image(input, idx)
-
-        df = pd.DataFrame(rows)
-        df.to_csv(self.output_csv)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SavePredictions -(path: str, save_preview_samples: bool = False, keys: List[str] = None) -
-
-

A class that saves model predictions.

-

Attributes

-
-
path : str
-
The path to save the output CSV file.
-
save_preview_samples : bool
-
If True, save preview images.
-
keys : List[str]
-
A list of keys.
-
-

Initialize an instance of the class.

-

Args

-
-
path : str
-
The path to save the output CSV file.
-
save_preview_samples : bool, optional
-
A flag indicating whether to save preview samples. Defaults to False.
-
keys : List[str], optional
-
A list of keys. Defaults to None.
-
-

Raises

-

None

-

Returns

-

None

-
- -Expand source code - -
class SavePredictions(BasePredictionWriter):
-    """
-    A class that saves model predictions.
-
-    Attributes:
-        path (str): The path to save the output CSV file.
-        save_preview_samples (bool): If True, save preview images.
-        keys (List[str]): A list of keys.
-    """
-
-    def __init__(self, path: str, save_preview_samples: bool = False, keys: List[str] = None):
-        """
-        Initialize an instance of the class.
-
-        Args:
-            path (str): The path to save the output CSV file.
-            save_preview_samples (bool, optional): A flag indicating whether to save preview samples. Defaults to False.
-            keys (List[str], optional): A list of keys. Defaults to None.
-
-        Raises:
-            None
-
-        Returns:
-            None
-        """
-        super().__init__("epoch")
-        self.output_csv = Path(path)
-        self.keys = keys
-        self.save_preview_samples = save_preview_samples
-        self.output_csv.parent.mkdir(parents=True, exist_ok=True)
-
-    def save_preview_image(self, data, tag):
-        """
-        Save a preview image to a specified directory.
-
-        Args:
-            self (object): The object calling the function. (self in Python)
-            data (tuple): A tuple containing the image data and its corresponding tag.
-            tag (str): The tag for the image.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
-        self.output_dir.mkdir(parents=True, exist_ok=True)
-        image, _ = data
-        image = handle_image(image)
-        fp = self.output_dir / f"{tag}.png"
-        torchvision.utils.save_image(image, fp)
-
-    def write_on_epoch_end(
-        self,
-        trainer,
-        pl_module: "LightningModule",
-        predictions: List[Any],
-        batch_indices: List[Any],
-    ):
-        """
-        Write predictions on epoch end.
-
-        Args:
-            self: The instance of the class.
-            trainer: The trainer object.
-            pl_module (LightningModule): The Lightning module.
-            predictions (List[Any]): A list of prediction values.
-            batch_indices (List[Any]): A list of batch indices.
-
-        Raises:
-            AssertionError: If 'predict' is not present in pl_module.datasets.
-            AssertionError: If 'data' is not defined in pl_module.datasets.
-
-        Returns:
-            None
-        """
-        rows = []
-        assert "predict" in pl_module.datasets, "`data` not defined"
-        dataset = pl_module.datasets["predict"]
-        predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
-        for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
-            for i, v in enumerate(pred):
-                row[f"pred_{i}"] = v.item()
-
-            rows.append(row)
-
-            # Save image previews
-            if idx <= self.save_preview_samples:
-                input = dataset[idx]
-                self.save_preview_image(input, idx)
-
-        df = pd.DataFrame(rows)
-        df.to_csv(self.output_csv)
-
-

Ancestors

-
    -
  • pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter
  • -
  • pytorch_lightning.callbacks.callback.Callback
  • -
-

Methods

-
-
-def save_preview_image(self, data, tag) -
-
-

Save a preview image to a specified directory.

-

Args

-
-
self : object
-
The object calling the function. (self in Python)
-
data : tuple
-
A tuple containing the image data and its corresponding tag.
-
tag : str
-
The tag for the image.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
def save_preview_image(self, data, tag):
-    """
-    Save a preview image to a specified directory.
-
-    Args:
-        self (object): The object calling the function. (self in Python)
-        data (tuple): A tuple containing the image data and its corresponding tag.
-        tag (str): The tag for the image.
-
-    Returns:
-        None
-
-    Raises:
-        None
-    """
-    self.output_dir = self.output_csv.parent / f"previews_{self.output_csv.stem}"
-    self.output_dir.mkdir(parents=True, exist_ok=True)
-    image, _ = data
-    image = handle_image(image)
-    fp = self.output_dir / f"{tag}.png"
-    torchvision.utils.save_image(image, fp)
-
-
-
-def write_on_epoch_end(self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]) -
-
-

Write predictions on epoch end.

-

Args

-
-
self
-
The instance of the class.
-
trainer
-
The trainer object.
-
pl_module : LightningModule
-
The Lightning module.
-
predictions : List[Any]
-
A list of prediction values.
-
batch_indices : List[Any]
-
A list of batch indices.
-
-

Raises

-
-
AssertionError
-
If 'predict' is not present in pl_module.datasets.
-
AssertionError
-
If 'data' is not defined in pl_module.datasets.
-
-

Returns

-

None

-
- -Expand source code - -
def write_on_epoch_end(
-    self,
-    trainer,
-    pl_module: "LightningModule",
-    predictions: List[Any],
-    batch_indices: List[Any],
-):
-    """
-    Write predictions on epoch end.
-
-    Args:
-        self: The instance of the class.
-        trainer: The trainer object.
-        pl_module (LightningModule): The Lightning module.
-        predictions (List[Any]): A list of prediction values.
-        batch_indices (List[Any]): A list of batch indices.
-
-    Raises:
-        AssertionError: If 'predict' is not present in pl_module.datasets.
-        AssertionError: If 'data' is not defined in pl_module.datasets.
-
-    Returns:
-        None
-    """
-    rows = []
-    assert "predict" in pl_module.datasets, "`data` not defined"
-    dataset = pl_module.datasets["predict"]
-    predictions = [pred for batch_pred in predictions for pred in batch_pred["pred"]]
-
-    for idx, (row, pred) in enumerate(zip(dataset.get_rows(), predictions)):
-        for i, v in enumerate(pred):
-            row[f"pred_{i}"] = v.item()
-
-        rows.append(row)
-
-        # Save image previews
-        if idx <= self.save_preview_samples:
-            input = dataset[idx]
-            self.save_preview_image(input, idx)
-
-    df = pd.DataFrame(rows)
-    df.to_csv(self.output_csv)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/callbacks/utils.html b/docs/api_docs/fmcib/callbacks/utils.html deleted file mode 100644 index 89bf7f1..0000000 --- a/docs/api_docs/fmcib/callbacks/utils.html +++ /dev/null @@ -1,205 +0,0 @@ - - - - - - -fmcib.callbacks.utils API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.callbacks.utils

-
-
-
- -Expand source code - -
from typing import List
-
-import torch
-
-
-def decollate(data: List[torch.Tensor]):
-    """
-    Decollate a list of tensors into a list of values.
-
-    Args:
-        data (list): A list of batch tensors.
-
-    Returns:
-        list: A list of values from the input tensors.
-
-    Raises:
-        AssertionError: If the input is not a list of tensors.
-    """
-    assert isinstance(data, list), "Decollate only implemented for list of `batch` tensors"
-
-    out = []
-    for d in data:
-        # Handles both cases: multiple elements and single element
-        # https://pytorch.org/docs/stable/generated/torch.Tensor.tolist.html
-        d = d.tolist()
-
-        out += d
-    return out
-
-
-def handle_image(image):
-    """
-    Handle image according to specific requirements.
-
-    Args:
-        image (tensor): An image tensor.
-
-    Returns:
-        tensor: The processed image tensor, based on the input conditions.
-
-    Raises:
-        None.
-    """
-    image = image.squeeze()
-    if image.dim() == 3:
-        return image[image.shape[0] // 2]
-    else:
-        return image
-
-
-
-
-
-
-
-

Functions

-
-
-def decollate(data: List[torch.Tensor]) -
-
-

Decollate a list of tensors into a list of values.

-

Args

-
-
data : list
-
A list of batch tensors.
-
-

Returns

-
-
list
-
A list of values from the input tensors.
-
-

Raises

-
-
AssertionError
-
If the input is not a list of tensors.
-
-
- -Expand source code - -
def decollate(data: List[torch.Tensor]):
-    """
-    Decollate a list of tensors into a list of values.
-
-    Args:
-        data (list): A list of batch tensors.
-
-    Returns:
-        list: A list of values from the input tensors.
-
-    Raises:
-        AssertionError: If the input is not a list of tensors.
-    """
-    assert isinstance(data, list), "Decollate only implemented for list of `batch` tensors"
-
-    out = []
-    for d in data:
-        # Handles both cases: multiple elements and single element
-        # https://pytorch.org/docs/stable/generated/torch.Tensor.tolist.html
-        d = d.tolist()
-
-        out += d
-    return out
-
-
-
-def handle_image(image) -
-
-

Handle image according to specific requirements.

-

Args

-
-
image : tensor
-
An image tensor.
-
-

Returns

-
-
tensor
-
The processed image tensor, based on the input conditions.
-
-

Raises

-

None.

-
- -Expand source code - -
def handle_image(image):
-    """
-    Handle image according to specific requirements.
-
-    Args:
-        image (tensor): An image tensor.
-
-    Returns:
-        tensor: The processed image tensor, based on the input conditions.
-
-    Raises:
-        None.
-    """
-    image = image.squeeze()
-    if image.dim() == 3:
-        return image[image.shape[0] // 2]
-    else:
-        return image
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/datasets/index.html b/docs/api_docs/fmcib/datasets/index.html deleted file mode 100644 index ebe683b..0000000 --- a/docs/api_docs/fmcib/datasets/index.html +++ /dev/null @@ -1,356 +0,0 @@ - - - - - - -fmcib.datasets API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.datasets

-
-
-
- -Expand source code - -
import os
-import random
-from pathlib import Path
-
-import monai
-import numpy as np
-import pandas as pd
-import SimpleITK as sitk
-import wget
-from loguru import logger
-
-from .ssl_radiomics_dataset import SSLRadiomicsDataset
-
-
-def get_lung1_clinical_data():
-    wget.download(
-        "https://www.dropbox.com/s/ulp8t21eunep21y/NSCLC%20Radiomics%20Lung1.clinical-version3-Oct%202019.csv?dl=1",
-        out="/tmp/lung1_clinical.csv",
-    )
-    return pd.read_csv("/tmp/lung1_clinical.csv")
-
-
-def get_radio_clinical_data():
-    wget.download(
-        "https://www.dropbox.com/s/mtpynjof550ulfo/NSCLCR01Radiogenomic_DATA_LABELS_2018-05-22_1500-shifted.csv?dl=1",
-        out=f"/tmp/radio_clinical.csv",
-    )
-    return pd.read_csv("/tmp/radio_clinical.csv")
-
-
-def get_lung1_foundation_features():
-    wget.download(
-        "https://www.dropbox.com/s/ypbb2iogq3bsq5v/lung1.csv?dl=1",
-        out=f"/tmp/lung1_foundation_features.csv",
-    )
-    df = pd.read_csv("/tmp/lung1_foundation_features.csv")
-    filtered_df = df.filter(like="pred")
-    filtered_df = filtered_df.reset_index()  # reset the index
-    filtered_df["PatientID"] = df["PatientID"]
-    return filtered_df
-
-
-def get_radio_foundation_features():
-    wget.download(
-        "https://www.dropbox.com/s/pwl4rdlvp9jirar/radio.csv?dl=1",
-        out=f"/tmp/radio_foundation_features.csv",
-    )
-
-    df = pd.read_csv("/tmp/radio_foundation_features.csv")
-    filtered_df = df.filter(like="pred")
-    filtered_df = filtered_df.reset_index()  # reset the index
-    filtered_df["PatientID"] = df["Case ID"]
-    return filtered_df
-
-
-def generate_dummy_data(dir_path, size=10):
-    path = Path(dir_path).resolve()
-    path.mkdir(exist_ok=True, parents=True)
-
-    row_list = []
-    for i in range(size):
-        row = create_dummy_row((32, 128, 128), str(path / f"dummy_{i}.nii.gz"))
-        row_list.append(row)
-
-    df = pd.DataFrame(row_list)
-    df.to_csv(path / "dummy.csv", index=False)
-
-    logger.info(f"Generated dummy data at {path}/dummy.csv")
-
-
-def create_dummy_row(size, output_filename):
-    """
-    Function to create a dummy row with path to an image and seed point corresponding to the image
-    """
-
-    # Create a np array initialized with random values between -1024 and 2048
-    np_image = np.random.randint(-1024, 2048, size, dtype=np.int16)
-
-    # Create an itk image from the numpy array
-    itk_image = sitk.GetImageFromArray(np_image)
-
-    # Save itk image to file with the given output filename
-    sitk.WriteImage(itk_image, output_filename)
-
-    x, y, z = generate_random_seed_point(itk_image.GetSize())
-
-    # Convert to global coordinates
-    x, y, z = itk_image.TransformContinuousIndexToPhysicalPoint((x, y, z))
-
-    return {
-        "image_path": output_filename,
-        "PatientID": random.randint(0, 100000),
-        "coordX": x,
-        "coordY": y,
-        "coordZ": z,
-        "label": random.randint(0, 1),
-    }
-
-
-def generate_random_seed_point(image_size):
-    """
-    Function to generate a random x, y, z coordinate within the image
-    """
-    x = random.randint(0, image_size[0] - 1)
-    y = random.randint(0, image_size[1] - 1)
-    z = random.randint(0, image_size[2] - 1)
-
-    return (x, y, z)
-
-
-
-

Sub-modules

-
-
fmcib.datasets.ssl_radiomics_dataset
-
-
-
-
fmcib.datasets.utils
-
-
-
-
-
-
-
-
-

Functions

-
-
-def create_dummy_row(size, output_filename) -
-
-

Function to create a dummy row with path to an image and seed point corresponding to the image

-
- -Expand source code - -
def create_dummy_row(size, output_filename):
-    """
-    Function to create a dummy row with path to an image and seed point corresponding to the image
-    """
-
-    # Create a np array initialized with random values between -1024 and 2048
-    np_image = np.random.randint(-1024, 2048, size, dtype=np.int16)
-
-    # Create an itk image from the numpy array
-    itk_image = sitk.GetImageFromArray(np_image)
-
-    # Save itk image to file with the given output filename
-    sitk.WriteImage(itk_image, output_filename)
-
-    x, y, z = generate_random_seed_point(itk_image.GetSize())
-
-    # Convert to global coordinates
-    x, y, z = itk_image.TransformContinuousIndexToPhysicalPoint((x, y, z))
-
-    return {
-        "image_path": output_filename,
-        "PatientID": random.randint(0, 100000),
-        "coordX": x,
-        "coordY": y,
-        "coordZ": z,
-        "label": random.randint(0, 1),
-    }
-
-
-
-def generate_dummy_data(dir_path, size=10) -
-
-
-
- -Expand source code - -
def generate_dummy_data(dir_path, size=10):
-    path = Path(dir_path).resolve()
-    path.mkdir(exist_ok=True, parents=True)
-
-    row_list = []
-    for i in range(size):
-        row = create_dummy_row((32, 128, 128), str(path / f"dummy_{i}.nii.gz"))
-        row_list.append(row)
-
-    df = pd.DataFrame(row_list)
-    df.to_csv(path / "dummy.csv", index=False)
-
-    logger.info(f"Generated dummy data at {path}/dummy.csv")
-
-
-
-def generate_random_seed_point(image_size) -
-
-

Function to generate a random x, y, z coordinate within the image

-
- -Expand source code - -
def generate_random_seed_point(image_size):
-    """
-    Function to generate a random x, y, z coordinate within the image
-    """
-    x = random.randint(0, image_size[0] - 1)
-    y = random.randint(0, image_size[1] - 1)
-    z = random.randint(0, image_size[2] - 1)
-
-    return (x, y, z)
-
-
-
-def get_lung1_clinical_data() -
-
-
-
- -Expand source code - -
def get_lung1_clinical_data():
-    wget.download(
-        "https://www.dropbox.com/s/ulp8t21eunep21y/NSCLC%20Radiomics%20Lung1.clinical-version3-Oct%202019.csv?dl=1",
-        out="/tmp/lung1_clinical.csv",
-    )
-    return pd.read_csv("/tmp/lung1_clinical.csv")
-
-
-
-def get_lung1_foundation_features() -
-
-
-
- -Expand source code - -
def get_lung1_foundation_features():
-    wget.download(
-        "https://www.dropbox.com/s/ypbb2iogq3bsq5v/lung1.csv?dl=1",
-        out=f"/tmp/lung1_foundation_features.csv",
-    )
-    df = pd.read_csv("/tmp/lung1_foundation_features.csv")
-    filtered_df = df.filter(like="pred")
-    filtered_df = filtered_df.reset_index()  # reset the index
-    filtered_df["PatientID"] = df["PatientID"]
-    return filtered_df
-
-
-
-def get_radio_clinical_data() -
-
-
-
- -Expand source code - -
def get_radio_clinical_data():
-    wget.download(
-        "https://www.dropbox.com/s/mtpynjof550ulfo/NSCLCR01Radiogenomic_DATA_LABELS_2018-05-22_1500-shifted.csv?dl=1",
-        out=f"/tmp/radio_clinical.csv",
-    )
-    return pd.read_csv("/tmp/radio_clinical.csv")
-
-
-
-def get_radio_foundation_features() -
-
-
-
- -Expand source code - -
def get_radio_foundation_features():
-    wget.download(
-        "https://www.dropbox.com/s/pwl4rdlvp9jirar/radio.csv?dl=1",
-        out=f"/tmp/radio_foundation_features.csv",
-    )
-
-    df = pd.read_csv("/tmp/radio_foundation_features.csv")
-    filtered_df = df.filter(like="pred")
-    filtered_df = filtered_df.reset_index()  # reset the index
-    filtered_df["PatientID"] = df["Case ID"]
-    return filtered_df
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/datasets/ssl_radiomics_dataset.html b/docs/api_docs/fmcib/datasets/ssl_radiomics_dataset.html deleted file mode 100644 index 691fcc3..0000000 --- a/docs/api_docs/fmcib/datasets/ssl_radiomics_dataset.html +++ /dev/null @@ -1,619 +0,0 @@ - - - - - - -fmcib.datasets.ssl_radiomics_dataset API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.datasets.ssl_radiomics_dataset

-
-
-
- -Expand source code - -
from pathlib import Path
-
-import monai
-import numpy as np
-import pandas as pd
-import SimpleITK as sitk
-from loguru import logger
-from torch.utils.data import Dataset
-
-from .utils import resample_image_to_spacing, slice_image
-
-
-class SSLRadiomicsDataset(Dataset):
-    """
-    Dataset class for SSL Radiomics dataset.
-
-    Args:
-        path (str): The path to the dataset.
-        label (str, optional): The label column name in the dataset annotations. Default is None.
-        radius (int, optional): The radius around the centroid for positive patch extraction. Default is 25.
-        orient (bool, optional): Whether to orient the images to LPI orientation. Default is False.
-        resample_spacing (float or tuple, optional): The desired spacing for resampling the images. Default is None.
-        enable_negatives (bool, optional): Whether to include negative samples. Default is True.
-        transform (callable, optional): A function/transform to apply on the images. Default is None.
-    """
-
-    def __init__(
-        self,
-        path,
-        label=None,
-        radius=25,
-        orient=False,
-        resample_spacing=None,
-        enable_negatives=True,
-        transform=None,
-        orient_patch=True,
-        input_is_target=False,
-    ):
-        """
-        Creates an instance of the SSLRadiomicsDataset class with the given parameters.
-
-        Args:
-            path (str): The path to the dataset.
-            label (Optional[str]): The label to use for the dataset. Defaults to None.
-            radius (int): The radius parameter. Defaults to 25.
-            orient (bool): True if the dataset should be oriented, False otherwise. Defaults to False.
-            resample_spacing (Optional[...]): The resample spacing parameter. Defaults to None.
-            enable_negatives (bool): True if negatives are enabled, False otherwise. Defaults to True.
-            transform: The transformation to apply to the dataset. Defaults to None.
-            orient_patch (bool): True if the patch should be oriented, False otherwise. Defaults to True.
-            input_is_target (bool): True if the input is the target, False otherwise. Defaults to False.
-
-        Raises:
-            None.
-
-        Returns:
-            None.
-        """
-        monai.data.set_track_meta(False)
-        sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
-        super(SSLRadiomicsDataset, self).__init__()
-        self._path = Path(path)
-
-        self.radius = radius
-        self.orient = orient
-        self.resample_spacing = resample_spacing
-        self.label = label
-        self.enable_negatives = enable_negatives
-        self.transform = transform
-        self.orient_patch = orient_patch
-        self.input_is_target = input_is_target
-        self.annotations = pd.read_csv(self._path)
-        self._num_samples = len(self.annotations)  # set the length of the dataset
-
-    def get_rows(self):
-        """
-        Get the rows of the annotations as a list of dictionaries.
-
-        Returns:
-            list of dict: The rows of the annotations as dictionaries.
-        """
-        return self.annotations.to_dict(orient="records")
-
-    def get_labels(self):
-        """
-        Function to get labels for when they are available in the dataset.
-
-        Args:
-            None
-
-        Returns:
-            None
-        """
-
-        labels = self.annotations[self.label].values
-        assert not np.any(labels == -1), "All labels must be specified"
-        return labels
-
-    def __len__(self):
-        """
-        Size of the dataset.
-        """
-        return self._num_samples
-
-    def get_negative_sample(self, image):
-        """
-        Extract a negative sample from the image background with no overlap to the positive sample.
-
-        Parameters:
-            image: Image to extract sample
-            positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
-        """
-        positive_patch_size = [self.radius * 2] * 3
-        valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
-        def get_random_patch():
-            """
-            Get a random patch from an image.
-
-            Returns:
-                list: A list containing the start and end indices of the random patch.
-            """
-            random_patch_idx = [
-                [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
-            ]
-            return random_patch_idx
-
-        random_patch_idx = get_random_patch()
-
-        # escape_count = 0
-        # while is_overlapping(positive_patch_idx, random_patch_idx):
-        #     if escape_count >= 3:
-        #         logger.warning("Random patch has overlap with positive patch")
-        #         return None
-
-        #     random_patch_idx = get_random_patch()
-        #     escape_count += 1
-
-        random_patch = slice_image(image, random_patch_idx)
-        random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
-        negative_array = sitk.GetArrayFromImage(random_patch)
-
-        negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
-        return negative_tensor
-
-    def __getitem__(self, idx: int):
-        """
-        Implement how to load the data corresponding to the idx element in the dataset from your data source.
-        """
-
-        # Get a row from the CSV file
-        row = self.annotations.iloc[idx]
-        image_path = row["image_path"]
-        image = sitk.ReadImage(str(image_path))
-        image = resample_image_to_spacing(image, self.resample_spacing, -1024) if self.resample_spacing is not None else image
-
-        centroid = (row["coordX"], row["coordY"], row["coordZ"])
-        centroid = image.TransformPhysicalPointToContinuousIndex(centroid)
-        centroid = [int(d) for d in centroid]
-
-        # Orient all images to LPI orientation
-        image = sitk.DICOMOrient(image, "LPI") if self.orient else image
-
-        # Extract positive with a specified radius around centroid
-        patch_idx = [(c - self.radius, c + self.radius) for c in centroid]
-        patch_image = slice_image(image, patch_idx)
-
-        patch_image = sitk.DICOMOrient(patch_image, "LPS") if self.orient_patch else patch_image
-
-        array = sitk.GetArrayFromImage(patch_image)
-        tensor = array if self.transform is None else self.transform(array)
-
-        if self.label is not None:
-            target = int(row[self.label])
-        elif self.input_is_target:
-            target = tensor.clone()
-        else:
-            target = None
-
-        if self.enable_negatives:
-            return {"positive": tensor, "negative": self.get_negative_sample(image)}, target
-
-        return tensor, target
-
-
-if __name__ == "__main__":
-    from pathlib import Path
-
-    # Test pytorch dataset
-    print("Test pytorch dataset")
-    dataset = SSLRadiomicsDataset(
-        "/home/suraj/Repositories/cancer-imaging-ssl/src/pretraining/data_csv/deeplesion/train.csv",
-        orient=True,
-        resample_spacing=[1, 1, 1],
-    )
-
-    # Visualize item from dataset
-    item = dataset[0]
-
-    positive = sitk.GetImageFromArray(item[0][0])
-    negative = sitk.GetImageFromArray(item[0][1])
-    current_dir = Path(__file__).parent.resolve()
-
-    sitk.WriteImage(positive, f"{str(current_dir)}/tests/positive.nrrd")
-    sitk.WriteImage(negative, f"{str(current_dir)}/tests/negative.nrrd")
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SSLRadiomicsDataset -(path, label=None, radius=25, orient=False, resample_spacing=None, enable_negatives=True, transform=None, orient_patch=True, input_is_target=False) -
-
-

Dataset class for SSL Radiomics dataset.

-

Args

-
-
path : str
-
The path to the dataset.
-
label : str, optional
-
The label column name in the dataset annotations. Default is None.
-
radius : int, optional
-
The radius around the centroid for positive patch extraction. Default is 25.
-
orient : bool, optional
-
Whether to orient the images to LPI orientation. Default is False.
-
resample_spacing : float or tuple, optional
-
The desired spacing for resampling the images. Default is None.
-
enable_negatives : bool, optional
-
Whether to include negative samples. Default is True.
-
transform : callable, optional
-
A function/transform to apply on the images. Default is None.
-
-

Creates an instance of the SSLRadiomicsDataset class with the given parameters.

-

Args

-
-
path : str
-
The path to the dataset.
-
label : Optional[str]
-
The label to use for the dataset. Defaults to None.
-
radius : int
-
The radius parameter. Defaults to 25.
-
orient : bool
-
True if the dataset should be oriented, False otherwise. Defaults to False.
-
resample_spacing : Optional[…]
-
The resample spacing parameter. Defaults to None.
-
enable_negatives : bool
-
True if negatives are enabled, False otherwise. Defaults to True.
-
transform
-
The transformation to apply to the dataset. Defaults to None.
-
orient_patch : bool
-
True if the patch should be oriented, False otherwise. Defaults to True.
-
input_is_target : bool
-
True if the input is the target, False otherwise. Defaults to False.
-
-

Raises

-

None.

-

Returns

-

None.

-
- -Expand source code - -
class SSLRadiomicsDataset(Dataset):
-    """
-    Dataset class for SSL Radiomics dataset.
-
-    Args:
-        path (str): The path to the dataset.
-        label (str, optional): The label column name in the dataset annotations. Default is None.
-        radius (int, optional): The radius around the centroid for positive patch extraction. Default is 25.
-        orient (bool, optional): Whether to orient the images to LPI orientation. Default is False.
-        resample_spacing (float or tuple, optional): The desired spacing for resampling the images. Default is None.
-        enable_negatives (bool, optional): Whether to include negative samples. Default is True.
-        transform (callable, optional): A function/transform to apply on the images. Default is None.
-    """
-
-    def __init__(
-        self,
-        path,
-        label=None,
-        radius=25,
-        orient=False,
-        resample_spacing=None,
-        enable_negatives=True,
-        transform=None,
-        orient_patch=True,
-        input_is_target=False,
-    ):
-        """
-        Creates an instance of the SSLRadiomicsDataset class with the given parameters.
-
-        Args:
-            path (str): The path to the dataset.
-            label (Optional[str]): The label to use for the dataset. Defaults to None.
-            radius (int): The radius parameter. Defaults to 25.
-            orient (bool): True if the dataset should be oriented, False otherwise. Defaults to False.
-            resample_spacing (Optional[...]): The resample spacing parameter. Defaults to None.
-            enable_negatives (bool): True if negatives are enabled, False otherwise. Defaults to True.
-            transform: The transformation to apply to the dataset. Defaults to None.
-            orient_patch (bool): True if the patch should be oriented, False otherwise. Defaults to True.
-            input_is_target (bool): True if the input is the target, False otherwise. Defaults to False.
-
-        Raises:
-            None.
-
-        Returns:
-            None.
-        """
-        monai.data.set_track_meta(False)
-        sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
-        super(SSLRadiomicsDataset, self).__init__()
-        self._path = Path(path)
-
-        self.radius = radius
-        self.orient = orient
-        self.resample_spacing = resample_spacing
-        self.label = label
-        self.enable_negatives = enable_negatives
-        self.transform = transform
-        self.orient_patch = orient_patch
-        self.input_is_target = input_is_target
-        self.annotations = pd.read_csv(self._path)
-        self._num_samples = len(self.annotations)  # set the length of the dataset
-
-    def get_rows(self):
-        """
-        Get the rows of the annotations as a list of dictionaries.
-
-        Returns:
-            list of dict: The rows of the annotations as dictionaries.
-        """
-        return self.annotations.to_dict(orient="records")
-
-    def get_labels(self):
-        """
-        Function to get labels for when they are available in the dataset.
-
-        Args:
-            None
-
-        Returns:
-            None
-        """
-
-        labels = self.annotations[self.label].values
-        assert not np.any(labels == -1), "All labels must be specified"
-        return labels
-
-    def __len__(self):
-        """
-        Size of the dataset.
-        """
-        return self._num_samples
-
-    def get_negative_sample(self, image):
-        """
-        Extract a negative sample from the image background with no overlap to the positive sample.
-
-        Parameters:
-            image: Image to extract sample
-            positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
-        """
-        positive_patch_size = [self.radius * 2] * 3
-        valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
-        def get_random_patch():
-            """
-            Get a random patch from an image.
-
-            Returns:
-                list: A list containing the start and end indices of the random patch.
-            """
-            random_patch_idx = [
-                [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
-            ]
-            return random_patch_idx
-
-        random_patch_idx = get_random_patch()
-
-        # escape_count = 0
-        # while is_overlapping(positive_patch_idx, random_patch_idx):
-        #     if escape_count >= 3:
-        #         logger.warning("Random patch has overlap with positive patch")
-        #         return None
-
-        #     random_patch_idx = get_random_patch()
-        #     escape_count += 1
-
-        random_patch = slice_image(image, random_patch_idx)
-        random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
-        negative_array = sitk.GetArrayFromImage(random_patch)
-
-        negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
-        return negative_tensor
-
-    def __getitem__(self, idx: int):
-        """
-        Implement how to load the data corresponding to the idx element in the dataset from your data source.
-        """
-
-        # Get a row from the CSV file
-        row = self.annotations.iloc[idx]
-        image_path = row["image_path"]
-        image = sitk.ReadImage(str(image_path))
-        image = resample_image_to_spacing(image, self.resample_spacing, -1024) if self.resample_spacing is not None else image
-
-        centroid = (row["coordX"], row["coordY"], row["coordZ"])
-        centroid = image.TransformPhysicalPointToContinuousIndex(centroid)
-        centroid = [int(d) for d in centroid]
-
-        # Orient all images to LPI orientation
-        image = sitk.DICOMOrient(image, "LPI") if self.orient else image
-
-        # Extract positive with a specified radius around centroid
-        patch_idx = [(c - self.radius, c + self.radius) for c in centroid]
-        patch_image = slice_image(image, patch_idx)
-
-        patch_image = sitk.DICOMOrient(patch_image, "LPS") if self.orient_patch else patch_image
-
-        array = sitk.GetArrayFromImage(patch_image)
-        tensor = array if self.transform is None else self.transform(array)
-
-        if self.label is not None:
-            target = int(row[self.label])
-        elif self.input_is_target:
-            target = tensor.clone()
-        else:
-            target = None
-
-        if self.enable_negatives:
-            return {"positive": tensor, "negative": self.get_negative_sample(image)}, target
-
-        return tensor, target
-
-

Ancestors

-
    -
  • torch.utils.data.dataset.Dataset
  • -
  • typing.Generic
  • -
-

Methods

-
-
-def get_labels(self) -
-
-

Function to get labels for when they are available in the dataset.

-

Args

-

None

-

Returns

-

None

-
- -Expand source code - -
def get_labels(self):
-    """
-    Function to get labels for when they are available in the dataset.
-
-    Args:
-        None
-
-    Returns:
-        None
-    """
-
-    labels = self.annotations[self.label].values
-    assert not np.any(labels == -1), "All labels must be specified"
-    return labels
-
-
-
-def get_negative_sample(self, image) -
-
-

Extract a negative sample from the image background with no overlap to the positive sample.

-

Parameters

-

image: Image to extract sample -positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]

-
- -Expand source code - -
def get_negative_sample(self, image):
-    """
-    Extract a negative sample from the image background with no overlap to the positive sample.
-
-    Parameters:
-        image: Image to extract sample
-        positive_patch_idx: Index of the positive patch in [(xmin, xmax), (ymin, ymax), (zmin, zmax)]
-    """
-    positive_patch_size = [self.radius * 2] * 3
-    valid_patch_size = monai.data.utils.get_valid_patch_size(image.GetSize(), positive_patch_size)
-
-    def get_random_patch():
-        """
-        Get a random patch from an image.
-
-        Returns:
-            list: A list containing the start and end indices of the random patch.
-        """
-        random_patch_idx = [
-            [x.start, x.stop] for x in monai.data.utils.get_random_patch(image.GetSize(), valid_patch_size)
-        ]
-        return random_patch_idx
-
-    random_patch_idx = get_random_patch()
-
-    # escape_count = 0
-    # while is_overlapping(positive_patch_idx, random_patch_idx):
-    #     if escape_count >= 3:
-    #         logger.warning("Random patch has overlap with positive patch")
-    #         return None
-
-    #     random_patch_idx = get_random_patch()
-    #     escape_count += 1
-
-    random_patch = slice_image(image, random_patch_idx)
-    random_patch = sitk.DICOMOrient(random_patch, "LPS") if self.orient_patch else random_patch
-    negative_array = sitk.GetArrayFromImage(random_patch)
-
-    negative_tensor = negative_array if self.transform is None else self.transform(negative_array)
-    return negative_tensor
-
-
-
-def get_rows(self) -
-
-

Get the rows of the annotations as a list of dictionaries.

-

Returns

-
-
list of dict
-
The rows of the annotations as dictionaries.
-
-
- -Expand source code - -
def get_rows(self):
-    """
-    Get the rows of the annotations as a list of dictionaries.
-
-    Returns:
-        list of dict: The rows of the annotations as dictionaries.
-    """
-    return self.annotations.to_dict(orient="records")
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/datasets/utils.html b/docs/api_docs/fmcib/datasets/utils.html deleted file mode 100644 index 0943a94..0000000 --- a/docs/api_docs/fmcib/datasets/utils.html +++ /dev/null @@ -1,254 +0,0 @@ - - - - - - -fmcib.datasets.utils API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.datasets.utils

-
-
-
- -Expand source code - -
from pathlib import Path
-
-import numpy as np
-import SimpleITK as sitk
-
-# https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
-SITK_INTERPOLATOR_DICT = {
-    "nearest": sitk.sitkNearestNeighbor,
-    "linear": sitk.sitkLinear,
-    "gaussian": sitk.sitkGaussian,
-    "label_gaussian": sitk.sitkLabelGaussian,
-    "bspline": sitk.sitkBSpline,
-    "hamming_sinc": sitk.sitkHammingWindowedSinc,
-    "cosine_windowed_sinc": sitk.sitkCosineWindowedSinc,
-    "welch_windowed_sinc": sitk.sitkWelchWindowedSinc,
-    "lanczos_windowed_sinc": sitk.sitkLanczosWindowedSinc,
-}
-
-
-def resample_image_to_spacing(image, new_spacing, default_value, interpolator="linear"):
-    """
-    Resample an image to a new spacing.
-    """
-    assert interpolator in SITK_INTERPOLATOR_DICT, (
-        f"Interpolator '{interpolator}' not part of SimpleITK. "
-        f"Please choose one of the following {list(SITK_INTERPOLATOR_DICT.keys())}."
-    )
-
-    assert image.GetDimension() == len(new_spacing), (
-        f"Input is {image.GetDimension()}-dimensional while " f"the new spacing is {len(new_spacing)}-dimensional."
-    )
-
-    interpolator = SITK_INTERPOLATOR_DICT[interpolator]
-    spacing = image.GetSpacing()
-    size = image.GetSize()
-    new_size = [int(round(siz * spac / n_spac)) for siz, spac, n_spac in zip(size, spacing, new_spacing)]
-    return sitk.Resample(
-        image,
-        new_size,  # size
-        sitk.Transform(),  # transform
-        interpolator,  # interpolator
-        image.GetOrigin(),  # outputOrigin
-        new_spacing,  # outputSpacing
-        image.GetDirection(),  # outputDirection
-        default_value,  # defaultPixelValue
-        image.GetPixelID(),
-    )  # outputPixelType
-
-
-def slice_image(image, patch_idx):
-    """
-    Slice an image.
-    """
-
-    start, stop = zip(*patch_idx)
-    slice_filter = sitk.SliceImageFilter()
-    slice_filter.SetStart(start)
-    slice_filter.SetStop(stop)
-    return slice_filter.Execute(image)
-
-
-def is_overlapping(patch1, patch2):
-    """
-    Check if two patches are overlapping.
-
-    Args:
-        patch1 (list of tuples): A list of tuples representing the ranges of each axis in patch1.
-        patch2 (list of tuples): A list of tuples representing the ranges of each axis in patch2.
-
-    Returns:
-        bool: True if the two patches overlap, False otherwise.
-
-    Note:
-        This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch.
-        The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range.
-        The patches are considered overlapping if there is any overlap in the ranges of their axes.
-    """
-    overlap_by_axis = [max(axis1[0], axis2[0]) < min(axis1[1], axis2[1]) for axis1, axis2 in zip(patch1, patch2)]
-
-    return np.all(overlap_by_axis)
-
-
-
-
-
-
-
-

Functions

-
-
-def is_overlapping(patch1, patch2) -
-
-

Check if two patches are overlapping.

-

Args

-
-
patch1 : list of tuples
-
A list of tuples representing the ranges of each axis in patch1.
-
patch2 : list of tuples
-
A list of tuples representing the ranges of each axis in patch2.
-
-

Returns

-
-
bool
-
True if the two patches overlap, False otherwise.
-
-

Note

-

This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch. -The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range. -The patches are considered overlapping if there is any overlap in the ranges of their axes.

-
- -Expand source code - -
def is_overlapping(patch1, patch2):
-    """
-    Check if two patches are overlapping.
-
-    Args:
-        patch1 (list of tuples): A list of tuples representing the ranges of each axis in patch1.
-        patch2 (list of tuples): A list of tuples representing the ranges of each axis in patch2.
-
-    Returns:
-        bool: True if the two patches overlap, False otherwise.
-
-    Note:
-        This function assumes that each patch is represented by a list of tuples, where each tuple represents the range of an axis in the patch.
-        The range of an axis is represented by a tuple (start, end), where start is the start value of the range and end is the end value of the range.
-        The patches are considered overlapping if there is any overlap in the ranges of their axes.
-    """
-    overlap_by_axis = [max(axis1[0], axis2[0]) < min(axis1[1], axis2[1]) for axis1, axis2 in zip(patch1, patch2)]
-
-    return np.all(overlap_by_axis)
-
-
-
-def resample_image_to_spacing(image, new_spacing, default_value, interpolator='linear') -
-
-

Resample an image to a new spacing.

-
- -Expand source code - -
def resample_image_to_spacing(image, new_spacing, default_value, interpolator="linear"):
-    """
-    Resample an image to a new spacing.
-    """
-    assert interpolator in SITK_INTERPOLATOR_DICT, (
-        f"Interpolator '{interpolator}' not part of SimpleITK. "
-        f"Please choose one of the following {list(SITK_INTERPOLATOR_DICT.keys())}."
-    )
-
-    assert image.GetDimension() == len(new_spacing), (
-        f"Input is {image.GetDimension()}-dimensional while " f"the new spacing is {len(new_spacing)}-dimensional."
-    )
-
-    interpolator = SITK_INTERPOLATOR_DICT[interpolator]
-    spacing = image.GetSpacing()
-    size = image.GetSize()
-    new_size = [int(round(siz * spac / n_spac)) for siz, spac, n_spac in zip(size, spacing, new_spacing)]
-    return sitk.Resample(
-        image,
-        new_size,  # size
-        sitk.Transform(),  # transform
-        interpolator,  # interpolator
-        image.GetOrigin(),  # outputOrigin
-        new_spacing,  # outputSpacing
-        image.GetDirection(),  # outputDirection
-        default_value,  # defaultPixelValue
-        image.GetPixelID(),
-    )  # outputPixelType
-
-
-
-def slice_image(image, patch_idx) -
-
-

Slice an image.

-
- -Expand source code - -
def slice_image(image, patch_idx):
-    """
-    Slice an image.
-    """
-
-    start, stop = zip(*patch_idx)
-    slice_filter = sitk.SliceImageFilter()
-    slice_filter.SetStart(start)
-    slice_filter.SetStop(stop)
-    return slice_filter.Execute(image)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/index.html b/docs/api_docs/fmcib/index.html deleted file mode 100644 index df7aa8b..0000000 --- a/docs/api_docs/fmcib/index.html +++ /dev/null @@ -1,111 +0,0 @@ - - - - - - -fmcib API documentation - - - - - - - - - - - -
-
-
-

Package fmcib

-
-
-
- -Expand source code - -
__version__ = "0.0.1a22"
-
-
-
-

Sub-modules

-
-
fmcib.callbacks
-
-
-
-
fmcib.datasets
-
-
-
-
fmcib.models
-
-
-
-
fmcib.optimizers
-
-
-
-
fmcib.preprocessing
-
-
-
-
fmcib.run
-
-
-
-
fmcib.ssl
-
-
-
-
fmcib.transforms
-
-
-
-
fmcib.utils
-
-
-
-
fmcib.visualization
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/models/autoencoder.html b/docs/api_docs/fmcib/models/autoencoder.html deleted file mode 100644 index 9f5f44f..0000000 --- a/docs/api_docs/fmcib/models/autoencoder.html +++ /dev/null @@ -1,397 +0,0 @@ - - - - - - -fmcib.models.autoencoder API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.models.autoencoder

-
-
-
- -Expand source code - -
import torch
-import torch.nn as nn
-from monai.networks.blocks import Convolution, ResidualUnit
-from monai.networks.nets import AutoEncoder
-
-
-class CustomAE(AutoEncoder):
-    """
-    A custom AutoEncoder class.
-
-    Inherits from AutoEncoder.
-
-    Attributes:
-        padding (int): The padding size for the convolutional layers.
-        decoder (bool, optional): Determines if the decoder part of the network is included.
-        kwargs: Additional keyword arguments passed to the parent class.
-
-    Methods:
-        _get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network.
-        _get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.
-    """
-
-    def __init__(self, padding, decoder=True, **kwargs):
-        """
-        Initialize the object.
-
-        Args:
-            padding (int): Padding value.
-            decoder (bool, optional): If True, use a decoder. Defaults to True.
-            **kwargs: Additional keyword arguments.
-
-        Attributes:
-            padding (int): Padding value.
-
-        Raises:
-            None
-        """
-        self.padding = padding
-        super().__init__(**kwargs)
-        if not decoder:
-            self.decode = nn.Sequential(nn.AvgPool3d(3), nn.Flatten())
-
-    def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:
-        """
-        Returns a single layer of the encoder part of the network.
-        """
-        mod: nn.Module
-        if self.num_res_units > 0:
-            mod = ResidualUnit(
-                spatial_dims=self.dimensions,
-                in_channels=in_channels,
-                out_channels=out_channels,
-                strides=strides,
-                kernel_size=self.kernel_size,
-                padding=self.padding,
-                subunits=self.num_res_units,
-                act=self.act,
-                norm=self.norm,
-                dropout=self.dropout,
-                bias=self.bias,
-                last_conv_only=is_last,
-            )
-            return mod
-        mod = Convolution(
-            spatial_dims=self.dimensions,
-            in_channels=in_channels,
-            out_channels=out_channels,
-            strides=strides,
-            kernel_size=self.kernel_size,
-            padding=self.padding,
-            act=self.act,
-            norm=self.norm,
-            dropout=self.dropout,
-            bias=self.bias,
-            conv_only=is_last,
-        )
-        return mod
-
-    def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:
-        """
-        Returns a single layer of the decoder part of the network.
-        """
-        decode = nn.Sequential()
-
-        conv = Convolution(
-            spatial_dims=self.dimensions,
-            in_channels=in_channels,
-            out_channels=out_channels,
-            strides=strides,
-            kernel_size=self.up_kernel_size,
-            padding=self.padding,
-            act=self.act,
-            norm=self.norm,
-            dropout=self.dropout,
-            bias=self.bias,
-            conv_only=is_last and self.num_res_units == 0,
-            is_transposed=True,
-        )
-
-        decode.add_module("conv", conv)
-
-        if self.num_res_units > 0:
-            ru = ResidualUnit(
-                spatial_dims=self.dimensions,
-                in_channels=out_channels,
-                out_channels=out_channels,
-                padding=self.padding,
-                strides=strides,
-                kernel_size=self.kernel_size,
-                subunits=1,
-                act=self.act,
-                norm=self.norm,
-                dropout=self.dropout,
-                bias=self.bias,
-                last_conv_only=is_last,
-            )
-
-            decode.add_module("resunit", ru)
-
-        return decode
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class CustomAE -(padding, decoder=True, **kwargs) -
-
-

A custom AutoEncoder class.

-

Inherits from AutoEncoder.

-

Attributes

-
-
padding : int
-
The padding size for the convolutional layers.
-
decoder : bool, optional
-
Determines if the decoder part of the network is included.
-
kwargs
-
Additional keyword arguments passed to the parent class.
-
-

Methods

-

_get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network. -_get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.

-

Initialize the object.

-

Args

-
-
padding : int
-
Padding value.
-
decoder : bool, optional
-
If True, use a decoder. Defaults to True.
-
**kwargs
-
Additional keyword arguments.
-
-

Attributes

-
-
padding : int
-
Padding value.
-
-

Raises

-

None

-
- -Expand source code - -
class CustomAE(AutoEncoder):
-    """
-    A custom AutoEncoder class.
-
-    Inherits from AutoEncoder.
-
-    Attributes:
-        padding (int): The padding size for the convolutional layers.
-        decoder (bool, optional): Determines if the decoder part of the network is included.
-        kwargs: Additional keyword arguments passed to the parent class.
-
-    Methods:
-        _get_encode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the encoder part of the network.
-        _get_decode_layer(in_channels, out_channels, strides, is_last): Returns a single layer of the decoder part of the network.
-    """
-
-    def __init__(self, padding, decoder=True, **kwargs):
-        """
-        Initialize the object.
-
-        Args:
-            padding (int): Padding value.
-            decoder (bool, optional): If True, use a decoder. Defaults to True.
-            **kwargs: Additional keyword arguments.
-
-        Attributes:
-            padding (int): Padding value.
-
-        Raises:
-            None
-        """
-        self.padding = padding
-        super().__init__(**kwargs)
-        if not decoder:
-            self.decode = nn.Sequential(nn.AvgPool3d(3), nn.Flatten())
-
-    def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:
-        """
-        Returns a single layer of the encoder part of the network.
-        """
-        mod: nn.Module
-        if self.num_res_units > 0:
-            mod = ResidualUnit(
-                spatial_dims=self.dimensions,
-                in_channels=in_channels,
-                out_channels=out_channels,
-                strides=strides,
-                kernel_size=self.kernel_size,
-                padding=self.padding,
-                subunits=self.num_res_units,
-                act=self.act,
-                norm=self.norm,
-                dropout=self.dropout,
-                bias=self.bias,
-                last_conv_only=is_last,
-            )
-            return mod
-        mod = Convolution(
-            spatial_dims=self.dimensions,
-            in_channels=in_channels,
-            out_channels=out_channels,
-            strides=strides,
-            kernel_size=self.kernel_size,
-            padding=self.padding,
-            act=self.act,
-            norm=self.norm,
-            dropout=self.dropout,
-            bias=self.bias,
-            conv_only=is_last,
-        )
-        return mod
-
-    def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:
-        """
-        Returns a single layer of the decoder part of the network.
-        """
-        decode = nn.Sequential()
-
-        conv = Convolution(
-            spatial_dims=self.dimensions,
-            in_channels=in_channels,
-            out_channels=out_channels,
-            strides=strides,
-            kernel_size=self.up_kernel_size,
-            padding=self.padding,
-            act=self.act,
-            norm=self.norm,
-            dropout=self.dropout,
-            bias=self.bias,
-            conv_only=is_last and self.num_res_units == 0,
-            is_transposed=True,
-        )
-
-        decode.add_module("conv", conv)
-
-        if self.num_res_units > 0:
-            ru = ResidualUnit(
-                spatial_dims=self.dimensions,
-                in_channels=out_channels,
-                out_channels=out_channels,
-                padding=self.padding,
-                strides=strides,
-                kernel_size=self.kernel_size,
-                subunits=1,
-                act=self.act,
-                norm=self.norm,
-                dropout=self.dropout,
-                bias=self.bias,
-                last_conv_only=is_last,
-            )
-
-            decode.add_module("resunit", ru)
-
-        return decode
-
-

Ancestors

-
    -
  • monai.networks.nets.autoencoder.AutoEncoder
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x: torch.Tensor) ‑> Any -
-
-

Define the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
-
- -Expand source code - -
def forward(self, x: torch.Tensor) -> Any:
-    x = self.encode(x)
-    x = self.intermediate(x)
-    x = self.decode(x)
-    return x
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/models/index.html b/docs/api_docs/fmcib/models/index.html deleted file mode 100644 index 5f2270a..0000000 --- a/docs/api_docs/fmcib/models/index.html +++ /dev/null @@ -1,169 +0,0 @@ - - - - - - -fmcib.models API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.models

-
-
-
- -Expand source code - -
import os
-import pickle
-from pathlib import Path
-
-import wget
-from monai.networks.nets import resnet50
-
-from fmcib.utils.download_utils import bar_progress
-
-from .autoencoder import CustomAE as AutoEncoder
-from .load_model import LoadModel
-from .models_genesis import UNet3D as ModelsGenesisUNet3D
-
-
-def get_linear_classifier(weights_path=None, download_url="https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1"):
-    if weights_path is None:
-        weights_path = "/tmp/linear_model.pkl"
-        wget.download(download_url, out=weights_path)
-
-    return pickle.load(open(weights_path, "rb"))
-
-
-def fmcib_model():
-    trunk = resnet50(
-        pretrained=False,
-        n_input_channels=1,
-        widen_factor=2,
-        conv1_t_stride=2,
-        feed_forward=False,
-        bias_downsample=True,
-    )
-    weights_url = "https://zenodo.org/records/10528450/files/model_weights.torch?download=1"
-    current_path = Path(os.getcwd())
-    if not (current_path / "model_weights.torch").exists():
-        wget.download(weights_url, bar=bar_progress)
-    model = LoadModel(trunk=trunk, weights_path=current_path / "model_weights.torch", heads=[])
-    return model
-
-
-
-

Sub-modules

-
-
fmcib.models.autoencoder
-
-
-
-
fmcib.models.load_model
-
-
-
-
fmcib.models.models_genesis
-
-
-
-
-
-
-
-
-

Functions

-
-
-def fmcib_model() -
-
-
-
- -Expand source code - -
def fmcib_model():
-    trunk = resnet50(
-        pretrained=False,
-        n_input_channels=1,
-        widen_factor=2,
-        conv1_t_stride=2,
-        feed_forward=False,
-        bias_downsample=True,
-    )
-    weights_url = "https://zenodo.org/records/10528450/files/model_weights.torch?download=1"
-    current_path = Path(os.getcwd())
-    if not (current_path / "model_weights.torch").exists():
-        wget.download(weights_url, bar=bar_progress)
-    model = LoadModel(trunk=trunk, weights_path=current_path / "model_weights.torch", heads=[])
-    return model
-
-
-
-def get_linear_classifier(weights_path=None, download_url='https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1') -
-
-
-
- -Expand source code - -
def get_linear_classifier(weights_path=None, download_url="https://www.dropbox.com/s/77zg2av5c6edjfu/task3.pkl?dl=1"):
-    if weights_path is None:
-        weights_path = "/tmp/linear_model.pkl"
-        wget.download(download_url, out=weights_path)
-
-    return pickle.load(open(weights_path, "rb"))
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/models/load_model.html b/docs/api_docs/fmcib/models/load_model.html deleted file mode 100644 index f6773a9..0000000 --- a/docs/api_docs/fmcib/models/load_model.html +++ /dev/null @@ -1,512 +0,0 @@ - - - - - - -fmcib.models.load_model API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.models.load_model

-
-
-
- -Expand source code - -
from collections import OrderedDict
-
-import torch
-from loguru import logger
-from torch import nn
-
-
-class LoadModel(nn.Module):
-    """
-    A class representing a loaded model.
-
-    Args:
-        trunk (nn.Module, optional): The trunk of the model. Defaults to None.
-        weights_path (str, optional): The path to the weights file. Defaults to None.
-        heads (list, optional): The list of head layers in the model. Defaults to [].
-
-    Attributes:
-        trunk (nn.Module): The trunk of the model.
-        heads (nn.Sequential): The concatenated head layers of the model.
-
-    Methods:
-        forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model.
-        load(weights): Load the pretrained model weights.
-    """
-
-    def __init__(self, trunk=None, weights_path=None, heads=[]) -> None:
-        """
-        Initialize the model.
-
-        Args:
-            trunk (optional): The trunk of the model.
-            weights_path (optional): The path to the weights file.
-            heads (list, optional): A list of layer sizes for the heads of the model.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__()
-        self.trunk = trunk
-        head_layers = []
-        for idx in range(len(heads) - 1):
-            current_layers = []
-            current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
-            if idx != (len(heads) - 2):
-                current_layers.append(nn.ReLU(inplace=True))
-
-            head_layers.append(nn.Sequential(*current_layers))
-
-        if len(head_layers):
-            self.heads = nn.Sequential(*head_layers)
-        else:
-            self.heads = nn.Identity()
-
-        if weights_path is not None:
-            self.load(weights_path)
-
-    def forward(self, x: torch.Tensor):
-        """
-        Forward pass of the neural network.
-
-        Args:
-            x (torch.Tensor): The input tensor.
-
-        Returns:
-            torch.Tensor: The output tensor.
-        """
-        out = self.trunk(x)
-        out = self.heads(out)
-        return out
-
-    def load(self, weights):
-        """
-        Load pretrained model weights from a file.
-
-        Args:
-            weights (str): The path to the file containing the pretrained model weights.
-
-        Raises:
-            KeyError: If the input weights file does not contain the expected keys.
-            Exception: If there is an error when loading the pretrained heads.
-
-        Returns:
-            None.
-
-        Note:
-            This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
-        Warnings:
-            - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
-            - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
-        Raises the appropriate warnings and logs informational messages.
-        """
-        pretrained_model = torch.load(weights)
-
-        if "trunk_state_dict" in pretrained_model:  # Loading ViSSL pretrained model
-            trained_trunk = pretrained_model["trunk_state_dict"]
-            msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        if "state_dict" in pretrained_model:  # Loading Med3D pretrained model
-            trained_model = pretrained_model["state_dict"]
-
-            # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
-            weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
-            weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
-            msg = self.trunk.load_state_dict(weights, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-            weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
-            msg = self.heads.load_state_dict(weights, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        # Load trained heads
-        if "head_state_dict" in pretrained_model:
-            trained_heads = pretrained_model["head_state_dict"]
-
-            try:
-                msg = self.heads.load_state_dict(trained_heads, strict=False)
-            except Exception as e:
-                logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        logger.info(f"Loaded pretrained model weights \n")
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class LoadModel -(trunk=None, weights_path=None, heads=[]) -
-
-

A class representing a loaded model.

-

Args

-
-
trunk : nn.Module, optional
-
The trunk of the model. Defaults to None.
-
weights_path : str, optional
-
The path to the weights file. Defaults to None.
-
heads : list, optional
-
The list of head layers in the model. Defaults to [].
-
-

Attributes

-
-
trunk : nn.Module
-
The trunk of the model.
-
heads : nn.Sequential
-
The concatenated head layers of the model.
-
-

Methods

-

forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model. -load(weights): Load the pretrained model weights.

-

Initialize the model.

-

Args

-
-
trunk : optional
-
The trunk of the model.
-
weights_path : optional
-
The path to the weights file.
-
heads : list, optional
-
A list of layer sizes for the heads of the model.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class LoadModel(nn.Module):
-    """
-    A class representing a loaded model.
-
-    Args:
-        trunk (nn.Module, optional): The trunk of the model. Defaults to None.
-        weights_path (str, optional): The path to the weights file. Defaults to None.
-        heads (list, optional): The list of head layers in the model. Defaults to [].
-
-    Attributes:
-        trunk (nn.Module): The trunk of the model.
-        heads (nn.Sequential): The concatenated head layers of the model.
-
-    Methods:
-        forward(x: torch.Tensor) -> torch.Tensor: Forward pass through the model.
-        load(weights): Load the pretrained model weights.
-    """
-
-    def __init__(self, trunk=None, weights_path=None, heads=[]) -> None:
-        """
-        Initialize the model.
-
-        Args:
-            trunk (optional): The trunk of the model.
-            weights_path (optional): The path to the weights file.
-            heads (list, optional): A list of layer sizes for the heads of the model.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__()
-        self.trunk = trunk
-        head_layers = []
-        for idx in range(len(heads) - 1):
-            current_layers = []
-            current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
-            if idx != (len(heads) - 2):
-                current_layers.append(nn.ReLU(inplace=True))
-
-            head_layers.append(nn.Sequential(*current_layers))
-
-        if len(head_layers):
-            self.heads = nn.Sequential(*head_layers)
-        else:
-            self.heads = nn.Identity()
-
-        if weights_path is not None:
-            self.load(weights_path)
-
-    def forward(self, x: torch.Tensor):
-        """
-        Forward pass of the neural network.
-
-        Args:
-            x (torch.Tensor): The input tensor.
-
-        Returns:
-            torch.Tensor: The output tensor.
-        """
-        out = self.trunk(x)
-        out = self.heads(out)
-        return out
-
-    def load(self, weights):
-        """
-        Load pretrained model weights from a file.
-
-        Args:
-            weights (str): The path to the file containing the pretrained model weights.
-
-        Raises:
-            KeyError: If the input weights file does not contain the expected keys.
-            Exception: If there is an error when loading the pretrained heads.
-
-        Returns:
-            None.
-
-        Note:
-            This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
-        Warnings:
-            - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
-            - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
-        Raises the appropriate warnings and logs informational messages.
-        """
-        pretrained_model = torch.load(weights)
-
-        if "trunk_state_dict" in pretrained_model:  # Loading ViSSL pretrained model
-            trained_trunk = pretrained_model["trunk_state_dict"]
-            msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        if "state_dict" in pretrained_model:  # Loading Med3D pretrained model
-            trained_model = pretrained_model["state_dict"]
-
-            # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
-            weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
-            weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
-            msg = self.trunk.load_state_dict(weights, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-            weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
-            msg = self.heads.load_state_dict(weights, strict=False)
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        # Load trained heads
-        if "head_state_dict" in pretrained_model:
-            trained_heads = pretrained_model["head_state_dict"]
-
-            try:
-                msg = self.heads.load_state_dict(trained_heads, strict=False)
-            except Exception as e:
-                logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        logger.info(f"Loaded pretrained model weights \n")
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x: torch.Tensor) ‑> Callable[..., Any] -
-
-

Forward pass of the neural network.

-

Args

-
-
x : torch.Tensor
-
The input tensor.
-
-

Returns

-
-
torch.Tensor
-
The output tensor.
-
-
- -Expand source code - -
def forward(self, x: torch.Tensor):
-    """
-    Forward pass of the neural network.
-
-    Args:
-        x (torch.Tensor): The input tensor.
-
-    Returns:
-        torch.Tensor: The output tensor.
-    """
-    out = self.trunk(x)
-    out = self.heads(out)
-    return out
-
-
-
-def load(self, weights) -
-
-

Load pretrained model weights from a file.

-

Args

-
-
weights : str
-
The path to the file containing the pretrained model weights.
-
-

Raises

-
-
KeyError
-
If the input weights file does not contain the expected keys.
-
Exception
-
If there is an error when loading the pretrained heads.
-
-

Returns

-

None.

-

Note

-

This function assumes that the pretrained model weights file is in the format expected by the model architecture.

-

Warnings

-
    -
  • Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
  • -
  • Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
  • -
-

Raises the appropriate warnings and logs informational messages.

-
- -Expand source code - -
def load(self, weights):
-    """
-    Load pretrained model weights from a file.
-
-    Args:
-        weights (str): The path to the file containing the pretrained model weights.
-
-    Raises:
-        KeyError: If the input weights file does not contain the expected keys.
-        Exception: If there is an error when loading the pretrained heads.
-
-    Returns:
-        None.
-
-    Note:
-        This function assumes that the pretrained model weights file is in the format expected by the model architecture.
-
-    Warnings:
-        - Missing keys: This warning message indicates the keys in the pretrained model weights file that are missing from the current model.
-        - Unexpected keys: This warning message indicates the keys in the pretrained model weights file that are not expected by the current model.
-
-    Raises the appropriate warnings and logs informational messages.
-    """
-    pretrained_model = torch.load(weights)
-
-    if "trunk_state_dict" in pretrained_model:  # Loading ViSSL pretrained model
-        trained_trunk = pretrained_model["trunk_state_dict"]
-        msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-    if "state_dict" in pretrained_model:  # Loading Med3D pretrained model
-        trained_model = pretrained_model["state_dict"]
-
-        # match the keys (https://github.com/Project-MONAI/MONAI/issues/6811)
-        weights = {key.replace("module.", ""): value for key, value in trained_model.items()}
-        weights = {key.replace("model.trunk.", ""): value for key, value in trained_model.items()}
-        msg = self.trunk.load_state_dict(weights, strict=False)
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        weights = {key.replace("model.heads.", ""): value for key, value in trained_model.items()}
-        msg = self.heads.load_state_dict(weights, strict=False)
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-    # Load trained heads
-    if "head_state_dict" in pretrained_model:
-        trained_heads = pretrained_model["head_state_dict"]
-
-        try:
-            msg = self.heads.load_state_dict(trained_heads, strict=False)
-        except Exception as e:
-            logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-    logger.info(f"Loaded pretrained model weights \n")
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/models/models_genesis.html b/docs/api_docs/fmcib/models/models_genesis.html deleted file mode 100644 index 0560823..0000000 --- a/docs/api_docs/fmcib/models/models_genesis.html +++ /dev/null @@ -1,1427 +0,0 @@ - - - - - - -fmcib.models.models_genesis API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.models.models_genesis

-
-
-
- -Expand source code - -
import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
-    """
-    A class representing a 3D contextual batch normalization layer.
-
-    Attributes:
-        running_mean (torch.Tensor): The running mean of the batch normalization.
-        running_var (torch.Tensor): The running variance of the batch normalization.
-        weight (torch.Tensor): The learnable weights of the batch normalization.
-        bias (torch.Tensor): The learnable bias of the batch normalization.
-        momentum (float): The momentum for updating the running statistics.
-        eps (float): Small value added to the denominator for numerical stability.
-    """
-
-    def _check_input_dim(self, input):
-        """
-        Check if the input tensor is 5-dimensional.
-
-        Args:
-            input (torch.Tensor): Input tensor to check the dimensionality.
-
-        Raises:
-            ValueError: If the input tensor is not 5-dimensional.
-        """
-        if input.dim() != 5:
-            raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
-        # super(ContBatchNorm3d, self)._check_input_dim(input)
-
-    def forward(self, input):
-        """
-        Apply forward pass for the input through batch normalization layer.
-
-        Args:
-            input (Tensor): Input tensor to be normalized.
-
-        Returns:
-            Tensor: Normalized output tensor.
-
-        Raises:
-            ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
-        """
-        self._check_input_dim(input)
-        return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-
-
-class LUConv(nn.Module):
-    """
-    A class representing a LUConv module.
-
-    This module performs a convolution operation on the input data with a specified number of input channels and output channels.
-    The convolution is followed by batch normalization and an activation function.
-
-    Attributes:
-        in_chan (int): The number of input channels.
-        out_chan (int): The number of output channels.
-        act (str): The activation function to be applied. Can be one of 'relu', 'prelu', or 'elu'.
-    """
-
-    def __init__(self, in_chan, out_chan, act):
-        """
-        Initialize a LUConv layer.
-
-        Args:
-            in_chan (int): Number of input channels.
-            out_chan (int): Number of output channels.
-            act (str): Activation function. Options: 'relu', 'prelu', 'elu'.
-
-        Returns:
-            None
-
-        Raises:
-            TypeError: If the activation function is not one of the specified options.
-        """
-        super(LUConv, self).__init__()
-        self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
-        self.bn1 = ContBatchNorm3d(out_chan)
-
-        if act == "relu":
-            self.activation = nn.ReLU(out_chan)
-        elif act == "prelu":
-            self.activation = nn.PReLU(out_chan)
-        elif act == "elu":
-            self.activation = nn.ELU(inplace=True)
-        else:
-            raise
-
-    def forward(self, x):
-        """
-        Apply forward pass through the neural network.
-
-        Args:
-            x (Tensor): Input tensor to the network.
-
-        Returns:
-            Tensor: Output tensor after passing through the network.
-        """
-        out = self.activation(self.bn1(self.conv1(x)))
-        return out
-
-
-def _make_nConv(in_channel, depth, act, double_chnnel=False):
-    """
-    Make a two-layer convolutional neural network module.
-
-    Args:
-        in_channel (int): The number of input channels.
-        depth (int): The depth of the network.
-        act: Activation function to be used in the network.
-        double_channel (bool, optional): If True, double the number of channels in the network. Defaults to False.
-
-    Returns:
-        nn.Sequential: A sequential module representing the two-layer convolutional network.
-
-    Note:
-        - If double_channel is True, the first layer will have 32 * 2 ** (depth + 1) channels and the second layer will have the same number of channels.
-        - If double_channel is False, the first layer will have 32 * 2 ** depth channels and the second layer will have 32 * 2 ** depth * 2 channels.
-    """
-    if double_chnnel:
-        layer1 = LUConv(in_channel, 32 * (2 ** (depth + 1)), act)
-        layer2 = LUConv(32 * (2 ** (depth + 1)), 32 * (2 ** (depth + 1)), act)
-    else:
-        layer1 = LUConv(in_channel, 32 * (2**depth), act)
-        layer2 = LUConv(32 * (2**depth), 32 * (2**depth) * 2, act)
-
-    return nn.Sequential(layer1, layer2)
-
-
-# class InputTransition(nn.Module):
-#     def __init__(self, outChans, elu):
-#         super(InputTransition, self).__init__()
-#         self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
-#         self.bn1 = ContBatchNorm3d(16)
-#         self.relu1 = ELUCons(elu, 16)
-#
-#     def forward(self, x):
-#         # do we want a PRELU here as well?
-#         out = self.bn1(self.conv1(x))
-#         # split input in to 16 channels
-#         x16 = torch.cat((x, x, x, x, x, x, x, x,
-#                          x, x, x, x, x, x, x, x), 1)
-#         out = self.relu1(torch.add(out, x16))
-#         return out
-
-
-class DownTransition(nn.Module):
-    """
-    A class representing a down transition module in a neural network.
-
-    Attributes:
-        in_channel (int): The number of input channels.
-        depth (int): The depth of the down transition module.
-        act (nn.Module): The activation function used in the module.
-    """
-
-    def __init__(self, in_channel, depth, act):
-        """
-        Initialize a DownTransition object.
-
-        Args:
-            in_channel (int): The number of channels in the input.
-            depth (int): The depth of the DownTransition.
-            act (function): The activation function.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super(DownTransition, self).__init__()
-        self.ops = _make_nConv(in_channel, depth, act)
-        self.maxpool = nn.MaxPool3d(2)
-        self.current_depth = depth
-
-    def forward(self, x):
-        """
-        Perform a forward pass through the neural network.
-
-        Args:
-            x (Tensor): The input tensor.
-
-        Returns:
-            tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
-        Raises:
-            None
-        """
-        if self.current_depth == 3:
-            out = self.ops(x)
-            out_before_pool = out
-        else:
-            out_before_pool = self.ops(x)
-            out = self.maxpool(out_before_pool)
-        return out, out_before_pool
-
-
-class UpTransition(nn.Module):
-    """
-    A class representing an up transition layer in a neural network.
-
-    Attributes:
-        inChans (int): The number of input channels.
-        outChans (int): The number of output channels.
-        depth (int): The depth of the layer.
-        act (str): The activation function to be applied.
-    """
-
-    def __init__(self, inChans, outChans, depth, act):
-        """
-        Initialize the UpTransition module.
-
-        Args:
-            inChans (int): The number of input channels.
-            outChans (int): The number of output channels.
-            depth (int): The depth of the module.
-            act (nn.Module): The activation function to be used.
-
-        Returns:
-            None.
-
-        Raises:
-            None.
-        """
-        super(UpTransition, self).__init__()
-        self.depth = depth
-        self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
-        self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True)
-
-    def forward(self, x, skip_x):
-        """
-        Forward pass of the neural network.
-
-        Args:
-            x (torch.Tensor): Input tensor.
-            skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
-        Returns:
-            torch.Tensor: The output tensor after passing through the network.
-        """
-        out_up_conv = self.up_conv(x)
-        concat = torch.cat((out_up_conv, skip_x), 1)
-        out = self.ops(concat)
-        return out
-
-
-class OutputTransition(nn.Module):
-    """
-    A class representing the output transition in a neural network.
-
-    Attributes:
-        inChans (int): The number of input channels.
-        n_labels (int): The number of output labels.
-    """
-
-    def __init__(self, inChans, n_labels):
-        """
-        Initialize the OutputTransition class.
-
-        Args:
-            inChans (int): Number of input channels.
-            n_labels (int): Number of output labels.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super(OutputTransition, self).__init__()
-        self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x):
-        """
-        Forward pass through a neural network model.
-
-        Args:
-            x (Tensor): The input tensor.
-
-        Returns:
-            Tensor: The output tensor after passing through the model.
-        """
-        out = self.sigmoid(self.final_conv(x))
-        return out
-
-
-class UNet3D(nn.Module):
-    # the number of convolutions in each layer corresponds
-    # to what is in the actual prototxt, not the intent
-    """
-    A class representing a 3D UNet model for segmentation.
-
-    Attributes:
-        n_class (int): The number of classes for segmentation.
-        act (str): The activation function type used in the model.
-        decoder (bool): Whether to include the decoder part in the model.
-
-    Methods:
-        forward(x): Forward pass of the model.
-    """
-
-    def __init__(self, n_class=1, act="relu", decoder=True):
-        """
-        Initialize a 3D UNet neural network model.
-
-        Args:
-            n_class (int): The number of output classes. Defaults to 1.
-            act (str): The activation function to use. Defaults to 'relu'.
-            decoder (bool): Whether to include the decoder layers. Defaults to True.
-
-        Attributes:
-            decoder (bool): Whether the model includes decoder layers.
-            down_tr64 (DownTransition): The first down transition layer.
-            down_tr128 (DownTransition): The second down transition layer.
-            down_tr256 (DownTransition): The third down transition layer.
-            down_tr512 (DownTransition): The fourth down transition layer.
-            up_tr256 (UpTransition): The first up transition layer. (Only exists if `decoder` is True)
-            up_tr128 (UpTransition): The second up transition layer. (Only exists if `decoder` is True)
-            up_tr64 (UpTransition): The third up transition layer. (Only exists if `decoder` is True)
-            out_tr (OutputTransition): The output transition layer. (Only exists if `decoder` is True)
-            avg_pool (nn.AvgPool3d): The average pooling layer. (Only exists if `decoder` is False)
-            flatten (nn.Flatten): The flattening layer. (Only exists if `decoder` is False)
-        """
-        super(UNet3D, self).__init__()
-
-        self.decoder = decoder
-
-        self.down_tr64 = DownTransition(1, 0, act)
-        self.down_tr128 = DownTransition(64, 1, act)
-        self.down_tr256 = DownTransition(128, 2, act)
-        self.down_tr512 = DownTransition(256, 3, act)
-
-        if self.decoder:
-            self.up_tr256 = UpTransition(512, 512, 2, act)
-            self.up_tr128 = UpTransition(256, 256, 1, act)
-            self.up_tr64 = UpTransition(128, 128, 0, act)
-            self.out_tr = OutputTransition(64, n_class)
-        else:
-            self.avg_pool = nn.AvgPool3d(3, stride=2)
-            self.flatten = nn.Flatten()
-
-    def forward(self, x):
-        """
-        Perform forward pass through the neural network.
-
-        Args:
-            x (Tensor): Input tensor to the network.
-
-        Returns:
-            Tensor: Output tensor from the network.
-
-        Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
-        Raises:
-            None.
-        """
-        self.out64, self.skip_out64 = self.down_tr64(x)
-        self.out128, self.skip_out128 = self.down_tr128(self.out64)
-        self.out256, self.skip_out256 = self.down_tr256(self.out128)
-        self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
-        if self.decoder:
-            self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
-            self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
-            self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
-            self.out = self.out_tr(self.out_up_64)
-        else:
-            self.out = self.avg_pool(self.out512)
-            self.out = self.flatten(self.out)
-
-        return self.out
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class ContBatchNorm3d -(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, dtype=None) -
-
-

A class representing a 3D contextual batch normalization layer.

-

Attributes

-
-
running_mean : torch.Tensor
-
The running mean of the batch normalization.
-
running_var : torch.Tensor
-
The running variance of the batch normalization.
-
weight : torch.Tensor
-
The learnable weights of the batch normalization.
-
bias : torch.Tensor
-
The learnable bias of the batch normalization.
-
momentum : float
-
The momentum for updating the running statistics.
-
eps : float
-
Small value added to the denominator for numerical stability.
-
-

Initialize internal Module state, shared by both nn.Module and ScriptModule.

-
- -Expand source code - -
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
-    """
-    A class representing a 3D contextual batch normalization layer.
-
-    Attributes:
-        running_mean (torch.Tensor): The running mean of the batch normalization.
-        running_var (torch.Tensor): The running variance of the batch normalization.
-        weight (torch.Tensor): The learnable weights of the batch normalization.
-        bias (torch.Tensor): The learnable bias of the batch normalization.
-        momentum (float): The momentum for updating the running statistics.
-        eps (float): Small value added to the denominator for numerical stability.
-    """
-
-    def _check_input_dim(self, input):
-        """
-        Check if the input tensor is 5-dimensional.
-
-        Args:
-            input (torch.Tensor): Input tensor to check the dimensionality.
-
-        Raises:
-            ValueError: If the input tensor is not 5-dimensional.
-        """
-        if input.dim() != 5:
-            raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
-        # super(ContBatchNorm3d, self)._check_input_dim(input)
-
-    def forward(self, input):
-        """
-        Apply forward pass for the input through batch normalization layer.
-
-        Args:
-            input (Tensor): Input tensor to be normalized.
-
-        Returns:
-            Tensor: Normalized output tensor.
-
-        Raises:
-            ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
-        """
-        self._check_input_dim(input)
-        return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-
-

Ancestors

-
    -
  • torch.nn.modules.batchnorm._BatchNorm
  • -
  • torch.nn.modules.batchnorm._NormBase
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var affine : bool
-
-
-
-
var eps : float
-
-
-
-
var momentum : float
-
-
-
-
var num_features : int
-
-
-
-
var track_running_stats : bool
-
-
-
-
-

Methods

-
-
-def forward(self, input) ‑> Callable[..., Any] -
-
-

Apply forward pass for the input through batch normalization layer.

-

Args

-
-
input : Tensor
-
Input tensor to be normalized.
-
-

Returns

-
-
Tensor
-
Normalized output tensor.
-
-

Raises

-
-
ValueError
-
If the dimensions of the input tensor do not match the expected input dimensions.
-
-
- -Expand source code - -
def forward(self, input):
-    """
-    Apply forward pass for the input through batch normalization layer.
-
-    Args:
-        input (Tensor): Input tensor to be normalized.
-
-    Returns:
-        Tensor: Normalized output tensor.
-
-    Raises:
-        ValueError: If the dimensions of the input tensor do not match the expected input dimensions.
-    """
-    self._check_input_dim(input)
-    return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)
-
-
-
-
-
-class DownTransition -(in_channel, depth, act) -
-
-

A class representing a down transition module in a neural network.

-

Attributes

-
-
in_channel : int
-
The number of input channels.
-
depth : int
-
The depth of the down transition module.
-
act : nn.Module
-
The activation function used in the module.
-
-

Initialize a DownTransition object.

-

Args

-
-
in_channel : int
-
The number of channels in the input.
-
depth : int
-
The depth of the DownTransition.
-
act : function
-
The activation function.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class DownTransition(nn.Module):
-    """
-    A class representing a down transition module in a neural network.
-
-    Attributes:
-        in_channel (int): The number of input channels.
-        depth (int): The depth of the down transition module.
-        act (nn.Module): The activation function used in the module.
-    """
-
-    def __init__(self, in_channel, depth, act):
-        """
-        Initialize a DownTransition object.
-
-        Args:
-            in_channel (int): The number of channels in the input.
-            depth (int): The depth of the DownTransition.
-            act (function): The activation function.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super(DownTransition, self).__init__()
-        self.ops = _make_nConv(in_channel, depth, act)
-        self.maxpool = nn.MaxPool3d(2)
-        self.current_depth = depth
-
-    def forward(self, x):
-        """
-        Perform a forward pass through the neural network.
-
-        Args:
-            x (Tensor): The input tensor.
-
-        Returns:
-            tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
-        Raises:
-            None
-        """
-        if self.current_depth == 3:
-            out = self.ops(x)
-            out_before_pool = out
-        else:
-            out_before_pool = self.ops(x)
-            out = self.maxpool(out_before_pool)
-        return out, out_before_pool
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x) ‑> Callable[..., Any] -
-
-

Perform a forward pass through the neural network.

-

Args

-
-
x : Tensor
-
The input tensor.
-
-

Returns

-
-
tuple
-
A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
-

Raises

-

None

-
- -Expand source code - -
def forward(self, x):
-    """
-    Perform a forward pass through the neural network.
-
-    Args:
-        x (Tensor): The input tensor.
-
-    Returns:
-        tuple: A tuple containing two tensors. The first tensor is the output of the forward pass. The second tensor is the output before applying the max pooling operation.
-
-    Raises:
-        None
-    """
-    if self.current_depth == 3:
-        out = self.ops(x)
-        out_before_pool = out
-    else:
-        out_before_pool = self.ops(x)
-        out = self.maxpool(out_before_pool)
-    return out, out_before_pool
-
-
-
-
-
-class LUConv -(in_chan, out_chan, act) -
-
-

A class representing a LUConv module.

-

This module performs a convolution operation on the input data with a specified number of input channels and output channels. -The convolution is followed by batch normalization and an activation function.

-

Attributes

-
-
in_chan : int
-
The number of input channels.
-
out_chan : int
-
The number of output channels.
-
act : str
-
The activation function to be applied. Can be one of 'relu', 'prelu', or 'elu'.
-
-

Initialize a LUConv layer.

-

Args

-
-
in_chan : int
-
Number of input channels.
-
out_chan : int
-
Number of output channels.
-
act : str
-
Activation function. Options: 'relu', 'prelu', 'elu'.
-
-

Returns

-

None

-

Raises

-
-
TypeError
-
If the activation function is not one of the specified options.
-
-
- -Expand source code - -
class LUConv(nn.Module):
-    """
-    A class representing a LUConv module.
-
-    This module performs a convolution operation on the input data with a specified number of input channels and output channels.
-    The convolution is followed by batch normalization and an activation function.
-
-    Attributes:
-        in_chan (int): The number of input channels.
-        out_chan (int): The number of output channels.
-        act (str): The activation function to be applied. Can be one of 'relu', 'prelu', or 'elu'.
-    """
-
-    def __init__(self, in_chan, out_chan, act):
-        """
-        Initialize a LUConv layer.
-
-        Args:
-            in_chan (int): Number of input channels.
-            out_chan (int): Number of output channels.
-            act (str): Activation function. Options: 'relu', 'prelu', 'elu'.
-
-        Returns:
-            None
-
-        Raises:
-            TypeError: If the activation function is not one of the specified options.
-        """
-        super(LUConv, self).__init__()
-        self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
-        self.bn1 = ContBatchNorm3d(out_chan)
-
-        if act == "relu":
-            self.activation = nn.ReLU(out_chan)
-        elif act == "prelu":
-            self.activation = nn.PReLU(out_chan)
-        elif act == "elu":
-            self.activation = nn.ELU(inplace=True)
-        else:
-            raise
-
-    def forward(self, x):
-        """
-        Apply forward pass through the neural network.
-
-        Args:
-            x (Tensor): Input tensor to the network.
-
-        Returns:
-            Tensor: Output tensor after passing through the network.
-        """
-        out = self.activation(self.bn1(self.conv1(x)))
-        return out
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x) ‑> Callable[..., Any] -
-
-

Apply forward pass through the neural network.

-

Args

-
-
x : Tensor
-
Input tensor to the network.
-
-

Returns

-
-
Tensor
-
Output tensor after passing through the network.
-
-
- -Expand source code - -
def forward(self, x):
-    """
-    Apply forward pass through the neural network.
-
-    Args:
-        x (Tensor): Input tensor to the network.
-
-    Returns:
-        Tensor: Output tensor after passing through the network.
-    """
-    out = self.activation(self.bn1(self.conv1(x)))
-    return out
-
-
-
-
-
-class OutputTransition -(inChans, n_labels) -
-
-

A class representing the output transition in a neural network.

-

Attributes

-
-
inChans : int
-
The number of input channels.
-
n_labels : int
-
The number of output labels.
-
-

Initialize the OutputTransition class.

-

Args

-
-
inChans : int
-
Number of input channels.
-
n_labels : int
-
Number of output labels.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class OutputTransition(nn.Module):
-    """
-    A class representing the output transition in a neural network.
-
-    Attributes:
-        inChans (int): The number of input channels.
-        n_labels (int): The number of output labels.
-    """
-
-    def __init__(self, inChans, n_labels):
-        """
-        Initialize the OutputTransition class.
-
-        Args:
-            inChans (int): Number of input channels.
-            n_labels (int): Number of output labels.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super(OutputTransition, self).__init__()
-        self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x):
-        """
-        Forward pass through a neural network model.
-
-        Args:
-            x (Tensor): The input tensor.
-
-        Returns:
-            Tensor: The output tensor after passing through the model.
-        """
-        out = self.sigmoid(self.final_conv(x))
-        return out
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x) ‑> Callable[..., Any] -
-
-

Forward pass through a neural network model.

-

Args

-
-
x : Tensor
-
The input tensor.
-
-

Returns

-
-
Tensor
-
The output tensor after passing through the model.
-
-
- -Expand source code - -
def forward(self, x):
-    """
-    Forward pass through a neural network model.
-
-    Args:
-        x (Tensor): The input tensor.
-
-    Returns:
-        Tensor: The output tensor after passing through the model.
-    """
-    out = self.sigmoid(self.final_conv(x))
-    return out
-
-
-
-
-
-class UNet3D -(n_class=1, act='relu', decoder=True) -
-
-

A class representing a 3D UNet model for segmentation.

-

Attributes

-
-
n_class : int
-
The number of classes for segmentation.
-
act : str
-
The activation function type used in the model.
-
decoder : bool
-
Whether to include the decoder part in the model.
-
-

Methods

-

forward(x): Forward pass of the model.

-

Initialize a 3D UNet neural network model.

-

Args

-
-
n_class : int
-
The number of output classes. Defaults to 1.
-
act : str
-
The activation function to use. Defaults to 'relu'.
-
decoder : bool
-
Whether to include the decoder layers. Defaults to True.
-
-

Attributes

-
-
decoder : bool
-
Whether the model includes decoder layers.
-
down_tr64 : DownTransition
-
The first down transition layer.
-
down_tr128 : DownTransition
-
The second down transition layer.
-
down_tr256 : DownTransition
-
The third down transition layer.
-
down_tr512 : DownTransition
-
The fourth down transition layer.
-
up_tr256 : UpTransition
-
The first up transition layer. (Only exists if decoder is True)
-
up_tr128 : UpTransition
-
The second up transition layer. (Only exists if decoder is True)
-
up_tr64 : UpTransition
-
The third up transition layer. (Only exists if decoder is True)
-
out_tr : OutputTransition
-
The output transition layer. (Only exists if decoder is True)
-
avg_pool : nn.AvgPool3d
-
The average pooling layer. (Only exists if decoder is False)
-
flatten : nn.Flatten
-
The flattening layer. (Only exists if decoder is False)
-
-
- -Expand source code - -
class UNet3D(nn.Module):
-    # the number of convolutions in each layer corresponds
-    # to what is in the actual prototxt, not the intent
-    """
-    A class representing a 3D UNet model for segmentation.
-
-    Attributes:
-        n_class (int): The number of classes for segmentation.
-        act (str): The activation function type used in the model.
-        decoder (bool): Whether to include the decoder part in the model.
-
-    Methods:
-        forward(x): Forward pass of the model.
-    """
-
-    def __init__(self, n_class=1, act="relu", decoder=True):
-        """
-        Initialize a 3D UNet neural network model.
-
-        Args:
-            n_class (int): The number of output classes. Defaults to 1.
-            act (str): The activation function to use. Defaults to 'relu'.
-            decoder (bool): Whether to include the decoder layers. Defaults to True.
-
-        Attributes:
-            decoder (bool): Whether the model includes decoder layers.
-            down_tr64 (DownTransition): The first down transition layer.
-            down_tr128 (DownTransition): The second down transition layer.
-            down_tr256 (DownTransition): The third down transition layer.
-            down_tr512 (DownTransition): The fourth down transition layer.
-            up_tr256 (UpTransition): The first up transition layer. (Only exists if `decoder` is True)
-            up_tr128 (UpTransition): The second up transition layer. (Only exists if `decoder` is True)
-            up_tr64 (UpTransition): The third up transition layer. (Only exists if `decoder` is True)
-            out_tr (OutputTransition): The output transition layer. (Only exists if `decoder` is True)
-            avg_pool (nn.AvgPool3d): The average pooling layer. (Only exists if `decoder` is False)
-            flatten (nn.Flatten): The flattening layer. (Only exists if `decoder` is False)
-        """
-        super(UNet3D, self).__init__()
-
-        self.decoder = decoder
-
-        self.down_tr64 = DownTransition(1, 0, act)
-        self.down_tr128 = DownTransition(64, 1, act)
-        self.down_tr256 = DownTransition(128, 2, act)
-        self.down_tr512 = DownTransition(256, 3, act)
-
-        if self.decoder:
-            self.up_tr256 = UpTransition(512, 512, 2, act)
-            self.up_tr128 = UpTransition(256, 256, 1, act)
-            self.up_tr64 = UpTransition(128, 128, 0, act)
-            self.out_tr = OutputTransition(64, n_class)
-        else:
-            self.avg_pool = nn.AvgPool3d(3, stride=2)
-            self.flatten = nn.Flatten()
-
-    def forward(self, x):
-        """
-        Perform forward pass through the neural network.
-
-        Args:
-            x (Tensor): Input tensor to the network.
-
-        Returns:
-            Tensor: Output tensor from the network.
-
-        Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
-        Raises:
-            None.
-        """
-        self.out64, self.skip_out64 = self.down_tr64(x)
-        self.out128, self.skip_out128 = self.down_tr128(self.out64)
-        self.out256, self.skip_out256 = self.down_tr256(self.out128)
-        self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
-        if self.decoder:
-            self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
-            self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
-            self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
-            self.out = self.out_tr(self.out_up_64)
-        else:
-            self.out = self.avg_pool(self.out512)
-            self.out = self.flatten(self.out)
-
-        return self.out
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x) ‑> Callable[..., Any] -
-
-

Perform forward pass through the neural network.

-

Args

-
-
x : Tensor
-
Input tensor to the network.
-
-

Returns

-
-
Tensor
-
Output tensor from the network.
-
-

Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.

-

Raises

-

None.

-
- -Expand source code - -
def forward(self, x):
-    """
-    Perform forward pass through the neural network.
-
-    Args:
-        x (Tensor): Input tensor to the network.
-
-    Returns:
-        Tensor: Output tensor from the network.
-
-    Note: This function performs a series of operations to downsample the input tensor, followed by upsampling if the 'decoder' flag is set. If the 'decoder' flag is not set, the output tensor goes through average pooling and flattening.
-
-    Raises:
-        None.
-    """
-    self.out64, self.skip_out64 = self.down_tr64(x)
-    self.out128, self.skip_out128 = self.down_tr128(self.out64)
-    self.out256, self.skip_out256 = self.down_tr256(self.out128)
-    self.out512, self.skip_out512 = self.down_tr512(self.out256)
-
-    if self.decoder:
-        self.out_up_256 = self.up_tr256(self.out512, self.skip_out256)
-        self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
-        self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
-        self.out = self.out_tr(self.out_up_64)
-    else:
-        self.out = self.avg_pool(self.out512)
-        self.out = self.flatten(self.out)
-
-    return self.out
-
-
-
-
-
-class UpTransition -(inChans, outChans, depth, act) -
-
-

A class representing an up transition layer in a neural network.

-

Attributes

-
-
inChans : int
-
The number of input channels.
-
outChans : int
-
The number of output channels.
-
depth : int
-
The depth of the layer.
-
act : str
-
The activation function to be applied.
-
-

Initialize the UpTransition module.

-

Args

-
-
inChans : int
-
The number of input channels.
-
outChans : int
-
The number of output channels.
-
depth : int
-
The depth of the module.
-
act : nn.Module
-
The activation function to be used.
-
-

Returns

-

None.

-

Raises

-

None.

-
- -Expand source code - -
class UpTransition(nn.Module):
-    """
-    A class representing an up transition layer in a neural network.
-
-    Attributes:
-        inChans (int): The number of input channels.
-        outChans (int): The number of output channels.
-        depth (int): The depth of the layer.
-        act (str): The activation function to be applied.
-    """
-
-    def __init__(self, inChans, outChans, depth, act):
-        """
-        Initialize the UpTransition module.
-
-        Args:
-            inChans (int): The number of input channels.
-            outChans (int): The number of output channels.
-            depth (int): The depth of the module.
-            act (nn.Module): The activation function to be used.
-
-        Returns:
-            None.
-
-        Raises:
-            None.
-        """
-        super(UpTransition, self).__init__()
-        self.depth = depth
-        self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
-        self.ops = _make_nConv(inChans + outChans // 2, depth, act, double_chnnel=True)
-
-    def forward(self, x, skip_x):
-        """
-        Forward pass of the neural network.
-
-        Args:
-            x (torch.Tensor): Input tensor.
-            skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
-        Returns:
-            torch.Tensor: The output tensor after passing through the network.
-        """
-        out_up_conv = self.up_conv(x)
-        concat = torch.cat((out_up_conv, skip_x), 1)
-        out = self.ops(concat)
-        return out
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x, skip_x) ‑> Callable[..., Any] -
-
-

Forward pass of the neural network.

-

Args

-
-
x : torch.Tensor
-
Input tensor.
-
skip_x : torch.Tensor
-
Tensor to be concatenated with the upsampled convolution output.
-
-

Returns

-
-
torch.Tensor
-
The output tensor after passing through the network.
-
-
- -Expand source code - -
def forward(self, x, skip_x):
-    """
-    Forward pass of the neural network.
-
-    Args:
-        x (torch.Tensor): Input tensor.
-        skip_x (torch.Tensor): Tensor to be concatenated with the upsampled convolution output.
-
-    Returns:
-        torch.Tensor: The output tensor after passing through the network.
-    """
-    out_up_conv = self.up_conv(x)
-    concat = torch.cat((out_up_conv, skip_x), 1)
-    out = self.ops(concat)
-    return out
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/models/resnet50.html b/docs/api_docs/fmcib/models/resnet50.html deleted file mode 100644 index 36fe011..0000000 --- a/docs/api_docs/fmcib/models/resnet50.html +++ /dev/null @@ -1,185 +0,0 @@ - - - - - - -fmcib.models.resnet50 API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.models.resnet50

-
-
-
- -Expand source code - -
import os
-from pathlib import Path
-
-import torch
-import tqdm
-import wget
-from loguru import logger
-from monai.networks.nets import resnet50 as resnet50_monai
-
-from fmcib.utils.download_utils import bar_progress
-
-
-def resnet50(
-    pretrained=True,
-    device="cuda",
-    weights_path=None,
-    download_url="https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1",
-):
-    """
-    Constructs a ResNet-50 model for image classification.
-
-    Args:
-        pretrained (bool, optional): If True, loads the pretrained weights. Default is True.
-        device (str, optional): The device to load the model on. Default is "cuda".
-        weights_path (str or Path, optional): The path to the pretrained weights file. If None, the weights will be downloaded. Default is None.
-        download_url (str, optional): The URL to download the pretrained weights. Default is "https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1".
-
-    Returns:
-        torch.nn.Module: The ResNet-50 model.
-    """
-    logger.info(f"Loading pretrained foundation model (Resnet50) on {device}...")
-
-    model = resnet50_monai(pretrained=False, n_input_channels=1, widen_factor=2, conv1_t_stride=2, feed_forward=False)
-    model = model.to(device)
-    if pretrained:
-        if weights_path is None:
-            current_path = Path(os.getcwd())
-            if not (current_path / "fmcib.torch").exists():
-                wget.download(download_url, bar=bar_progress)
-            weights_path = current_path / "fmcib.torch"
-
-        checkpoint = torch.load(weights_path, map_location=device)
-
-        if "trunk_state_dict" in checkpoint:
-            model_state_dict = checkpoint["trunk_state_dict"]
-        elif "state_dict" in checkpoint:
-            model_state_dict = checkpoint["state_dict"]
-            model_state_dict = {key.replace("model.backbone.", ""): value for key, value in model_state_dict.items()}
-
-        model.load_state_dict(model_state_dict, strict=False)
-
-    return model
-
-
-
-
-
-
-
-

Functions

-
-
-def resnet50(pretrained=True, device='cuda', weights_path=None, download_url='https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1') -
-
-

Constructs a ResNet-50 model for image classification.

-

Args

-
-
pretrained : bool, optional
-
If True, loads the pretrained weights. Default is True.
-
device : str, optional
-
The device to load the model on. Default is "cuda".
-
weights_path : str or Path, optional
-
The path to the pretrained weights file. If None, the weights will be downloaded. Default is None.
-
download_url : str, optional
-
The URL to download the pretrained weights. Default is "https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1".
-
-

Returns

-
-
torch.nn.Module
-
The ResNet-50 model.
-
-
- -Expand source code - -
def resnet50(
-    pretrained=True,
-    device="cuda",
-    weights_path=None,
-    download_url="https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1",
-):
-    """
-    Constructs a ResNet-50 model for image classification.
-
-    Args:
-        pretrained (bool, optional): If True, loads the pretrained weights. Default is True.
-        device (str, optional): The device to load the model on. Default is "cuda".
-        weights_path (str or Path, optional): The path to the pretrained weights file. If None, the weights will be downloaded. Default is None.
-        download_url (str, optional): The URL to download the pretrained weights. Default is "https://www.dropbox.com/s/bd7azdsvx1jhalp/fmcib.torch?dl=1".
-
-    Returns:
-        torch.nn.Module: The ResNet-50 model.
-    """
-    logger.info(f"Loading pretrained foundation model (Resnet50) on {device}...")
-
-    model = resnet50_monai(pretrained=False, n_input_channels=1, widen_factor=2, conv1_t_stride=2, feed_forward=False)
-    model = model.to(device)
-    if pretrained:
-        if weights_path is None:
-            current_path = Path(os.getcwd())
-            if not (current_path / "fmcib.torch").exists():
-                wget.download(download_url, bar=bar_progress)
-            weights_path = current_path / "fmcib.torch"
-
-        checkpoint = torch.load(weights_path, map_location=device)
-
-        if "trunk_state_dict" in checkpoint:
-            model_state_dict = checkpoint["trunk_state_dict"]
-        elif "state_dict" in checkpoint:
-            model_state_dict = checkpoint["state_dict"]
-            model_state_dict = {key.replace("model.backbone.", ""): value for key, value in model_state_dict.items()}
-
-        model.load_state_dict(model_state_dict, strict=False)
-
-    return model
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/optimizers/index.html b/docs/api_docs/fmcib/optimizers/index.html deleted file mode 100644 index 4bd4a99..0000000 --- a/docs/api_docs/fmcib/optimizers/index.html +++ /dev/null @@ -1,75 +0,0 @@ - - - - - - -fmcib.optimizers API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.optimizers

-
-
-
- -Expand source code - -
from .lars import LARS
-
-
-
-

Sub-modules

-
-
fmcib.optimizers.lars
-
- -
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/optimizers/lars.html b/docs/api_docs/fmcib/optimizers/lars.html deleted file mode 100644 index 755a5b4..0000000 --- a/docs/api_docs/fmcib/optimizers/lars.html +++ /dev/null @@ -1,607 +0,0 @@ - - - - - - -fmcib.optimizers.lars API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.optimizers.lars

-
-
-

References

- -
- -Expand source code - -
"""
-References:
-    - https://arxiv.org/pdf/1708.03888.pdf
-    - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py
-"""
-import torch
-from torch.optim.optimizer import Optimizer, required
-
-
-class LARS(Optimizer):
-    """Extends SGD in PyTorch with LARS scaling from the paper
-    `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
-    Args:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float): learning rate
-        momentum (float, optional): momentum factor (default: 0)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        dampening (float, optional): dampening for momentum (default: 0)
-        nesterov (bool, optional): enables Nesterov momentum (default: False)
-        trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
-        eps (float, optional): eps for division denominator (default: 1e-8)
-
-    Example:
-        >>> model = torch.nn.Linear(10, 1)
-        >>> input = torch.Tensor(10)
-        >>> target = torch.Tensor([1.])
-        >>> loss_fn = lambda input, target: (input - target) ** 2
-        >>> #
-        >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
-        >>> optimizer.zero_grad()
-        >>> loss_fn(model(input), target).backward()
-        >>> optimizer.step()
-
-    .. note::
-        The application of momentum in the SGD part is modified according to
-        the PyTorch standards. LARS scaling fits into the equation in the
-        following fashion.
-
-        .. math::
-            \begin{aligned}
-                g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
-                v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
-                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
-            \\end{aligned}
-
-        where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
-        parameters, gradient, velocity, momentum, and weight decay respectively.
-        The :math:`lars_lr` is defined by Eq. 6 in the paper.
-        The Nesterov version is analogously modified.
-
-    .. warning::
-        Parameters with weight decay set to 0 will automatically be excluded from
-        layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
-        and BYOL.
-    """
-
-    def __init__(
-        self,
-        params,
-        lr=required,
-        momentum=0,
-        dampening=0,
-        weight_decay=0,
-        nesterov=False,
-        trust_coefficient=0.001,
-        eps=1e-8,
-    ):
-        """
-        Initialize an optimizer with the given parameters.
-
-        Args:
-            params (iterable): Iterable of parameters to optimize.
-            lr (float, optional): Learning rate. Default is required.
-            momentum (float, optional): Momentum factor. Default is 0.
-            dampening (float, optional): Dampening for momentum. Default is 0.
-            weight_decay (float, optional): Weight decay factor. Default is 0.
-            nesterov (bool, optional): Use Nesterov momentum. Default is False.
-            trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
-            eps (float, optional): Small value for numerical stability. Default is 1e-08.
-
-        Raises:
-            ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
-            ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
-        """
-        if lr is not required and lr < 0.0:
-            raise ValueError(f"Invalid learning rate: {lr}")
-        if momentum < 0.0:
-            raise ValueError(f"Invalid momentum value: {momentum}")
-        if weight_decay < 0.0:
-            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
-        defaults = dict(
-            lr=lr,
-            momentum=momentum,
-            dampening=dampening,
-            weight_decay=weight_decay,
-            nesterov=nesterov,
-            trust_coefficient=trust_coefficient,
-            eps=eps,
-        )
-        if nesterov and (momentum <= 0 or dampening != 0):
-            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
-        super().__init__(params, defaults)
-
-    def __setstate__(self, state):
-        """
-        Set the state of the optimizer.
-
-        Args:
-            state (dict): A dictionary containing the state of the optimizer.
-
-        Returns:
-            None
-
-        Note:
-            This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
-        """
-        super().__setstate__(state)
-
-        for group in self.param_groups:
-            group.setdefault("nesterov", False)
-
-    @torch.no_grad()
-    def step(self, closure=None):
-        """
-        Performs a single optimization step.
-
-        Parameters:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
-        """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
-        # exclude scaling for params with 0 weight decay
-        for group in self.param_groups:
-            weight_decay = group["weight_decay"]
-            momentum = group["momentum"]
-            dampening = group["dampening"]
-            nesterov = group["nesterov"]
-
-            for p in group["params"]:
-                if p.grad is None:
-                    continue
-
-                d_p = p.grad
-                p_norm = torch.norm(p.data)
-                g_norm = torch.norm(p.grad.data)
-
-                # lars scaling + weight decay part
-                if weight_decay != 0:
-                    if p_norm != 0 and g_norm != 0:
-                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                        lars_lr *= group["trust_coefficient"]
-
-                        d_p = d_p.add(p, alpha=weight_decay)
-                        d_p *= lars_lr
-
-                # sgd part
-                if momentum != 0:
-                    param_state = self.state[p]
-                    if "momentum_buffer" not in param_state:
-                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                    else:
-                        buf = param_state["momentum_buffer"]
-                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                    if nesterov:
-                        d_p = d_p.add(buf, alpha=momentum)
-                    else:
-                        d_p = buf
-
-                p.add_(d_p, alpha=-group["lr"])
-
-        return loss
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class LARS -(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08) -
-
-

Extends SGD in PyTorch with LARS scaling from the paper -Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>_.

-

Args

-
-
params : iterable
-
iterable of parameters to optimize or dicts defining -parameter groups
-
lr : float
-
learning rate
-
momentum : float, optional
-
momentum factor (default: 0)
-
weight_decay : float, optional
-
weight decay (L2 penalty) (default: 0)
-
dampening : float, optional
-
dampening for momentum (default: 0)
-
nesterov : bool, optional
-
enables Nesterov momentum (default: False)
-
trust_coefficient : float, optional
-
trust coefficient for computing LR (default: 0.001)
-
eps : float, optional
-
eps for division denominator (default: 1e-8)
-
-

Example

-
>>> model = torch.nn.Linear(10, 1)
->>> input = torch.Tensor(10)
->>> target = torch.Tensor([1.])
->>> loss_fn = lambda input, target: (input - target) ** 2
->>> #
->>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
->>> optimizer.zero_grad()
->>> loss_fn(model(input), target).backward()
->>> optimizer.step()
-
-
-

Note

-

The application of momentum in the SGD part is modified according to -the PyTorch standards. LARS scaling fits into the equation in the -following fashion.

-

[ egin{aligned} -g_{t+1} & = -ext{lars_lr} * (eta * p_{t} + g_{t+1}), \ -v_{t+1} & = \mu * v_{t} + g_{t+1}, \ -p_{t+1} & = p_{t} - -ext{lr} * v_{t+1}, -\end{aligned} ] -where :math:p, :math:g, :math:v, :math:\mu and :math:eta denote the -parameters, gradient, velocity, momentum, and weight decay respectively. -The :math:lars_lr is defined by Eq. 6 in the paper. -The Nesterov version is analogously modified.

-
-
-

Warning

-

Parameters with weight decay set to 0 will automatically be excluded from -layer-wise LR scaling. This is to ensure consistency with papers like SimCLR -and BYOL.

-
-

Initialize an optimizer with the given parameters.

-

Args

-
-
params : iterable
-
Iterable of parameters to optimize.
-
lr : float, optional
-
Learning rate. Default is required.
-
momentum : float, optional
-
Momentum factor. Default is 0.
-
dampening : float, optional
-
Dampening for momentum. Default is 0.
-
weight_decay : float, optional
-
Weight decay factor. Default is 0.
-
nesterov : bool, optional
-
Use Nesterov momentum. Default is False.
-
trust_coefficient : float, optional
-
Trust coefficient. Default is 0.001.
-
eps : float, optional
-
Small value for numerical stability. Default is 1e-08.
-
-

Raises

-
-
ValueError
-
If an invalid value is provided for lr, momentum, or weight_decay.
-
ValueError
-
If nesterov momentum is enabled without providing a momentum and zero dampening.
-
-
- -Expand source code - -
class LARS(Optimizer):
-    """Extends SGD in PyTorch with LARS scaling from the paper
-    `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
-    Args:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float): learning rate
-        momentum (float, optional): momentum factor (default: 0)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        dampening (float, optional): dampening for momentum (default: 0)
-        nesterov (bool, optional): enables Nesterov momentum (default: False)
-        trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
-        eps (float, optional): eps for division denominator (default: 1e-8)
-
-    Example:
-        >>> model = torch.nn.Linear(10, 1)
-        >>> input = torch.Tensor(10)
-        >>> target = torch.Tensor([1.])
-        >>> loss_fn = lambda input, target: (input - target) ** 2
-        >>> #
-        >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
-        >>> optimizer.zero_grad()
-        >>> loss_fn(model(input), target).backward()
-        >>> optimizer.step()
-
-    .. note::
-        The application of momentum in the SGD part is modified according to
-        the PyTorch standards. LARS scaling fits into the equation in the
-        following fashion.
-
-        .. math::
-            \begin{aligned}
-                g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
-                v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
-                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
-            \\end{aligned}
-
-        where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
-        parameters, gradient, velocity, momentum, and weight decay respectively.
-        The :math:`lars_lr` is defined by Eq. 6 in the paper.
-        The Nesterov version is analogously modified.
-
-    .. warning::
-        Parameters with weight decay set to 0 will automatically be excluded from
-        layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
-        and BYOL.
-    """
-
-    def __init__(
-        self,
-        params,
-        lr=required,
-        momentum=0,
-        dampening=0,
-        weight_decay=0,
-        nesterov=False,
-        trust_coefficient=0.001,
-        eps=1e-8,
-    ):
-        """
-        Initialize an optimizer with the given parameters.
-
-        Args:
-            params (iterable): Iterable of parameters to optimize.
-            lr (float, optional): Learning rate. Default is required.
-            momentum (float, optional): Momentum factor. Default is 0.
-            dampening (float, optional): Dampening for momentum. Default is 0.
-            weight_decay (float, optional): Weight decay factor. Default is 0.
-            nesterov (bool, optional): Use Nesterov momentum. Default is False.
-            trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
-            eps (float, optional): Small value for numerical stability. Default is 1e-08.
-
-        Raises:
-            ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
-            ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
-        """
-        if lr is not required and lr < 0.0:
-            raise ValueError(f"Invalid learning rate: {lr}")
-        if momentum < 0.0:
-            raise ValueError(f"Invalid momentum value: {momentum}")
-        if weight_decay < 0.0:
-            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
-        defaults = dict(
-            lr=lr,
-            momentum=momentum,
-            dampening=dampening,
-            weight_decay=weight_decay,
-            nesterov=nesterov,
-            trust_coefficient=trust_coefficient,
-            eps=eps,
-        )
-        if nesterov and (momentum <= 0 or dampening != 0):
-            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
-        super().__init__(params, defaults)
-
-    def __setstate__(self, state):
-        """
-        Set the state of the optimizer.
-
-        Args:
-            state (dict): A dictionary containing the state of the optimizer.
-
-        Returns:
-            None
-
-        Note:
-            This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
-        """
-        super().__setstate__(state)
-
-        for group in self.param_groups:
-            group.setdefault("nesterov", False)
-
-    @torch.no_grad()
-    def step(self, closure=None):
-        """
-        Performs a single optimization step.
-
-        Parameters:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
-        """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
-        # exclude scaling for params with 0 weight decay
-        for group in self.param_groups:
-            weight_decay = group["weight_decay"]
-            momentum = group["momentum"]
-            dampening = group["dampening"]
-            nesterov = group["nesterov"]
-
-            for p in group["params"]:
-                if p.grad is None:
-                    continue
-
-                d_p = p.grad
-                p_norm = torch.norm(p.data)
-                g_norm = torch.norm(p.grad.data)
-
-                # lars scaling + weight decay part
-                if weight_decay != 0:
-                    if p_norm != 0 and g_norm != 0:
-                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                        lars_lr *= group["trust_coefficient"]
-
-                        d_p = d_p.add(p, alpha=weight_decay)
-                        d_p *= lars_lr
-
-                # sgd part
-                if momentum != 0:
-                    param_state = self.state[p]
-                    if "momentum_buffer" not in param_state:
-                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                    else:
-                        buf = param_state["momentum_buffer"]
-                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                    if nesterov:
-                        d_p = d_p.add(buf, alpha=momentum)
-                    else:
-                        d_p = buf
-
-                p.add_(d_p, alpha=-group["lr"])
-
-        return loss
-
-

Ancestors

-
    -
  • torch.optim.optimizer.Optimizer
  • -
-

Class variables

-
-
var OptimizerPostHook : typing_extensions.TypeAlias
-
-
-
-
var OptimizerPreHook : typing_extensions.TypeAlias
-
-
-
-
-

Methods

-
-
-def step(self, closure=None) -
-
-

Performs a single optimization step.

-

Parameters

-

closure (callable, optional): A closure that reevaluates the model -and returns the loss.

-
- -Expand source code - -
@torch.no_grad()
-def step(self, closure=None):
-    """
-    Performs a single optimization step.
-
-    Parameters:
-        closure (callable, optional): A closure that reevaluates the model
-            and returns the loss.
-    """
-    loss = None
-    if closure is not None:
-        with torch.enable_grad():
-            loss = closure()
-
-    # exclude scaling for params with 0 weight decay
-    for group in self.param_groups:
-        weight_decay = group["weight_decay"]
-        momentum = group["momentum"]
-        dampening = group["dampening"]
-        nesterov = group["nesterov"]
-
-        for p in group["params"]:
-            if p.grad is None:
-                continue
-
-            d_p = p.grad
-            p_norm = torch.norm(p.data)
-            g_norm = torch.norm(p.grad.data)
-
-            # lars scaling + weight decay part
-            if weight_decay != 0:
-                if p_norm != 0 and g_norm != 0:
-                    lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                    lars_lr *= group["trust_coefficient"]
-
-                    d_p = d_p.add(p, alpha=weight_decay)
-                    d_p *= lars_lr
-
-            # sgd part
-            if momentum != 0:
-                param_state = self.state[p]
-                if "momentum_buffer" not in param_state:
-                    buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                else:
-                    buf = param_state["momentum_buffer"]
-                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                if nesterov:
-                    d_p = d_p.add(buf, alpha=momentum)
-                else:
-                    d_p = buf
-
-            p.add_(d_p, alpha=-group["lr"])
-
-    return loss
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/preprocessing/index.html b/docs/api_docs/fmcib/preprocessing/index.html deleted file mode 100644 index 22b30f3..0000000 --- a/docs/api_docs/fmcib/preprocessing/index.html +++ /dev/null @@ -1,213 +0,0 @@ - - - - - - -fmcib.preprocessing API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.preprocessing

-
-
-
- -Expand source code - -
import monai
-import torchvision
-from loguru import logger
-from monai import transforms as monai_transforms
-
-from .seed_based_crop import SeedBasedPatchCropd
-
-
-def preprocess(image, spatial_size=(50, 50, 50)):
-    T = get_transforms(spatial_size=spatial_size)
-    return T(image)
-
-
-def get_transforms(spatial_size=(50, 50, 50), precropped=False):
-    if precropped:
-        return monai_transforms.Compose(
-            [
-                monai_transforms.LoadImaged(keys=["image_path"], image_only=True),
-                monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-                monai_transforms.ScaleIntensityRanged(
-                    keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
-                ),
-                monai_transforms.SelectItemsd(keys=["image_path"]),
-                monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
-                torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
-            ]
-        )
-    else:
-        return monai_transforms.Compose(
-            [
-                monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
-                monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-                monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
-                monai_transforms.ScaleIntensityRanged(
-                    keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
-                ),
-                monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
-                SeedBasedPatchCropd(
-                    keys=["image_path"], roi_size=spatial_size[::-1], coord_orientation="LPS", global_coordinates=True
-                ),
-                monai_transforms.SelectItemsd(keys=["image_path"]),
-                monai_transforms.Transposed(keys=["image_path"], indices=(0, 3, 2, 1)),
-                monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
-                torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
-            ]
-        )
-
-
-def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False):
-    logger.info("Building dataloader instance ...")
-    T = get_transforms(spatial_size=spatial_size, precropped=precropped)
-    dataset = monai.data.CSVDataset(csv_path, transform=T)
-    dataloader = monai.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
-    return dataloader
-
-
-
-

Sub-modules

-
-
fmcib.preprocessing.seed_based_crop
-
-

Author: Suraj Pai -Email: bspai@bwh.harvard.edu -This script contains two classes: -1. SeedBasedPatchCropd -2. SeedBasedPatchCrop

-
-
-
-
-
-
-

Functions

-
-
-def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False) -
-
-
-
- -Expand source code - -
def get_dataloader(csv_path, batch_size=4, num_workers=4, spatial_size=(50, 50, 50), precropped=False):
-    logger.info("Building dataloader instance ...")
-    T = get_transforms(spatial_size=spatial_size, precropped=precropped)
-    dataset = monai.data.CSVDataset(csv_path, transform=T)
-    dataloader = monai.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
-    return dataloader
-
-
-
-def get_transforms(spatial_size=(50, 50, 50), precropped=False) -
-
-
-
- -Expand source code - -
def get_transforms(spatial_size=(50, 50, 50), precropped=False):
-    if precropped:
-        return monai_transforms.Compose(
-            [
-                monai_transforms.LoadImaged(keys=["image_path"], image_only=True),
-                monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-                monai_transforms.ScaleIntensityRanged(
-                    keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
-                ),
-                monai_transforms.SelectItemsd(keys=["image_path"]),
-                monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
-                torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
-            ]
-        )
-    else:
-        return monai_transforms.Compose(
-            [
-                monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
-                monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-                monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
-                monai_transforms.ScaleIntensityRanged(
-                    keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True
-                ),
-                monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
-                SeedBasedPatchCropd(
-                    keys=["image_path"], roi_size=spatial_size[::-1], coord_orientation="LPS", global_coordinates=True
-                ),
-                monai_transforms.SelectItemsd(keys=["image_path"]),
-                monai_transforms.Transposed(keys=["image_path"], indices=(0, 3, 2, 1)),
-                monai_transforms.SpatialPadd(keys=["image_path"], spatial_size=spatial_size),
-                torchvision.transforms.Lambda(lambda x: x["image_path"].as_tensor()),
-            ]
-        )
-
-
-
-def preprocess(image, spatial_size=(50, 50, 50)) -
-
-
-
- -Expand source code - -
def preprocess(image, spatial_size=(50, 50, 50)):
-    T = get_transforms(spatial_size=spatial_size)
-    return T(image)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/preprocessing/seed_based_crop.html b/docs/api_docs/fmcib/preprocessing/seed_based_crop.html deleted file mode 100644 index 90effd2..0000000 --- a/docs/api_docs/fmcib/preprocessing/seed_based_crop.html +++ /dev/null @@ -1,452 +0,0 @@ - - - - - - -fmcib.preprocessing.seed_based_crop API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.preprocessing.seed_based_crop

-
-
-

Author: Suraj Pai -Email: bspai@bwh.harvard.edu -This script contains two classes: -1. SeedBasedPatchCropd -2. SeedBasedPatchCrop

-
- -Expand source code - -
"""
-Author: Suraj Pai
-Email: bspai@bwh.harvard.edu
-This script contains two classes:
-1. SeedBasedPatchCropd
-2. SeedBasedPatchCrop
-"""
-
-from typing import Any, Dict, Hashable, Mapping, Tuple
-
-import numpy as np
-from monai.config.type_definitions import NdarrayOrTensor
-from monai.transforms import MapTransform, Transform
-
-
-class SeedBasedPatchCropd(MapTransform):
-    """
-    A class representing a seed-based patch crop transformation.
-
-    Inherits from MapTransform.
-
-    Attributes:
-        keys (list): List of keys for images in the input data dictionary.
-        roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
-        allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
-        coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
-        global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
-    """
-
-    def __init__(self, keys, roi_size, allow_missing_keys=False, coord_orientation="RAS", global_coordinates=True) -> None:
-        """
-        Initialize SeedBasedPatchCropd class.
-
-        Args:
-            keys (List): List of keys for images in the input data dictionary.
-            roi_size (Tuple): Tuple indicating the size of the region of interest (ROI).
-            allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
-            coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
-            global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
-        """
-        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
-        self.coord_orientation = coord_orientation
-        self.global_coordinates = global_coordinates
-        self.cropper = SeedBasedPatchCrop(roi_size=roi_size)
-
-    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
-        """
-        Apply transformation to given data.
-
-        Args:
-            data (dict): Dictionary with image keys and required center coordinates.
-
-        Returns:
-            dict: Dictionary with cropped patches for each key in the input data dictionary.
-        """
-        d = dict(data)
-
-        assert "coordX" in d.keys(), "coordX not found in data"
-        assert "coordY" in d.keys(), "coordY not found in data"
-        assert "coordZ" in d.keys(), "coordZ not found in data"
-
-        # Convert coordinates to RAS orientation to match image orientation
-        if self.coord_orientation == "RAS":
-            center = (d["coordX"], d["coordY"], d["coordZ"])
-        elif self.coord_orientation == "LPS":
-            center = (-d["coordX"], -d["coordY"], d["coordZ"])
-
-        for key in self.key_iterator(d):
-            d[key] = self.cropper(d[key], center=center, global_coordinates=self.global_coordinates)
-        return d
-
-
-class SeedBasedPatchCrop(Transform):
-    """
-    A class representing a seed-based patch crop transformation.
-
-    Attributes:
-        roi_size: Tuple indicating the size of the region of interest (ROI).
-
-    Methods:
-        __call__: Crop a patch from the input image centered around the seed coordinate.
-
-    Args:
-        roi_size: Tuple indicating the size of the region of interest (ROI).
-
-    Returns:
-        NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-
-    Raises:
-        AssertionError: If the input image has dimensions other than (C, H, W, D)
-        AssertionError: If the coordinates are invalid (e.g., min_h >= max_h)
-    """
-
-    def __init__(self, roi_size) -> None:
-        """
-        Initialize SeedBasedPatchCrop class.
-
-        Args:
-            roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
-        """
-        super().__init__()
-        self.roi_size = roi_size
-
-    def __call__(self, img: NdarrayOrTensor, center: tuple, global_coordinates=False) -> NdarrayOrTensor:
-        """
-        Crop a patch from the input image centered around the seed coordinate.
-
-        Args:
-            img (NdarrayOrTensor): Image to crop, with dimensions (C, H, W, D). C is the number of channels.
-            center (tuple): Seed coordinate around which to crop the patch (X, Y, Z).
-            global_coordinates (bool): If True, seed coordinate is in global space; otherwise, local space.
-
-        Returns:
-            NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-        """
-        assert len(img.shape) == 4, "Input image must have dimensions: (C, H, W, D)"
-        C, H, W, D = img.shape
-        Ph, Pw, Pd = self.roi_size
-
-        # If global coordinates, convert to local coordinates
-        if global_coordinates:
-            center = np.linalg.inv(np.array(img.affine)) @ np.array(center + (1,))
-            center = [int(x) for x in center[:3]]
-
-        # Calculate and clamp the ranges for cropping
-        min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
-        min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
-        min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
-        # Check if coordinates are valid
-        assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
-        assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
-        assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
-        # Crop the patch from the image
-        patch = img[:, min_h:max_h, min_w:max_w, min_d:max_d]
-
-        return patch
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SeedBasedPatchCrop -(roi_size) -
-
-

A class representing a seed-based patch crop transformation.

-

Attributes

-
-
roi_size
-
Tuple indicating the size of the region of interest (ROI).
-
-

Methods

-

call: Crop a patch from the input image centered around the seed coordinate.

-

Args

-
-
roi_size
-
Tuple indicating the size of the region of interest (ROI).
-
-

Returns

-
-
NdarrayOrTensor
-
Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-
-

Raises

-
-
AssertionError
-
If the input image has dimensions other than (C, H, W, D)
-
AssertionError
-
If the coordinates are invalid (e.g., min_h >= max_h)
-
-

Initialize SeedBasedPatchCrop class.

-

Args

-
-
roi_size : tuple
-
Tuple indicating the size of the region of interest (ROI).
-
-
- -Expand source code - -
class SeedBasedPatchCrop(Transform):
-    """
-    A class representing a seed-based patch crop transformation.
-
-    Attributes:
-        roi_size: Tuple indicating the size of the region of interest (ROI).
-
-    Methods:
-        __call__: Crop a patch from the input image centered around the seed coordinate.
-
-    Args:
-        roi_size: Tuple indicating the size of the region of interest (ROI).
-
-    Returns:
-        NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-
-    Raises:
-        AssertionError: If the input image has dimensions other than (C, H, W, D)
-        AssertionError: If the coordinates are invalid (e.g., min_h >= max_h)
-    """
-
-    def __init__(self, roi_size) -> None:
-        """
-        Initialize SeedBasedPatchCrop class.
-
-        Args:
-            roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
-        """
-        super().__init__()
-        self.roi_size = roi_size
-
-    def __call__(self, img: NdarrayOrTensor, center: tuple, global_coordinates=False) -> NdarrayOrTensor:
-        """
-        Crop a patch from the input image centered around the seed coordinate.
-
-        Args:
-            img (NdarrayOrTensor): Image to crop, with dimensions (C, H, W, D). C is the number of channels.
-            center (tuple): Seed coordinate around which to crop the patch (X, Y, Z).
-            global_coordinates (bool): If True, seed coordinate is in global space; otherwise, local space.
-
-        Returns:
-            NdarrayOrTensor: Cropped patch of shape (C, Ph, Pw, Pd), where (Ph, Pw, Pd) is the patch size.
-        """
-        assert len(img.shape) == 4, "Input image must have dimensions: (C, H, W, D)"
-        C, H, W, D = img.shape
-        Ph, Pw, Pd = self.roi_size
-
-        # If global coordinates, convert to local coordinates
-        if global_coordinates:
-            center = np.linalg.inv(np.array(img.affine)) @ np.array(center + (1,))
-            center = [int(x) for x in center[:3]]
-
-        # Calculate and clamp the ranges for cropping
-        min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
-        min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
-        min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
-        # Check if coordinates are valid
-        assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
-        assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
-        assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
-        # Crop the patch from the image
-        patch = img[:, min_h:max_h, min_w:max_w, min_d:max_d]
-
-        return patch
-
-

Ancestors

-
    -
  • monai.transforms.transform.Transform
  • -
  • abc.ABC
  • -
-

Class variables

-
-
var backend : list[TransformBackends]
-
-
-
-
-
-
-class SeedBasedPatchCropd -(keys, roi_size, allow_missing_keys=False, coord_orientation='RAS', global_coordinates=True) -
-
-

A class representing a seed-based patch crop transformation.

-

Inherits from MapTransform.

-

Attributes

-
-
keys : list
-
List of keys for images in the input data dictionary.
-
roi_size : tuple
-
Tuple indicating the size of the region of interest (ROI).
-
allow_missing_keys : bool
-
If True, do not raise an error if some keys in the input data dictionary are missing.
-
coord_orientation : str
-
Coordinate system (RAS or LPS) of input coordinates.
-
global_coordinates : bool
-
If True, coordinates are in global coordinates; otherwise, local coordinates.
-
-

Initialize SeedBasedPatchCropd class.

-

Args

-
-
keys : List
-
List of keys for images in the input data dictionary.
-
roi_size : Tuple
-
Tuple indicating the size of the region of interest (ROI).
-
allow_missing_keys : bool
-
If True, do not raise an error if some keys in the input data dictionary are missing.
-
coord_orientation : str
-
Coordinate system (RAS or LPS) of input coordinates.
-
global_coordinates : bool
-
If True, coordinates are in global coordinates; otherwise, local coordinates.
-
-
- -Expand source code - -
class SeedBasedPatchCropd(MapTransform):
-    """
-    A class representing a seed-based patch crop transformation.
-
-    Inherits from MapTransform.
-
-    Attributes:
-        keys (list): List of keys for images in the input data dictionary.
-        roi_size (tuple): Tuple indicating the size of the region of interest (ROI).
-        allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
-        coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
-        global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
-    """
-
-    def __init__(self, keys, roi_size, allow_missing_keys=False, coord_orientation="RAS", global_coordinates=True) -> None:
-        """
-        Initialize SeedBasedPatchCropd class.
-
-        Args:
-            keys (List): List of keys for images in the input data dictionary.
-            roi_size (Tuple): Tuple indicating the size of the region of interest (ROI).
-            allow_missing_keys (bool): If True, do not raise an error if some keys in the input data dictionary are missing.
-            coord_orientation (str): Coordinate system (RAS or LPS) of input coordinates.
-            global_coordinates (bool): If True, coordinates are in global coordinates; otherwise, local coordinates.
-        """
-        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
-        self.coord_orientation = coord_orientation
-        self.global_coordinates = global_coordinates
-        self.cropper = SeedBasedPatchCrop(roi_size=roi_size)
-
-    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
-        """
-        Apply transformation to given data.
-
-        Args:
-            data (dict): Dictionary with image keys and required center coordinates.
-
-        Returns:
-            dict: Dictionary with cropped patches for each key in the input data dictionary.
-        """
-        d = dict(data)
-
-        assert "coordX" in d.keys(), "coordX not found in data"
-        assert "coordY" in d.keys(), "coordY not found in data"
-        assert "coordZ" in d.keys(), "coordZ not found in data"
-
-        # Convert coordinates to RAS orientation to match image orientation
-        if self.coord_orientation == "RAS":
-            center = (d["coordX"], d["coordY"], d["coordZ"])
-        elif self.coord_orientation == "LPS":
-            center = (-d["coordX"], -d["coordY"], d["coordZ"])
-
-        for key in self.key_iterator(d):
-            d[key] = self.cropper(d[key], center=center, global_coordinates=self.global_coordinates)
-        return d
-
-

Ancestors

-
    -
  • monai.transforms.transform.MapTransform
  • -
  • monai.transforms.transform.Transform
  • -
  • abc.ABC
  • -
-

Class variables

-
-
var backend : list[TransformBackends]
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/run.html b/docs/api_docs/fmcib/run.html deleted file mode 100644 index f7c01bf..0000000 --- a/docs/api_docs/fmcib/run.html +++ /dev/null @@ -1,212 +0,0 @@ - - - - - - -fmcib.run API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.run

-
-
-
- -Expand source code - -
import numpy as np
-import pandas as pd
-import torch
-from loguru import logger
-from monai.networks.nets import resnet50
-from tqdm import tqdm
-
-from .models import LoadModel, fmcib_model
-from .preprocessing import get_dataloader
-
-
-def get_features(
-    csv_path,
-    weights_path=None,
-    spatial_size=(50, 50, 50),
-    precropped=False,
-):
-    """
-    Extracts features from images specified in a CSV file.
-
-    Args:
-        csv_path (str): Path to the CSV file containing image paths.
-        weights_path (str, optional): Path to the pre-trained weights file. Default is None.
-        spatial_size (tuple, optional): Spatial size of the input images. Default is (50, 50, 50).
-        precropped (bool, optional): Whether the images are already pre-cropped. Default is False.
-
-    Returns:
-        pandas.DataFrame: DataFrame containing the original data from the CSV file along with the extracted features.
-    """
-    logger.info("Loading CSV file ...")
-    df = pd.read_csv(csv_path)
-    dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped)
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    if weights_path is None:
-        model = fmcib_model().to(device)
-    else:
-        logger.warning(
-            "Loading custom model provided from weights file. If this is not intended, please do not provide the weights_path argument."
-        )
-        trunk = resnet50(
-            pretrained=False,
-            n_input_channels=1,
-            widen_factor=2,
-            conv1_t_stride=2,
-            feed_forward=False,
-            bias_downsample=True,
-        )
-        model = LoadModel(trunk=trunk, weights_path=weights_path).to(device)
-
-    feature_list = []
-    logger.info("Running inference over batches ...")
-
-    model.eval()
-    for batch in tqdm(dataloader, total=len(dataloader)):
-        feature = model(batch.to(device)).detach().cpu().numpy()
-        feature_list.append(feature)
-
-    features = np.concatenate(feature_list, axis=0)
-    # Flatten features into a list
-    features = features.reshape(-1, 4096)
-
-    # Add the features to the dataframe
-    df = pd.concat([df, pd.DataFrame(features, columns=[f"pred_{idx}" for idx in range(4096)])], axis=1)
-    return df
-
-
-
-
-
-
-
-

Functions

-
-
-def get_features(csv_path, weights_path=None, spatial_size=(50, 50, 50), precropped=False) -
-
-

Extracts features from images specified in a CSV file.

-

Args

-
-
csv_path : str
-
Path to the CSV file containing image paths.
-
weights_path : str, optional
-
Path to the pre-trained weights file. Default is None.
-
spatial_size : tuple, optional
-
Spatial size of the input images. Default is (50, 50, 50).
-
precropped : bool, optional
-
Whether the images are already pre-cropped. Default is False.
-
-

Returns

-
-
pandas.DataFrame
-
DataFrame containing the original data from the CSV file along with the extracted features.
-
-
- -Expand source code - -
def get_features(
-    csv_path,
-    weights_path=None,
-    spatial_size=(50, 50, 50),
-    precropped=False,
-):
-    """
-    Extracts features from images specified in a CSV file.
-
-    Args:
-        csv_path (str): Path to the CSV file containing image paths.
-        weights_path (str, optional): Path to the pre-trained weights file. Default is None.
-        spatial_size (tuple, optional): Spatial size of the input images. Default is (50, 50, 50).
-        precropped (bool, optional): Whether the images are already pre-cropped. Default is False.
-
-    Returns:
-        pandas.DataFrame: DataFrame containing the original data from the CSV file along with the extracted features.
-    """
-    logger.info("Loading CSV file ...")
-    df = pd.read_csv(csv_path)
-    dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped)
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    if weights_path is None:
-        model = fmcib_model().to(device)
-    else:
-        logger.warning(
-            "Loading custom model provided from weights file. If this is not intended, please do not provide the weights_path argument."
-        )
-        trunk = resnet50(
-            pretrained=False,
-            n_input_channels=1,
-            widen_factor=2,
-            conv1_t_stride=2,
-            feed_forward=False,
-            bias_downsample=True,
-        )
-        model = LoadModel(trunk=trunk, weights_path=weights_path).to(device)
-
-    feature_list = []
-    logger.info("Running inference over batches ...")
-
-    model.eval()
-    for batch in tqdm(dataloader, total=len(dataloader)):
-        feature = model(batch.to(device)).detach().cpu().numpy()
-        feature_list.append(feature)
-
-    features = np.concatenate(feature_list, axis=0)
-    # Flatten features into a list
-    features = features.reshape(-1, 4096)
-
-    # Add the features to the dataframe
-    df = pd.concat([df, pd.DataFrame(features, columns=[f"pred_{idx}" for idx in range(4096)])], axis=1)
-    return df
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/index.html b/docs/api_docs/fmcib/ssl/index.html deleted file mode 100644 index 64dfbd7..0000000 --- a/docs/api_docs/fmcib/ssl/index.html +++ /dev/null @@ -1,70 +0,0 @@ - - - - - - -fmcib.ssl API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl

-
-
-
-
-

Sub-modules

-
-
fmcib.ssl.losses
-
-
-
-
fmcib.ssl.modules
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/index.html b/docs/api_docs/fmcib/ssl/losses/index.html deleted file mode 100644 index ef5c04e..0000000 --- a/docs/api_docs/fmcib/ssl/losses/index.html +++ /dev/null @@ -1,95 +0,0 @@ - - - - - - -fmcib.ssl.losses API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses

-
-
-
- -Expand source code - -
from .neg_mining_info_nce_loss import NegativeMiningInfoNCECriterion
-from .nnclr_loss import NNCLRLoss
-from .ntxent_loss import NTXentLoss
-from .ntxent_mined_loss import NTXentNegativeMinedLoss
-from .swav_loss import SwaVLoss
-
-
-
-

Sub-modules

-
-
fmcib.ssl.losses.neg_mining_info_nce_loss
-
-
-
-
fmcib.ssl.losses.nnclr_loss
-
-
-
-
fmcib.ssl.losses.ntxent_loss
-
-
-
-
fmcib.ssl.losses.ntxent_mined_loss
-
-

Contrastive Loss Functions

-
-
fmcib.ssl.losses.swav_loss
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/neg_mining_info_nce_loss.html b/docs/api_docs/fmcib/ssl/losses/neg_mining_info_nce_loss.html deleted file mode 100644 index bb08380..0000000 --- a/docs/api_docs/fmcib/ssl/losses/neg_mining_info_nce_loss.html +++ /dev/null @@ -1,708 +0,0 @@ - - - - - - -fmcib.ssl.losses.neg_mining_info_nce_loss API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses.neg_mining_info_nce_loss

-
-
-
- -Expand source code - -
# Copyright (c) Facebook, Inc. and its affiliates.
-
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import pprint
-
-import numpy as np
-import torch
-from lightly.utils import dist
-from loguru import logger
-from torch import nn
-
-
-class NegativeMiningInfoNCECriterion(nn.Module):
-    """
-    The criterion corresponding to the SimCLR loss as defined in the paper
-    https://arxiv.org/abs/2002.05709.
-
-    Args:
-        temperature (float): The temperature to be applied on the logits.
-        buffer_params (dict): A dictionary containing the following keys:
-            - world_size (int): Total number of trainers in training.
-            - embedding_dim (int): Output dimensions of the features projects.
-            - effective_batch_size (int): Total batch size used (includes positives).
-    """
-
-    def __init__(
-        self, embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True
-    ):
-        """
-        Initialize the NegativeMiningInfoNCECriterion class.
-
-        Args:
-            embedding_dim (int): The dimension of the embedding space.
-            batch_size (int): The size of the input batch.
-            world_size (int): The number of distributed processes.
-            gather_distributed (bool): Whether to gather distributed data.
-            temperature (float): The temperature used in the computation.
-            balanced (bool): Whether to use balanced sampling.
-
-        Attributes:
-            embedding_dim (int): The dimension of the embedding space.
-            use_gpu (bool): Whether to use GPU for computations.
-            temperature (float): The temperature used in the computation.
-            num_pos (int): The number of positive samples.
-            num_neg (int): The number of negative samples.
-            criterion (nn.CrossEntropyLoss): The loss function.
-            gather_distributed (bool): Whether to gather distributed data.
-            world_size (int): The number of distributed processes.
-            effective_batch_size (int): The effective batch size, taking into account world size and number of positive samples.
-            pos_mask (None or Tensor): Mask for positive samples.
-            neg_mask (None or Tensor): Mask for negative samples.
-            balanced (bool): Whether to use balanced sampling.
-            setup (bool): Whether the setup has been done.
-        """
-        super(NegativeMiningInfoNCECriterion, self).__init__()
-        self.embedding_dim = embedding_dim
-        self.use_gpu = torch.cuda.is_available()
-        self.temperature = temperature
-        self.num_pos = 2
-
-        # Same number of negatives as positives are loaded
-        self.num_neg = self.num_pos
-        self.criterion = nn.CrossEntropyLoss()
-        self.gather_distributed = gather_distributed
-        self.world_size = world_size
-        self.effective_batch_size = batch_size * self.world_size * self.num_pos
-        self.pos_mask = None
-        self.neg_mask = None
-        self.balanced = balanced
-        self.setup = False
-
-    def precompute_pos_neg_mask(self):
-        """
-        Precompute the positive and negative masks to speed up the loss calculation.
-        """
-        # computed once at the begining of training
-
-        # total_images is x2 SimCLR Info-NCE loss
-        # as we have negative samples for each positive sample
-
-        total_images = self.effective_batch_size * self.num_neg
-        world_size = self.world_size
-
-        # Batch size computation is different from SimCLR paper
-        batch_size = self.effective_batch_size // world_size
-        orig_images = batch_size // self.num_pos
-        rank = dist.rank()
-
-        pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
-        neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
-        all_indices = np.arange(total_images)
-
-        # Index for pairs of images (original + copy)
-        pairs = orig_images * np.arange(self.num_pos)
-
-        # Remove all indices associated with positive samples & copies (for neg_mask)
-        all_pos_members = []
-        for _rank in range(world_size):
-            all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
-        all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
-        # Index of original positive images
-        orig_members = torch.arange(orig_images)
-
-        for anchor in np.arange(self.num_pos):
-            for img_idx in range(orig_images):
-                # delete_inds are spaced by batch_size for each rank as
-                # all_indices_pos_removed (half of the indices) is deleted first
-                delete_inds = batch_size * rank + img_idx + pairs
-                neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
-                neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
-            for pos in np.delete(np.arange(self.num_pos), anchor):
-                # Pos_inds are spaced by batch_size * self.num_neg for each rank
-                pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
-                pos_mask[
-                    torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
-                    pos_inds.long(),
-                ] = 1
-
-        self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
-        self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-
-    def forward(self, out: torch.Tensor):
-        """
-        Calculate the loss. Operates on embeddings tensor.
-        """
-        if not self.setup:
-            logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
-            self.precompute_pos_neg_mask()
-            self.setup = True
-
-        pos0, pos1 = out["positive"]
-        neg0, neg1 = out["negative"]
-        embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
-        embedding = nn.functional.normalize(embedding, dim=1, p=2)
-        assert embedding.ndim == 2
-        assert embedding.shape[1] == int(self.embedding_dim)
-
-        batch_size = embedding.shape[0]
-        T = self.temperature
-        num_pos = self.num_pos
-
-        assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
-        assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
-        # Step 1: gather all the embeddings. Shape example: 4096 x 128
-        embeddings_buffer = self.gather_embeddings(embedding)
-
-        # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
-        # divide by temperature.
-        similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
-        pos = torch.sum(similarity * self.pos_mask, 1)
-        neg = torch.sum(similarity * self.neg_mask, 1)
-
-        # Ignore the negative samples as entries for loss calculation
-        pos = pos[: (batch_size // 2)]
-        neg = neg[: (batch_size // 2)]
-
-        loss = -(torch.mean(torch.log(pos / (pos + neg))))
-        return loss
-
-    def __repr__(self):
-        """
-        Return a string representation of the object.
-
-        Returns:
-            str: A formatted string representation of the object.
-
-        Examples:
-            The following example shows the string representation of the object:
-
-            {
-              'name': <object_name>,
-              'temperature': <temperature_value>,
-              'num_negatives': <num_negatives_value>,
-              'num_pos': <num_pos_value>,
-              'dist_rank': <dist_rank_value>
-            }
-
-        Note:
-            This function is intended to be used with the pprint module for pretty printing.
-        """
-        num_negatives = self.effective_batch_size - 2
-        T = self.temperature
-        num_pos = self.num_pos
-        repr_dict = {
-            "name": self._get_name(),
-            "temperature": T,
-            "num_negatives": num_negatives,
-            "num_pos": num_pos,
-            "dist_rank": dist.rank(),
-        }
-        return pprint.pformat(repr_dict, indent=2)
-
-    def gather_embeddings(self, embedding: torch.Tensor):
-        """
-        Do a gather over all embeddings, so we can compute the loss.
-        Final shape is like: (batch_size * num_gpus) x embedding_dim
-        """
-        if self.gather_distributed:
-            embedding_gathered = torch.cat(dist.gather(embedding), 0)
-        else:
-            embedding_gathered = embedding
-        return embedding_gathered
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class NegativeMiningInfoNCECriterion -(embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True) -
-
-

The criterion corresponding to the SimCLR loss as defined in the paper -https://arxiv.org/abs/2002.05709.

-

Args

-
-
temperature : float
-
The temperature to be applied on the logits.
-
buffer_params : dict
-
A dictionary containing the following keys: -- world_size (int): Total number of trainers in training. -- embedding_dim (int): Output dimensions of the features projects. -- effective_batch_size (int): Total batch size used (includes positives).
-
-

Initialize the NegativeMiningInfoNCECriterion class.

-

Args

-
-
embedding_dim : int
-
The dimension of the embedding space.
-
batch_size : int
-
The size of the input batch.
-
world_size : int
-
The number of distributed processes.
-
gather_distributed : bool
-
Whether to gather distributed data.
-
temperature : float
-
The temperature used in the computation.
-
balanced : bool
-
Whether to use balanced sampling.
-
-

Attributes

-
-
embedding_dim : int
-
The dimension of the embedding space.
-
use_gpu : bool
-
Whether to use GPU for computations.
-
temperature : float
-
The temperature used in the computation.
-
num_pos : int
-
The number of positive samples.
-
num_neg : int
-
The number of negative samples.
-
criterion : nn.CrossEntropyLoss
-
The loss function.
-
gather_distributed : bool
-
Whether to gather distributed data.
-
world_size : int
-
The number of distributed processes.
-
effective_batch_size : int
-
The effective batch size, taking into account world size and number of positive samples.
-
pos_mask : None or Tensor
-
Mask for positive samples.
-
neg_mask : None or Tensor
-
Mask for negative samples.
-
balanced : bool
-
Whether to use balanced sampling.
-
setup : bool
-
Whether the setup has been done.
-
-
- -Expand source code - -
class NegativeMiningInfoNCECriterion(nn.Module):
-    """
-    The criterion corresponding to the SimCLR loss as defined in the paper
-    https://arxiv.org/abs/2002.05709.
-
-    Args:
-        temperature (float): The temperature to be applied on the logits.
-        buffer_params (dict): A dictionary containing the following keys:
-            - world_size (int): Total number of trainers in training.
-            - embedding_dim (int): Output dimensions of the features projects.
-            - effective_batch_size (int): Total batch size used (includes positives).
-    """
-
-    def __init__(
-        self, embedding_dim, batch_size, world_size, gather_distributed=False, temperature: float = 0.1, balanced: bool = True
-    ):
-        """
-        Initialize the NegativeMiningInfoNCECriterion class.
-
-        Args:
-            embedding_dim (int): The dimension of the embedding space.
-            batch_size (int): The size of the input batch.
-            world_size (int): The number of distributed processes.
-            gather_distributed (bool): Whether to gather distributed data.
-            temperature (float): The temperature used in the computation.
-            balanced (bool): Whether to use balanced sampling.
-
-        Attributes:
-            embedding_dim (int): The dimension of the embedding space.
-            use_gpu (bool): Whether to use GPU for computations.
-            temperature (float): The temperature used in the computation.
-            num_pos (int): The number of positive samples.
-            num_neg (int): The number of negative samples.
-            criterion (nn.CrossEntropyLoss): The loss function.
-            gather_distributed (bool): Whether to gather distributed data.
-            world_size (int): The number of distributed processes.
-            effective_batch_size (int): The effective batch size, taking into account world size and number of positive samples.
-            pos_mask (None or Tensor): Mask for positive samples.
-            neg_mask (None or Tensor): Mask for negative samples.
-            balanced (bool): Whether to use balanced sampling.
-            setup (bool): Whether the setup has been done.
-        """
-        super(NegativeMiningInfoNCECriterion, self).__init__()
-        self.embedding_dim = embedding_dim
-        self.use_gpu = torch.cuda.is_available()
-        self.temperature = temperature
-        self.num_pos = 2
-
-        # Same number of negatives as positives are loaded
-        self.num_neg = self.num_pos
-        self.criterion = nn.CrossEntropyLoss()
-        self.gather_distributed = gather_distributed
-        self.world_size = world_size
-        self.effective_batch_size = batch_size * self.world_size * self.num_pos
-        self.pos_mask = None
-        self.neg_mask = None
-        self.balanced = balanced
-        self.setup = False
-
-    def precompute_pos_neg_mask(self):
-        """
-        Precompute the positive and negative masks to speed up the loss calculation.
-        """
-        # computed once at the begining of training
-
-        # total_images is x2 SimCLR Info-NCE loss
-        # as we have negative samples for each positive sample
-
-        total_images = self.effective_batch_size * self.num_neg
-        world_size = self.world_size
-
-        # Batch size computation is different from SimCLR paper
-        batch_size = self.effective_batch_size // world_size
-        orig_images = batch_size // self.num_pos
-        rank = dist.rank()
-
-        pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
-        neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
-        all_indices = np.arange(total_images)
-
-        # Index for pairs of images (original + copy)
-        pairs = orig_images * np.arange(self.num_pos)
-
-        # Remove all indices associated with positive samples & copies (for neg_mask)
-        all_pos_members = []
-        for _rank in range(world_size):
-            all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
-        all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
-        # Index of original positive images
-        orig_members = torch.arange(orig_images)
-
-        for anchor in np.arange(self.num_pos):
-            for img_idx in range(orig_images):
-                # delete_inds are spaced by batch_size for each rank as
-                # all_indices_pos_removed (half of the indices) is deleted first
-                delete_inds = batch_size * rank + img_idx + pairs
-                neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
-                neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
-            for pos in np.delete(np.arange(self.num_pos), anchor):
-                # Pos_inds are spaced by batch_size * self.num_neg for each rank
-                pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
-                pos_mask[
-                    torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
-                    pos_inds.long(),
-                ] = 1
-
-        self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
-        self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-
-    def forward(self, out: torch.Tensor):
-        """
-        Calculate the loss. Operates on embeddings tensor.
-        """
-        if not self.setup:
-            logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
-            self.precompute_pos_neg_mask()
-            self.setup = True
-
-        pos0, pos1 = out["positive"]
-        neg0, neg1 = out["negative"]
-        embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
-        embedding = nn.functional.normalize(embedding, dim=1, p=2)
-        assert embedding.ndim == 2
-        assert embedding.shape[1] == int(self.embedding_dim)
-
-        batch_size = embedding.shape[0]
-        T = self.temperature
-        num_pos = self.num_pos
-
-        assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
-        assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
-        # Step 1: gather all the embeddings. Shape example: 4096 x 128
-        embeddings_buffer = self.gather_embeddings(embedding)
-
-        # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
-        # divide by temperature.
-        similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
-        pos = torch.sum(similarity * self.pos_mask, 1)
-        neg = torch.sum(similarity * self.neg_mask, 1)
-
-        # Ignore the negative samples as entries for loss calculation
-        pos = pos[: (batch_size // 2)]
-        neg = neg[: (batch_size // 2)]
-
-        loss = -(torch.mean(torch.log(pos / (pos + neg))))
-        return loss
-
-    def __repr__(self):
-        """
-        Return a string representation of the object.
-
-        Returns:
-            str: A formatted string representation of the object.
-
-        Examples:
-            The following example shows the string representation of the object:
-
-            {
-              'name': <object_name>,
-              'temperature': <temperature_value>,
-              'num_negatives': <num_negatives_value>,
-              'num_pos': <num_pos_value>,
-              'dist_rank': <dist_rank_value>
-            }
-
-        Note:
-            This function is intended to be used with the pprint module for pretty printing.
-        """
-        num_negatives = self.effective_batch_size - 2
-        T = self.temperature
-        num_pos = self.num_pos
-        repr_dict = {
-            "name": self._get_name(),
-            "temperature": T,
-            "num_negatives": num_negatives,
-            "num_pos": num_pos,
-            "dist_rank": dist.rank(),
-        }
-        return pprint.pformat(repr_dict, indent=2)
-
-    def gather_embeddings(self, embedding: torch.Tensor):
-        """
-        Do a gather over all embeddings, so we can compute the loss.
-        Final shape is like: (batch_size * num_gpus) x embedding_dim
-        """
-        if self.gather_distributed:
-            embedding_gathered = torch.cat(dist.gather(embedding), 0)
-        else:
-            embedding_gathered = embedding
-        return embedding_gathered
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, out: torch.Tensor) ‑> Callable[..., Any] -
-
-

Calculate the loss. Operates on embeddings tensor.

-
- -Expand source code - -
def forward(self, out: torch.Tensor):
-    """
-    Calculate the loss. Operates on embeddings tensor.
-    """
-    if not self.setup:
-        logger.info(f"Running Negative Mining Info-NCE loss on Rank: {dist.rank()}")
-        self.precompute_pos_neg_mask()
-        self.setup = True
-
-    pos0, pos1 = out["positive"]
-    neg0, neg1 = out["negative"]
-    embedding = torch.cat([pos0, pos1, neg0, neg1], dim=0)
-    embedding = nn.functional.normalize(embedding, dim=1, p=2)
-    assert embedding.ndim == 2
-    assert embedding.shape[1] == int(self.embedding_dim)
-
-    batch_size = embedding.shape[0]
-    T = self.temperature
-    num_pos = self.num_pos
-
-    assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
-    assert batch_size == self.pos_mask.shape[0], "Batch size should be equal to pos_mask shape"
-
-    # Step 1: gather all the embeddings. Shape example: 4096 x 128
-    embeddings_buffer = self.gather_embeddings(embedding)
-
-    # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
-    # divide by temperature.
-    similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
-
-    pos = torch.sum(similarity * self.pos_mask, 1)
-    neg = torch.sum(similarity * self.neg_mask, 1)
-
-    # Ignore the negative samples as entries for loss calculation
-    pos = pos[: (batch_size // 2)]
-    neg = neg[: (batch_size // 2)]
-
-    loss = -(torch.mean(torch.log(pos / (pos + neg))))
-    return loss
-
-
-
-def gather_embeddings(self, embedding: torch.Tensor) -
-
-

Do a gather over all embeddings, so we can compute the loss. -Final shape is like: (batch_size * num_gpus) x embedding_dim

-
- -Expand source code - -
def gather_embeddings(self, embedding: torch.Tensor):
-    """
-    Do a gather over all embeddings, so we can compute the loss.
-    Final shape is like: (batch_size * num_gpus) x embedding_dim
-    """
-    if self.gather_distributed:
-        embedding_gathered = torch.cat(dist.gather(embedding), 0)
-    else:
-        embedding_gathered = embedding
-    return embedding_gathered
-
-
-
-def precompute_pos_neg_mask(self) -
-
-

Precompute the positive and negative masks to speed up the loss calculation.

-
- -Expand source code - -
def precompute_pos_neg_mask(self):
-    """
-    Precompute the positive and negative masks to speed up the loss calculation.
-    """
-    # computed once at the begining of training
-
-    # total_images is x2 SimCLR Info-NCE loss
-    # as we have negative samples for each positive sample
-
-    total_images = self.effective_batch_size * self.num_neg
-    world_size = self.world_size
-
-    # Batch size computation is different from SimCLR paper
-    batch_size = self.effective_batch_size // world_size
-    orig_images = batch_size // self.num_pos
-    rank = dist.rank()
-
-    pos_mask = torch.zeros(batch_size * self.num_neg, total_images)
-    neg_mask = torch.zeros(batch_size * self.num_neg, total_images)
-
-    all_indices = np.arange(total_images)
-
-    # Index for pairs of images (original + copy)
-    pairs = orig_images * np.arange(self.num_pos)
-
-    # Remove all indices associated with positive samples & copies (for neg_mask)
-    all_pos_members = []
-    for _rank in range(world_size):
-        all_pos_members += list(_rank * (batch_size * 2) + np.arange(batch_size))
-
-    all_indices_pos_removed = np.delete(all_indices, all_pos_members)
-
-    # Index of original positive images
-    orig_members = torch.arange(orig_images)
-
-    for anchor in np.arange(self.num_pos):
-        for img_idx in range(orig_images):
-            # delete_inds are spaced by batch_size for each rank as
-            # all_indices_pos_removed (half of the indices) is deleted first
-            delete_inds = batch_size * rank + img_idx + pairs
-            neg_inds = torch.tensor(np.delete(all_indices_pos_removed, delete_inds)).long()
-            neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
-
-        for pos in np.delete(np.arange(self.num_pos), anchor):
-            # Pos_inds are spaced by batch_size * self.num_neg for each rank
-            pos_inds = (batch_size * self.num_neg) * rank + pos * orig_images + orig_members
-            pos_mask[
-                torch.arange(anchor * orig_images, (anchor + 1) * orig_images).long(),
-                pos_inds.long(),
-            ] = 1
-
-    self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
-    self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/nnclr_loss.html b/docs/api_docs/fmcib/ssl/losses/nnclr_loss.html deleted file mode 100644 index 34f8289..0000000 --- a/docs/api_docs/fmcib/ssl/losses/nnclr_loss.html +++ /dev/null @@ -1,206 +0,0 @@ - - - - - - -fmcib.ssl.losses.nnclr_loss API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses.nnclr_loss

-
-
-
- -Expand source code - -
from lightly.loss import NTXentLoss
-
-
-class NNCLRLoss(NTXentLoss):
-    """
-    A class representing the NNCLRLoss.
-
-    This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.
-
-    Attributes:
-        temperature (float): The temperature for the loss function. Default is 0.1.
-        gather_distributed (bool): A flag indicating whether the distributed gathering is used. Default is False.
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize a new instance of the class.
-
-        Args:
-            temperature (float): The temperature to use for initialization. Default value is 0.1.
-            gather_distributed (bool): Whether to use gather distributed mode. Default value is False.
-        """
-        super().__init__(temperature, gather_distributed)
-
-    def forward(self, out):
-        """
-        Symmetric loss function for NNCLR.
-        """
-        (z0, p0), (z1, p1) = out
-        loss0 = super().forward(z0, p0)
-        loss1 = super().forward(z1, p1)
-        return (loss0 + loss1) / 2
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class NNCLRLoss -(temperature: float = 0.1, gather_distributed: bool = False) -
-
-

A class representing the NNCLRLoss.

-

This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.

-

Attributes

-
-
temperature : float
-
The temperature for the loss function. Default is 0.1.
-
gather_distributed : bool
-
A flag indicating whether the distributed gathering is used. Default is False.
-
-

Initialize a new instance of the class.

-

Args

-
-
temperature : float
-
The temperature to use for initialization. Default value is 0.1.
-
gather_distributed : bool
-
Whether to use gather distributed mode. Default value is False.
-
-
- -Expand source code - -
class NNCLRLoss(NTXentLoss):
-    """
-    A class representing the NNCLRLoss.
-
-    This class extends the NTXentLoss class and implements a symmetric loss function for NNCLR.
-
-    Attributes:
-        temperature (float): The temperature for the loss function. Default is 0.1.
-        gather_distributed (bool): A flag indicating whether the distributed gathering is used. Default is False.
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize a new instance of the class.
-
-        Args:
-            temperature (float): The temperature to use for initialization. Default value is 0.1.
-            gather_distributed (bool): Whether to use gather distributed mode. Default value is False.
-        """
-        super().__init__(temperature, gather_distributed)
-
-    def forward(self, out):
-        """
-        Symmetric loss function for NNCLR.
-        """
-        (z0, p0), (z1, p1) = out
-        loss0 = super().forward(z0, p0)
-        loss1 = super().forward(z1, p1)
-        return (loss0 + loss1) / 2
-
-

Ancestors

-
    -
  • lightly.loss.ntx_ent_loss.NTXentLoss
  • -
  • lightly.loss.memory_bank.MemoryBankModule
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, out) ‑> Callable[..., Any] -
-
-

Symmetric loss function for NNCLR.

-
- -Expand source code - -
def forward(self, out):
-    """
-    Symmetric loss function for NNCLR.
-    """
-    (z0, p0), (z1, p1) = out
-    loss0 = super().forward(z0, p0)
-    loss1 = super().forward(z1, p1)
-    return (loss0 + loss1) / 2
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/ntxent_loss.html b/docs/api_docs/fmcib/ssl/losses/ntxent_loss.html deleted file mode 100644 index 4ac6021..0000000 --- a/docs/api_docs/fmcib/ssl/losses/ntxent_loss.html +++ /dev/null @@ -1,210 +0,0 @@ - - - - - - -fmcib.ssl.losses.ntxent_loss API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses.ntxent_loss

-
-
-
- -Expand source code - -
from typing import List
-
-from lightly.loss import NTXentLoss as lightly_NTXentLoss
-
-
-class NTXentLoss(lightly_NTXentLoss):
-    """
-    NTXentNegativeMinedLoss:
-    NTXentLoss with explicitly mined negatives
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize an instance of the class.
-
-        Args:
-            temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
-            gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
-        """
-        super().__init__(temperature, gather_distributed)
-
-    def forward(self, out: List):
-        """
-        Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-        Args:
-            out (List[torch.Tensor]): List of tensors
-
-        Returns:
-            float: Contrastive Cross Entropy Loss value.
-        """
-        return super().forward(*out)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class NTXentLoss -(temperature: float = 0.1, gather_distributed: bool = False) -
-
-

NTXentNegativeMinedLoss: -NTXentLoss with explicitly mined negatives

-

Initialize an instance of the class.

-

Args

-
-
temperature : float, optional
-
The temperature parameter for the instance. Defaults to 0.1.
-
gather_distributed : bool, optional
-
Whether to gather distributed data. Defaults to False.
-
-
- -Expand source code - -
class NTXentLoss(lightly_NTXentLoss):
-    """
-    NTXentNegativeMinedLoss:
-    NTXentLoss with explicitly mined negatives
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize an instance of the class.
-
-        Args:
-            temperature (float, optional): The temperature parameter for the instance. Defaults to 0.1.
-            gather_distributed (bool, optional): Whether to gather distributed data. Defaults to False.
-        """
-        super().__init__(temperature, gather_distributed)
-
-    def forward(self, out: List):
-        """
-        Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-        Args:
-            out (List[torch.Tensor]): List of tensors
-
-        Returns:
-            float: Contrastive Cross Entropy Loss value.
-        """
-        return super().forward(*out)
-
-

Ancestors

-
    -
  • lightly.loss.ntx_ent_loss.NTXentLoss
  • -
  • lightly.loss.memory_bank.MemoryBankModule
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, out: List[~T]) ‑> Callable[..., Any] -
-
-

Forward pass through Negative mining contrastive Cross-Entropy Loss.

-

Args

-
-
out : List[torch.Tensor]
-
List of tensors
-
-

Returns

-
-
float
-
Contrastive Cross Entropy Loss value.
-
-
- -Expand source code - -
def forward(self, out: List):
-    """
-    Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-    Args:
-        out (List[torch.Tensor]): List of tensors
-
-    Returns:
-        float: Contrastive Cross Entropy Loss value.
-    """
-    return super().forward(*out)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/ntxent_mined_loss.html b/docs/api_docs/fmcib/ssl/losses/ntxent_mined_loss.html deleted file mode 100644 index cceff95..0000000 --- a/docs/api_docs/fmcib/ssl/losses/ntxent_mined_loss.html +++ /dev/null @@ -1,447 +0,0 @@ - - - - - - -fmcib.ssl.losses.ntxent_mined_loss API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses.ntxent_mined_loss

-
-
-

Contrastive Loss Functions

-
- -Expand source code - -
""" Contrastive Loss Functions """
-
-# Copyright (c) 2020. Lightly AG and its affiliates.
-# All Rights Reserved
-
-# Modified to function for explicitly selected negatives
-
-from typing import Dict
-
-import torch
-from lightly.utils import dist
-from torch import nn
-
-
-class NTXentNegativeMinedLoss(torch.nn.Module):
-    """
-    NTXentNegativeMinedLoss:
-    NTXentLoss with explicitly mined negatives
-
-    Args:
-        temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
-        gather_distributed (bool): Whether to gather hidden representations from other processes in a distributed setting. Default is False.
-
-    Raises:
-        ValueError: If the absolute value of temperature is less than 1e-8.
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize the NTXentNegativeMinedLoss object.
-
-        Args:
-            temperature (float, optional): The temperature parameter for the loss function. Defaults to 0.1.
-            gather_distributed (bool, optional): Whether to use distributed gathering or not. Defaults to False.
-
-        Raises:
-            ValueError: If the absolute value of the temperature is too small.
-
-        Attributes:
-            temperature (float): The temperature parameter for the loss function.
-            gather_distributed (bool): Whether to use distributed gathering or not.
-            cross_entropy (torch.nn.CrossEntropyLoss): The cross entropy loss function.
-            eps (float): A small value to avoid division by zero.
-        """
-        super(NTXentNegativeMinedLoss, self).__init__()
-        self.temperature = temperature
-        self.gather_distributed = gather_distributed
-        self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
-        self.eps = 1e-8
-
-        if abs(self.temperature) < self.eps:
-            raise ValueError("Illegal temperature: abs({}) < 1e-8".format(self.temperature))
-
-    def forward(self, out: Dict):
-        """
-        Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-        Args:
-            out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
-        Returns:
-            torch.Tensor: Contrastive Cross Entropy Loss value.
-
-        Raises:
-            AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
-        """
-
-        assert "positive" in out, "`positive` key needs to be specified"
-        assert "negative" in out, "`negative` key needs to be specified"
-
-        pos0, pos1 = out["positive"]
-        neg0, neg1 = out["negative"]
-
-        device = pos0.device
-        batch_size, _ = pos0.shape
-
-        # normalize the output to length 1
-        pos0 = nn.functional.normalize(pos0, dim=1)
-        pos1 = nn.functional.normalize(pos1, dim=1)
-        neg0 = nn.functional.normalize(neg0, dim=1)
-        neg1 = nn.functional.normalize(neg1, dim=1)
-
-        if self.gather_distributed and dist.world_size() > 1:
-            # gather hidden representations from other processes
-            pos0_large = torch.cat(dist.gather(pos0), 0)
-            pos1_large = torch.cat(dist.gather(pos1), 0)
-            neg0_large = torch.cat(dist.gather(neg0), 0)
-            neg1_large = torch.cat(dist.gather(neg1), 0)
-            diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
-        else:
-            # gather hidden representations from other processes
-            pos0_large = pos0
-            pos1_large = pos1
-            neg0_large = neg0
-            neg1_large = neg1
-            diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
-        logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
-        logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
-        logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
-        logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
-        logits_01 = logits_01[diag_mask].view(batch_size, -1)
-        logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
-        logits_0100 = torch.cat([logits_01, logits_00], dim=1)
-        logits_1011 = torch.cat([logits_10, logits_11], dim=1)
-        logits = torch.cat([logits_0100, logits_1011], dim=0)
-
-        labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
-        loss = self.cross_entropy(logits, labels)
-
-        return loss
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class NTXentNegativeMinedLoss -(temperature: float = 0.1, gather_distributed: bool = False) -
-
-

NTXentNegativeMinedLoss: -NTXentLoss with explicitly mined negatives

-

Args

-
-
temperature : float
-
The temperature parameter for the loss calculation. Default is 0.1.
-
gather_distributed : bool
-
Whether to gather hidden representations from other processes in a distributed setting. Default is False.
-
-

Raises

-
-
ValueError
-
If the absolute value of temperature is less than 1e-8.
-
-

Initialize the NTXentNegativeMinedLoss object.

-

Args

-
-
temperature : float, optional
-
The temperature parameter for the loss function. Defaults to 0.1.
-
gather_distributed : bool, optional
-
Whether to use distributed gathering or not. Defaults to False.
-
-

Raises

-
-
ValueError
-
If the absolute value of the temperature is too small.
-
-

Attributes

-
-
temperature : float
-
The temperature parameter for the loss function.
-
gather_distributed : bool
-
Whether to use distributed gathering or not.
-
cross_entropy : torch.nn.CrossEntropyLoss
-
The cross entropy loss function.
-
eps : float
-
A small value to avoid division by zero.
-
-
- -Expand source code - -
class NTXentNegativeMinedLoss(torch.nn.Module):
-    """
-    NTXentNegativeMinedLoss:
-    NTXentLoss with explicitly mined negatives
-
-    Args:
-        temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
-        gather_distributed (bool): Whether to gather hidden representations from other processes in a distributed setting. Default is False.
-
-    Raises:
-        ValueError: If the absolute value of temperature is less than 1e-8.
-    """
-
-    def __init__(self, temperature: float = 0.1, gather_distributed: bool = False):
-        """
-        Initialize the NTXentNegativeMinedLoss object.
-
-        Args:
-            temperature (float, optional): The temperature parameter for the loss function. Defaults to 0.1.
-            gather_distributed (bool, optional): Whether to use distributed gathering or not. Defaults to False.
-
-        Raises:
-            ValueError: If the absolute value of the temperature is too small.
-
-        Attributes:
-            temperature (float): The temperature parameter for the loss function.
-            gather_distributed (bool): Whether to use distributed gathering or not.
-            cross_entropy (torch.nn.CrossEntropyLoss): The cross entropy loss function.
-            eps (float): A small value to avoid division by zero.
-        """
-        super(NTXentNegativeMinedLoss, self).__init__()
-        self.temperature = temperature
-        self.gather_distributed = gather_distributed
-        self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
-        self.eps = 1e-8
-
-        if abs(self.temperature) < self.eps:
-            raise ValueError("Illegal temperature: abs({}) < 1e-8".format(self.temperature))
-
-    def forward(self, out: Dict):
-        """
-        Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-        Args:
-            out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
-        Returns:
-            torch.Tensor: Contrastive Cross Entropy Loss value.
-
-        Raises:
-            AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
-        """
-
-        assert "positive" in out, "`positive` key needs to be specified"
-        assert "negative" in out, "`negative` key needs to be specified"
-
-        pos0, pos1 = out["positive"]
-        neg0, neg1 = out["negative"]
-
-        device = pos0.device
-        batch_size, _ = pos0.shape
-
-        # normalize the output to length 1
-        pos0 = nn.functional.normalize(pos0, dim=1)
-        pos1 = nn.functional.normalize(pos1, dim=1)
-        neg0 = nn.functional.normalize(neg0, dim=1)
-        neg1 = nn.functional.normalize(neg1, dim=1)
-
-        if self.gather_distributed and dist.world_size() > 1:
-            # gather hidden representations from other processes
-            pos0_large = torch.cat(dist.gather(pos0), 0)
-            pos1_large = torch.cat(dist.gather(pos1), 0)
-            neg0_large = torch.cat(dist.gather(neg0), 0)
-            neg1_large = torch.cat(dist.gather(neg1), 0)
-            diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
-        else:
-            # gather hidden representations from other processes
-            pos0_large = pos0
-            pos1_large = pos1
-            neg0_large = neg0
-            neg1_large = neg1
-            diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
-        logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
-        logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
-        logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
-        logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
-        logits_01 = logits_01[diag_mask].view(batch_size, -1)
-        logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
-        logits_0100 = torch.cat([logits_01, logits_00], dim=1)
-        logits_1011 = torch.cat([logits_10, logits_11], dim=1)
-        logits = torch.cat([logits_0100, logits_1011], dim=0)
-
-        labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
-        loss = self.cross_entropy(logits, labels)
-
-        return loss
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, out: Dict[~KT, ~VT]) ‑> Callable[..., Any] -
-
-

Forward pass through Negative mining contrastive Cross-Entropy Loss.

-

Args

-
-
out : Dict
-
Dictionary with positive and negative keys to represent positive selected and negative selected samples.
-
-

Returns

-
-
torch.Tensor
-
Contrastive Cross Entropy Loss value.
-
-

Raises

-
-
AssertionError
-
If positive or negative keys are not specified in the input dictionary.
-
-
- -Expand source code - -
def forward(self, out: Dict):
-    """
-    Forward pass through Negative mining contrastive Cross-Entropy Loss.
-
-    Args:
-        out (Dict): Dictionary with `positive` and `negative` keys to represent positive selected and negative selected samples.
-
-    Returns:
-        torch.Tensor: Contrastive Cross Entropy Loss value.
-
-    Raises:
-        AssertionError: If `positive` or `negative` keys are not specified in the input dictionary.
-    """
-
-    assert "positive" in out, "`positive` key needs to be specified"
-    assert "negative" in out, "`negative` key needs to be specified"
-
-    pos0, pos1 = out["positive"]
-    neg0, neg1 = out["negative"]
-
-    device = pos0.device
-    batch_size, _ = pos0.shape
-
-    # normalize the output to length 1
-    pos0 = nn.functional.normalize(pos0, dim=1)
-    pos1 = nn.functional.normalize(pos1, dim=1)
-    neg0 = nn.functional.normalize(neg0, dim=1)
-    neg1 = nn.functional.normalize(neg1, dim=1)
-
-    if self.gather_distributed and dist.world_size() > 1:
-        # gather hidden representations from other processes
-        pos0_large = torch.cat(dist.gather(pos0), 0)
-        pos1_large = torch.cat(dist.gather(pos1), 0)
-        neg0_large = torch.cat(dist.gather(neg0), 0)
-        neg1_large = torch.cat(dist.gather(neg1), 0)
-        diag_mask = dist.eye_rank(batch_size, device=pos0.device)
-
-    else:
-        # gather hidden representations from other processes
-        pos0_large = pos0
-        pos1_large = pos1
-        neg0_large = neg0
-        neg1_large = neg1
-        diag_mask = torch.eye(batch_size, device=pos0.device, dtype=torch.bool)
-
-    logits_00 = torch.einsum("nc,mc->nm", pos0, neg0_large) / self.temperature
-    logits_01 = torch.einsum("nc,mc->nm", pos0, pos1_large) / self.temperature
-    logits_10 = torch.einsum("nc,mc->nm", pos1, pos0_large) / self.temperature
-    logits_11 = torch.einsum("nc,mc->nm", pos1, neg1_large) / self.temperature
-
-    logits_01 = logits_01[diag_mask].view(batch_size, -1)
-    logits_10 = logits_10[diag_mask].view(batch_size, -1)
-
-    logits_0100 = torch.cat([logits_01, logits_00], dim=1)
-    logits_1011 = torch.cat([logits_10, logits_11], dim=1)
-    logits = torch.cat([logits_0100, logits_1011], dim=0)
-
-    labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
-    loss = self.cross_entropy(logits, labels)
-
-    return loss
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/losses/swav_loss.html b/docs/api_docs/fmcib/ssl/losses/swav_loss.html deleted file mode 100644 index a3790b1..0000000 --- a/docs/api_docs/fmcib/ssl/losses/swav_loss.html +++ /dev/null @@ -1,247 +0,0 @@ - - - - - - -fmcib.ssl.losses.swav_loss API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.losses.swav_loss

-
-
-
- -Expand source code - -
import lightly
-
-
-class SwaVLoss(lightly.loss.swav_loss.SwaVLoss):
-    """
-    A class representing a custom SwaV loss function.
-
-    Attributes:
-        temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
-        sinkhorn_iterations (int): The number of iterations for Sinkhorn algorithm. Default is 3.
-        sinkhorn_epsilon (float): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-        sinkhorn_gather_distributed (bool): Whether to gather distributed results for Sinkhorn algorithm. Default is False.
-    """
-
-    def __init__(
-        self,
-        temperature: float = 0.1,
-        sinkhorn_iterations: int = 3,
-        sinkhorn_epsilon: float = 0.05,
-        sinkhorn_gather_distributed: bool = False,
-    ):
-        """
-        Initialize the object with specified parameters.
-
-        Args:
-            temperature (float, optional): The temperature parameter. Default is 0.1.
-            sinkhorn_iterations (int, optional): The number of Sinkhorn iterations. Default is 3.
-            sinkhorn_epsilon (float, optional): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-            sinkhorn_gather_distributed (bool, optional): Whether to use distributed computation for Sinkhorn algorithm. Default is False.
-        """
-        super().__init__(temperature, sinkhorn_iterations, sinkhorn_epsilon, sinkhorn_gather_distributed)
-
-    def forward(self, pred):
-        """
-        Perform a forward pass of the model.
-
-        Args:
-            pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
-        Returns:
-            The output of the forward pass.
-        """
-        high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
-        return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SwaVLoss -(temperature: float = 0.1, sinkhorn_iterations: int = 3, sinkhorn_epsilon: float = 0.05, sinkhorn_gather_distributed: bool = False) -
-
-

A class representing a custom SwaV loss function.

-

Attributes

-
-
temperature : float
-
The temperature parameter for the loss calculation. Default is 0.1.
-
sinkhorn_iterations : int
-
The number of iterations for Sinkhorn algorithm. Default is 3.
-
sinkhorn_epsilon : float
-
The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-
sinkhorn_gather_distributed : bool
-
Whether to gather distributed results for Sinkhorn algorithm. Default is False.
-
-

Initialize the object with specified parameters.

-

Args

-
-
temperature : float, optional
-
The temperature parameter. Default is 0.1.
-
sinkhorn_iterations : int, optional
-
The number of Sinkhorn iterations. Default is 3.
-
sinkhorn_epsilon : float, optional
-
The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-
sinkhorn_gather_distributed : bool, optional
-
Whether to use distributed computation for Sinkhorn algorithm. Default is False.
-
-
- -Expand source code - -
class SwaVLoss(lightly.loss.swav_loss.SwaVLoss):
-    """
-    A class representing a custom SwaV loss function.
-
-    Attributes:
-        temperature (float): The temperature parameter for the loss calculation. Default is 0.1.
-        sinkhorn_iterations (int): The number of iterations for Sinkhorn algorithm. Default is 3.
-        sinkhorn_epsilon (float): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-        sinkhorn_gather_distributed (bool): Whether to gather distributed results for Sinkhorn algorithm. Default is False.
-    """
-
-    def __init__(
-        self,
-        temperature: float = 0.1,
-        sinkhorn_iterations: int = 3,
-        sinkhorn_epsilon: float = 0.05,
-        sinkhorn_gather_distributed: bool = False,
-    ):
-        """
-        Initialize the object with specified parameters.
-
-        Args:
-            temperature (float, optional): The temperature parameter. Default is 0.1.
-            sinkhorn_iterations (int, optional): The number of Sinkhorn iterations. Default is 3.
-            sinkhorn_epsilon (float, optional): The epsilon parameter for Sinkhorn algorithm. Default is 0.05.
-            sinkhorn_gather_distributed (bool, optional): Whether to use distributed computation for Sinkhorn algorithm. Default is False.
-        """
-        super().__init__(temperature, sinkhorn_iterations, sinkhorn_epsilon, sinkhorn_gather_distributed)
-
-    def forward(self, pred):
-        """
-        Perform a forward pass of the model.
-
-        Args:
-            pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
-        Returns:
-            The output of the forward pass.
-        """
-        high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
-        return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-
-

Ancestors

-
    -
  • lightly.loss.swav_loss.SwaVLoss
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, pred) ‑> Callable[..., Any] -
-
-

Perform a forward pass of the model.

-

Args

-
-
pred : tuple
-
A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
-

Returns

-

The output of the forward pass.

-
- -Expand source code - -
def forward(self, pred):
-    """
-    Perform a forward pass of the model.
-
-    Args:
-        pred (tuple): A tuple containing the predicted outputs for high resolution, low resolution, and queue.
-
-    Returns:
-        The output of the forward pass.
-    """
-    high_resolution_outputs, low_resolution_outputs, queue_outputs = pred
-    return super().forward(high_resolution_outputs, low_resolution_outputs, queue_outputs)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/models/exneg_simclr.html b/docs/api_docs/fmcib/ssl/models/exneg_simclr.html deleted file mode 100644 index 6029b52..0000000 --- a/docs/api_docs/fmcib/ssl/models/exneg_simclr.html +++ /dev/null @@ -1,210 +0,0 @@ - - - - - - -fmcib.ssl.modules.exneg_simclr API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.exneg_simclr

-
-
-
- -Expand source code - -
from typing import Dict, Union
-
-import torch
-import torch.nn as nn
-from lightly.models import SimCLR
-
-
-class ExNegSimCLR(SimCLR):
-    """
-    Extended Negative Sampling SimCLR model.
-
-    Args:
-        backbone (nn.Module): The backbone model.
-        num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
-        out_dim (int): Dimension of the output feature embeddings. Default is 128.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
-        print(backbone)
-        super().__init__(backbone, num_ftrs, out_dim)
-
-    def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-        """
-        Forward pass of the ExNegSimCLR model.
-
-        Args:
-            x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-            return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-        Returns:
-            out (Dict): Output dictionary containing the forward pass results for each input view.
-        """
-        assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-        out = {}
-        for key, value in x.items():
-            if isinstance(value, list):
-                out[key] = super().forward(*value, return_features)
-
-        return out
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class ExNegSimCLR -(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128) -
-
-

Extended Negative Sampling SimCLR model.

-

Args

-
-
backbone : nn.Module
-
The backbone model.
-
num_ftrs : int
-
Number of features in the bottleneck layer. Default is 32.
-
out_dim : int
-
Dimension of the output feature embeddings. Default is 128.
-
-

Initializes internal Module state, shared by both nn.Module and ScriptModule.

-
- -Expand source code - -
class ExNegSimCLR(SimCLR):
-    """
-    Extended Negative Sampling SimCLR model.
-
-    Args:
-        backbone (nn.Module): The backbone model.
-        num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
-        out_dim (int): Dimension of the output feature embeddings. Default is 128.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
-        print(backbone)
-        super().__init__(backbone, num_ftrs, out_dim)
-
-    def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-        """
-        Forward pass of the ExNegSimCLR model.
-
-        Args:
-            x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-            return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-        Returns:
-            out (Dict): Output dictionary containing the forward pass results for each input view.
-        """
-        assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-        out = {}
-        for key, value in x.items():
-            if isinstance(value, list):
-                out[key] = super().forward(*value, return_features)
-
-        return out
-
-

Ancestors

-
    -
  • lightly.models.simclr.SimCLR
  • -
  • torch.nn.modules.module.Module
  • -
-

Methods

-
-
-def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False) ‑> Callable[..., Any] -
-
-

Forward pass of the ExNegSimCLR model.

-

Args

-
-
x : Union[Dict, torch.Tensor]
-
Input data. If a dictionary, it should contain multiple views of the same image.
-
return_features : bool
-
Whether to return the intermediate feature embeddings. Default is False.
-
-

Returns

-

out (Dict): Output dictionary containing the forward pass results for each input view.

-
- -Expand source code - -
def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-    """
-    Forward pass of the ExNegSimCLR model.
-
-    Args:
-        x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-        return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-    Returns:
-        out (Dict): Output dictionary containing the forward pass results for each input view.
-    """
-    assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-    out = {}
-    for key, value in x.items():
-        if isinstance(value, list):
-            out[key] = super().forward(*value, return_features)
-
-    return out
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/models/index.html b/docs/api_docs/fmcib/ssl/models/index.html deleted file mode 100644 index 6977153..0000000 --- a/docs/api_docs/fmcib/ssl/models/index.html +++ /dev/null @@ -1,77 +0,0 @@ - - - - - - -fmcib.ssl.modules API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules

-
-
-
- -Expand source code - -
from .exneg_simclr import ExNegSimCLR
-from .load_pretrained_resnet import LoadPretrainedResnet3D
-
-
-
-

Sub-modules

-
-
fmcib.ssl.modules.exneg_simclr
-
-
-
-
fmcib.ssl.modules.load_pretrained_resnet
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/models/load_pretrained_resnet.html b/docs/api_docs/fmcib/ssl/models/load_pretrained_resnet.html deleted file mode 100644 index 7f6e5de..0000000 --- a/docs/api_docs/fmcib/ssl/models/load_pretrained_resnet.html +++ /dev/null @@ -1,327 +0,0 @@ - - - - - - -fmcib.ssl.modules.load_pretrained_resnet API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.load_pretrained_resnet

-
-
-
- -Expand source code - -
import monai
-import torch
-from loguru import logger
-from monai.networks.nets.resnet import ResNetBottleneck as Bottleneck
-from torch import nn
-
-
-class LoadPretrainedResnet3D(nn.Module):
-    """
-    LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.
-
-    Args:
-        pretrained (str): Path to the pretrained model file. Default is None.
-        vissl (bool): Whether the pretrained model is from VISSL. Default is False.
-        heads (list): List of integers specifying the number of input and output channels for each head. Default is an empty list.
-
-    Attributes:
-        trunk (nn.Module): The ResNet trunk network.
-        heads (nn.Module): The sequential module containing the heads.
-
-    Methods:
-        forward(x): Forward pass of the model.
-        load(pretrained): Load the pretrained model weights.
-
-    Example:
-        model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])
-    """
-
-    def __init__(self, pretrained=None, vissl=False, heads=[]) -> None:
-        super().__init__()
-        self.trunk = monai.networks.nets.resnet.ResNet(
-            block=Bottleneck,
-            layers=(3, 4, 6, 3),
-            block_inplanes=(64, 128, 256, 512),
-            spatial_dims=3,
-            n_input_channels=1,
-            conv1_t_stride=2,
-            conv1_t_size=7,
-            widen_factor=2,
-            feed_forward=False,
-        )
-
-        head_layers = []
-        for idx in range(len(heads) - 1):
-            current_layers = []
-            current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
-            if idx != (len(heads) - 2):
-                current_layers.append(nn.ReLU(inplace=True))
-
-            head_layers.append(nn.Sequential(*current_layers))
-
-        if len(head_layers):
-            self.heads = nn.Sequential(*head_layers)
-        else:
-            self.heads = nn.Identity()
-
-        if pretrained is not None:
-            self.load(pretrained)
-
-    def forward(self, x: torch.Tensor):
-        out = self.trunk(x)
-        out = self.heads(out)
-        return out
-
-    def load(self, pretrained):
-        pretrained_model = torch.load(pretrained)
-
-        # Load trained trunk
-        trained_trunk = pretrained_model["trunk_state_dict"]
-        msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        # Load trained heads
-        if "head_state_dict" in pretrained_model:
-            trained_heads = pretrained_model["head_state_dict"]
-
-            try:
-                msg = self.heads.load_state_dict(trained_heads, strict=False)
-            except Exception as e:
-                logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        logger.info(f"Loaded pretrained model weights \n")
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class LoadPretrainedResnet3D -(pretrained=None, vissl=False, heads=[]) -
-
-

LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.

-

Args

-
-
pretrained : str
-
Path to the pretrained model file. Default is None.
-
vissl : bool
-
Whether the pretrained model is from VISSL. Default is False.
-
heads : list
-
List of integers specifying the number of input and output channels for each head. Default is an empty list.
-
-

Attributes

-
-
trunk : nn.Module
-
The ResNet trunk network.
-
heads : nn.Module
-
The sequential module containing the heads.
-
-

Methods

-

forward(x): Forward pass of the model. -load(pretrained): Load the pretrained model weights.

-

Example

-

model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])

-

Initializes internal Module state, shared by both nn.Module and ScriptModule.

-
- -Expand source code - -
class LoadPretrainedResnet3D(nn.Module):
-    """
-    LoadPretrainedResnet3D is a PyTorch module for loading a pretrained ResNet-3D model with optional heads.
-
-    Args:
-        pretrained (str): Path to the pretrained model file. Default is None.
-        vissl (bool): Whether the pretrained model is from VISSL. Default is False.
-        heads (list): List of integers specifying the number of input and output channels for each head. Default is an empty list.
-
-    Attributes:
-        trunk (nn.Module): The ResNet trunk network.
-        heads (nn.Module): The sequential module containing the heads.
-
-    Methods:
-        forward(x): Forward pass of the model.
-        load(pretrained): Load the pretrained model weights.
-
-    Example:
-        model = LoadPretrainedResnet3D(pretrained='path/to/pretrained_model.pth', vissl=True, heads=[512, 256, 128])
-    """
-
-    def __init__(self, pretrained=None, vissl=False, heads=[]) -> None:
-        super().__init__()
-        self.trunk = monai.networks.nets.resnet.ResNet(
-            block=Bottleneck,
-            layers=(3, 4, 6, 3),
-            block_inplanes=(64, 128, 256, 512),
-            spatial_dims=3,
-            n_input_channels=1,
-            conv1_t_stride=2,
-            conv1_t_size=7,
-            widen_factor=2,
-            feed_forward=False,
-        )
-
-        head_layers = []
-        for idx in range(len(heads) - 1):
-            current_layers = []
-            current_layers.append(nn.Linear(heads[idx], heads[idx + 1], bias=True))
-
-            if idx != (len(heads) - 2):
-                current_layers.append(nn.ReLU(inplace=True))
-
-            head_layers.append(nn.Sequential(*current_layers))
-
-        if len(head_layers):
-            self.heads = nn.Sequential(*head_layers)
-        else:
-            self.heads = nn.Identity()
-
-        if pretrained is not None:
-            self.load(pretrained)
-
-    def forward(self, x: torch.Tensor):
-        out = self.trunk(x)
-        out = self.heads(out)
-        return out
-
-    def load(self, pretrained):
-        pretrained_model = torch.load(pretrained)
-
-        # Load trained trunk
-        trained_trunk = pretrained_model["trunk_state_dict"]
-        msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        # Load trained heads
-        if "head_state_dict" in pretrained_model:
-            trained_heads = pretrained_model["head_state_dict"]
-
-            try:
-                msg = self.heads.load_state_dict(trained_heads, strict=False)
-            except Exception as e:
-                logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-            logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-        logger.info(f"Loaded pretrained model weights \n")
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Methods

-
-
-def forward(self, x: torch.Tensor) ‑> Callable[..., Any] -
-
-

Defines the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
-
- -Expand source code - -
def forward(self, x: torch.Tensor):
-    out = self.trunk(x)
-    out = self.heads(out)
-    return out
-
-
-
-def load(self, pretrained) -
-
-
-
- -Expand source code - -
def load(self, pretrained):
-    pretrained_model = torch.load(pretrained)
-
-    # Load trained trunk
-    trained_trunk = pretrained_model["trunk_state_dict"]
-    msg = self.trunk.load_state_dict(trained_trunk, strict=False)
-    logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-    # Load trained heads
-    if "head_state_dict" in pretrained_model:
-        trained_heads = pretrained_model["head_state_dict"]
-
-        try:
-            msg = self.heads.load_state_dict(trained_heads, strict=False)
-        except Exception as e:
-            logger.error(f"Failed to load trained heads with error {e}. This is expected if the models do not match!")
-        logger.warning(f"Missing keys: {msg[0]} and unexpected keys: {msg[1]}")
-
-    logger.info(f"Loaded pretrained model weights \n")
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/modules/exneg_simclr.html b/docs/api_docs/fmcib/ssl/modules/exneg_simclr.html deleted file mode 100644 index 3683cfa..0000000 --- a/docs/api_docs/fmcib/ssl/modules/exneg_simclr.html +++ /dev/null @@ -1,275 +0,0 @@ - - - - - - -fmcib.ssl.modules.exneg_simclr API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.exneg_simclr

-
-
-
- -Expand source code - -
from typing import Dict, Union
-
-import torch
-import torch.nn as nn
-from lightly.models import SimCLR as lightly_SimCLR
-from lightly.models.modules import SimCLRProjectionHead
-
-
-class ExNegSimCLR(lightly_SimCLR):
-    """
-    Extended Negative Sampling SimCLR model.
-
-    Args:
-        backbone (nn.Module): The backbone model.
-        num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
-        out_dim (int): Dimension of the output feature embeddings. Default is 128.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
-        """
-        Initialize the object.
-
-        Args:
-            backbone (nn.Module): The backbone neural network.
-            num_ftrs (int, optional): The number of input features for the projection head. Default is 32.
-            out_dim (int, optional): The output dimension of the projection head. Default is 128.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__(backbone, num_ftrs, out_dim)
-        # replace the projection head with a new one
-        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
-    def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-        """
-        Forward pass of the ExNegSimCLR model.
-
-        Args:
-            x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-            return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-        Returns:
-            Dict: Output dictionary containing the forward pass results for each input view.
-        """
-        assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-        out = {}
-        for key, value in x.items():
-            if isinstance(value, list):
-                out[key] = super().forward(*value, return_features)
-
-        return out
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class ExNegSimCLR -(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128) -
-
-

Extended Negative Sampling SimCLR model.

-

Args

-
-
backbone : nn.Module
-
The backbone model.
-
num_ftrs : int
-
Number of features in the bottleneck layer. Default is 32.
-
out_dim : int
-
Dimension of the output feature embeddings. Default is 128.
-
-

Initialize the object.

-

Args

-
-
backbone : nn.Module
-
The backbone neural network.
-
num_ftrs : int, optional
-
The number of input features for the projection head. Default is 32.
-
out_dim : int, optional
-
The output dimension of the projection head. Default is 128.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class ExNegSimCLR(lightly_SimCLR):
-    """
-    Extended Negative Sampling SimCLR model.
-
-    Args:
-        backbone (nn.Module): The backbone model.
-        num_ftrs (int): Number of features in the bottleneck layer. Default is 32.
-        out_dim (int): Dimension of the output feature embeddings. Default is 128.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128) -> None:
-        """
-        Initialize the object.
-
-        Args:
-            backbone (nn.Module): The backbone neural network.
-            num_ftrs (int, optional): The number of input features for the projection head. Default is 32.
-            out_dim (int, optional): The output dimension of the projection head. Default is 128.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__(backbone, num_ftrs, out_dim)
-        # replace the projection head with a new one
-        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
-    def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-        """
-        Forward pass of the ExNegSimCLR model.
-
-        Args:
-            x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-            return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-        Returns:
-            Dict: Output dictionary containing the forward pass results for each input view.
-        """
-        assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-        out = {}
-        for key, value in x.items():
-            if isinstance(value, list):
-                out[key] = super().forward(*value, return_features)
-
-        return out
-
-

Ancestors

-
    -
  • lightly.models.simclr.SimCLR
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x: Union[Dict[~KT, ~VT], torch.Tensor], return_features: bool = False) ‑> Callable[..., Any] -
-
-

Forward pass of the ExNegSimCLR model.

-

Args

-
-
x : Union[Dict, torch.Tensor]
-
Input data. If a dictionary, it should contain multiple views of the same image.
-
return_features : bool
-
Whether to return the intermediate feature embeddings. Default is False.
-
-

Returns

-
-
Dict
-
Output dictionary containing the forward pass results for each input view.
-
-
- -Expand source code - -
def forward(self, x: Union[Dict, torch.Tensor], return_features: bool = False):
-    """
-    Forward pass of the ExNegSimCLR model.
-
-    Args:
-        x (Union[Dict, torch.Tensor]): Input data. If a dictionary, it should contain multiple views of the same image.
-        return_features (bool): Whether to return the intermediate feature embeddings. Default is False.
-
-    Returns:
-        Dict: Output dictionary containing the forward pass results for each input view.
-    """
-    assert isinstance(x, dict), "Input to forward must be a `dict` for ExNegSimCLR"
-    out = {}
-    for key, value in x.items():
-        if isinstance(value, list):
-            out[key] = super().forward(*value, return_features)
-
-    return out
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/modules/index.html b/docs/api_docs/fmcib/ssl/modules/index.html deleted file mode 100644 index 8b1157e..0000000 --- a/docs/api_docs/fmcib/ssl/modules/index.html +++ /dev/null @@ -1,89 +0,0 @@ - - - - - - -fmcib.ssl.modules API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules

-
-
-
- -Expand source code - -
from .exneg_simclr import ExNegSimCLR
-from .nnclr import NNCLR
-from .simclr import SimCLR
-from .swav import SwaV
-
-
-
-

Sub-modules

-
-
fmcib.ssl.modules.exneg_simclr
-
-
-
-
fmcib.ssl.modules.nnclr
-
-
-
-
fmcib.ssl.modules.simclr
-
-
-
-
fmcib.ssl.modules.swav
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/modules/nnclr.html b/docs/api_docs/fmcib/ssl/modules/nnclr.html deleted file mode 100644 index 8bba380..0000000 --- a/docs/api_docs/fmcib/ssl/modules/nnclr.html +++ /dev/null @@ -1,340 +0,0 @@ - - - - - - -fmcib.ssl.modules.nnclr API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.nnclr

-
-
-
- -Expand source code - -
from typing import Any, Dict, List, Optional, Union
-
-import torch
-import torch.nn as nn
-from lightly.models.modules import NNCLRPredictionHead, NNCLRProjectionHead, NNMemoryBankModule
-
-
-class NNCLR(nn.Module):
-    """
-    Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py
-    """
-
-    def __init__(
-        self,
-        backbone: nn.Module,
-        num_ftrs: int = 4096,
-        proj_hidden_dim: int = 4096,
-        pred_hidden_dim: int = 4096,
-        out_dim: int = 256,
-        memory_bank_size: int = 4096,
-    ) -> None:
-        """
-        Initialize the NNCLR model.
-
-        Args:
-            backbone (nn.Module): The backbone neural network model.
-            num_ftrs (int, optional): The number of features in the backbone output. Default is 4096.
-            proj_hidden_dim (int, optional): The hidden dimension of the projection head. Default is 4096.
-            pred_hidden_dim (int, optional): The hidden dimension of the prediction head. Default is 4096.
-            out_dim (int, optional): The output dimension of the model. Default is 256.
-            memory_bank_size (int, optional): The size of the memory bank module. Default is 4096.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__()
-        self.backbone = backbone
-        self.projection_head = NNCLRProjectionHead(num_ftrs, proj_hidden_dim, out_dim)
-        self.prediction_head = NNCLRPredictionHead(out_dim, pred_hidden_dim, out_dim)
-        self.memory_bank = NNMemoryBankModule(memory_bank_size)
-
-    def forward(
-        self,
-        x: List[torch.Tensor],
-        get_nearest_neighbor: bool = True,
-    ):
-        # forward pass of first input x0
-        """
-        Forward pass of the model.
-
-        Args:
-            x (List[torch.Tensor]): A list containing two input tensors.
-            get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
-                Defaults to True.
-
-        Returns:
-            Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
-                A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
-                for each input tensor.
-        """
-        x0, x1 = x
-        f0 = self.backbone(x0).flatten(start_dim=1)
-        z0 = self.projection_head(f0)
-        p0 = self.prediction_head(z0)
-
-        if get_nearest_neighbor:
-            z0 = self.memory_bank(z0, update=False)
-
-        # forward pass of second input x1
-        f1 = self.backbone(x1).flatten(start_dim=1)
-        z1 = self.projection_head(f1)
-        p1 = self.prediction_head(z1)
-
-        if get_nearest_neighbor:
-            z1 = self.memory_bank(z1, update=True)
-
-        return (z0, p0), (z1, p1)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class NNCLR -(backbone: torch.nn.modules.module.Module, num_ftrs: int = 4096, proj_hidden_dim: int = 4096, pred_hidden_dim: int = 4096, out_dim: int = 256, memory_bank_size: int = 4096) -
-
-

Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py

-

Initialize the NNCLR model.

-

Args

-
-
backbone : nn.Module
-
The backbone neural network model.
-
num_ftrs : int, optional
-
The number of features in the backbone output. Default is 4096.
-
proj_hidden_dim : int, optional
-
The hidden dimension of the projection head. Default is 4096.
-
pred_hidden_dim : int, optional
-
The hidden dimension of the prediction head. Default is 4096.
-
out_dim : int, optional
-
The output dimension of the model. Default is 256.
-
memory_bank_size : int, optional
-
The size of the memory bank module. Default is 4096.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class NNCLR(nn.Module):
-    """
-    Taken largely from https://github.com/lightly-ai/lightly/blob/master/lightly/models/nnclr.py
-    """
-
-    def __init__(
-        self,
-        backbone: nn.Module,
-        num_ftrs: int = 4096,
-        proj_hidden_dim: int = 4096,
-        pred_hidden_dim: int = 4096,
-        out_dim: int = 256,
-        memory_bank_size: int = 4096,
-    ) -> None:
-        """
-        Initialize the NNCLR model.
-
-        Args:
-            backbone (nn.Module): The backbone neural network model.
-            num_ftrs (int, optional): The number of features in the backbone output. Default is 4096.
-            proj_hidden_dim (int, optional): The hidden dimension of the projection head. Default is 4096.
-            pred_hidden_dim (int, optional): The hidden dimension of the prediction head. Default is 4096.
-            out_dim (int, optional): The output dimension of the model. Default is 256.
-            memory_bank_size (int, optional): The size of the memory bank module. Default is 4096.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__()
-        self.backbone = backbone
-        self.projection_head = NNCLRProjectionHead(num_ftrs, proj_hidden_dim, out_dim)
-        self.prediction_head = NNCLRPredictionHead(out_dim, pred_hidden_dim, out_dim)
-        self.memory_bank = NNMemoryBankModule(memory_bank_size)
-
-    def forward(
-        self,
-        x: List[torch.Tensor],
-        get_nearest_neighbor: bool = True,
-    ):
-        # forward pass of first input x0
-        """
-        Forward pass of the model.
-
-        Args:
-            x (List[torch.Tensor]): A list containing two input tensors.
-            get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
-                Defaults to True.
-
-        Returns:
-            Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
-                A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
-                for each input tensor.
-        """
-        x0, x1 = x
-        f0 = self.backbone(x0).flatten(start_dim=1)
-        z0 = self.projection_head(f0)
-        p0 = self.prediction_head(z0)
-
-        if get_nearest_neighbor:
-            z0 = self.memory_bank(z0, update=False)
-
-        # forward pass of second input x1
-        f1 = self.backbone(x1).flatten(start_dim=1)
-        z1 = self.projection_head(f1)
-        p1 = self.prediction_head(z1)
-
-        if get_nearest_neighbor:
-            z1 = self.memory_bank(z1, update=True)
-
-        return (z0, p0), (z1, p1)
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x: List[torch.Tensor], get_nearest_neighbor: bool = True) ‑> Callable[..., Any] -
-
-

Forward pass of the model.

-

Args

-
-
x : List[torch.Tensor]
-
A list containing two input tensors.
-
get_nearest_neighbor : bool, optional
-
Whether to compute and update the nearest neighbor vectors. -Defaults to True.
-
-

Returns

-

Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: -A tuple containing two tuples. The inner tuples contain the projection and prediction vectors -for each input tensor.

-
- -Expand source code - -
def forward(
-    self,
-    x: List[torch.Tensor],
-    get_nearest_neighbor: bool = True,
-):
-    # forward pass of first input x0
-    """
-    Forward pass of the model.
-
-    Args:
-        x (List[torch.Tensor]): A list containing two input tensors.
-        get_nearest_neighbor (bool, optional): Whether to compute and update the nearest neighbor vectors.
-            Defaults to True.
-
-    Returns:
-        Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
-            A tuple containing two tuples. The inner tuples contain the projection and prediction vectors
-            for each input tensor.
-    """
-    x0, x1 = x
-    f0 = self.backbone(x0).flatten(start_dim=1)
-    z0 = self.projection_head(f0)
-    p0 = self.prediction_head(z0)
-
-    if get_nearest_neighbor:
-        z0 = self.memory_bank(z0, update=False)
-
-    # forward pass of second input x1
-    f1 = self.backbone(x1).flatten(start_dim=1)
-    z1 = self.projection_head(f1)
-    p1 = self.prediction_head(z1)
-
-    if get_nearest_neighbor:
-        z1 = self.memory_bank(z1, update=True)
-
-    return (z0, p0), (z1, p1)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/modules/simclr.html b/docs/api_docs/fmcib/ssl/modules/simclr.html deleted file mode 100644 index 60e730e..0000000 --- a/docs/api_docs/fmcib/ssl/modules/simclr.html +++ /dev/null @@ -1,272 +0,0 @@ - - - - - - -fmcib.ssl.modules.simclr API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.simclr

-
-
-
- -Expand source code - -
import torch
-import torch.nn as nn
-from lightly.models import SimCLR as lightly_SimCLR
-from lightly.models.modules import SimCLRProjectionHead
-
-
-class SimCLR(lightly_SimCLR):
-    """
-    A class representing a SimCLR model.
-
-    Attributes:
-        backbone (nn.Module): The backbone model used in the SimCLR model.
-        num_ftrs (int): The number of output features from the backbone model.
-        out_dim (int): The dimension of the output representations.
-        projection_head (SimCLRProjectionHead): The projection head used for projection head training.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128):
-        """
-        Initialize the object with a backbone network, number of features, and output dimension.
-
-        Args:
-            backbone (nn.Module): The backbone network.
-            num_ftrs (int): The number of features. Default is 32.
-            out_dim (int): The output dimension. Default is 128.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__(backbone, num_ftrs, out_dim)
-        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
-    def forward(self, x, return_features=False):
-        """
-        Perform a forward pass of the neural network.
-
-        Args:
-            x (tuple): A tuple of input data. Each element of the tuple represents a different input.
-            return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
-        Returns:
-            torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
-                If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
-        Raises:
-            None.
-        """
-        return super().forward(*x, return_features)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SimCLR -(backbone: torch.nn.modules.module.Module, num_ftrs: int = 32, out_dim: int = 128) -
-
-

A class representing a SimCLR model.

-

Attributes

-
-
backbone : nn.Module
-
The backbone model used in the SimCLR model.
-
num_ftrs : int
-
The number of output features from the backbone model.
-
out_dim : int
-
The dimension of the output representations.
-
projection_head : SimCLRProjectionHead
-
The projection head used for projection head training.
-
-

Initialize the object with a backbone network, number of features, and output dimension.

-

Args

-
-
backbone : nn.Module
-
The backbone network.
-
num_ftrs : int
-
The number of features. Default is 32.
-
out_dim : int
-
The output dimension. Default is 128.
-
-

Returns

-

None

-

Raises

-

None

-
- -Expand source code - -
class SimCLR(lightly_SimCLR):
-    """
-    A class representing a SimCLR model.
-
-    Attributes:
-        backbone (nn.Module): The backbone model used in the SimCLR model.
-        num_ftrs (int): The number of output features from the backbone model.
-        out_dim (int): The dimension of the output representations.
-        projection_head (SimCLRProjectionHead): The projection head used for projection head training.
-    """
-
-    def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128):
-        """
-        Initialize the object with a backbone network, number of features, and output dimension.
-
-        Args:
-            backbone (nn.Module): The backbone network.
-            num_ftrs (int): The number of features. Default is 32.
-            out_dim (int): The output dimension. Default is 128.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        super().__init__(backbone, num_ftrs, out_dim)
-        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs // 2, out_dim, batch_norm=False)
-
-    def forward(self, x, return_features=False):
-        """
-        Perform a forward pass of the neural network.
-
-        Args:
-            x (tuple): A tuple of input data. Each element of the tuple represents a different input.
-            return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
-        Returns:
-            torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
-                If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
-        Raises:
-            None.
-        """
-        return super().forward(*x, return_features)
-
-

Ancestors

-
    -
  • lightly.models.simclr.SimCLR
  • -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, x, return_features=False) ‑> Callable[..., Any] -
-
-

Perform a forward pass of the neural network.

-

Args

-
-
x : tuple
-
A tuple of input data. Each element of the tuple represents a different input.
-
return_features : bool, optional
-
Whether to return the intermediate features. Default is False.
-
-

Returns

-
-
torch.Tensor or tuple
-
The output of the forward pass. If return_features is False, a single tensor is returned. -If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
-

Raises

-

None.

-
- -Expand source code - -
def forward(self, x, return_features=False):
-    """
-    Perform a forward pass of the neural network.
-
-    Args:
-        x (tuple): A tuple of input data. Each element of the tuple represents a different input.
-        return_features (bool, optional): Whether to return the intermediate features. Default is False.
-
-    Returns:
-        torch.Tensor or tuple: The output of the forward pass. If return_features is False, a single tensor is returned.
-            If return_features is True, a tuple is returned consisting of the output tensor and the intermediate features.
-
-    Raises:
-        None.
-    """
-    return super().forward(*x, return_features)
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/modules/swav.html b/docs/api_docs/fmcib/ssl/modules/swav.html deleted file mode 100644 index 43daf71..0000000 --- a/docs/api_docs/fmcib/ssl/modules/swav.html +++ /dev/null @@ -1,561 +0,0 @@ - - - - - - -fmcib.ssl.modules.swav API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.modules.swav

-
-
-
- -Expand source code - -
import torch
-from lightly.loss.memory_bank import MemoryBankModule
-from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
-from torch import nn
-
-torch.set_float32_matmul_precision("medium")
-
-
-class SwaV(nn.Module):
-    """
-    Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
-
-    Args:
-        backbone (nn.Module): CNN backbone for feature extraction.
-        num_ftrs (int): Number of input features for the projection head.
-        out_dim (int): Output dimension for the projection head.
-        n_prototypes (int): Number of prototypes to compute.
-        n_queues (int): Number of memory banks (queues). Should be equal to the number of high-resolution inputs.
-        queue_length (int, optional): Length of the memory bank. Defaults to 0.
-        start_queue_at_epoch (int, optional): Number of the epoch at which SwaV starts using the queued features. Defaults to 0.
-        n_steps_frozen_prototypes (int, optional): Number of steps during which we keep the prototypes fixed. Defaults to 0.
-    """
-
-    def __init__(
-        self,
-        backbone: nn.Module,
-        num_ftrs: int,
-        out_dim: int,
-        n_prototypes: int,
-        n_queues: int,
-        queue_length: int = 0,
-        start_queue_at_epoch: int = 0,
-        n_steps_frozen_prototypes: int = 0,
-    ):
-        """
-        Initialize a SwaV model.
-
-        Args:
-            backbone (nn.Module): The backbone model.
-            num_ftrs (int): The number of input features.
-            out_dim (int): The dimension of the output.
-            n_prototypes (int): The number of prototypes.
-            n_queues (int): The number of queues.
-            queue_length (int, optional): The length of the queue. Default is 0.
-            start_queue_at_epoch (int, optional): The epoch at which to start using the queue. Default is 0.
-            n_steps_frozen_prototypes (int, optional): The number of steps to freeze prototypes. Default is 0.
-
-        Returns:
-            None
-
-        Attributes:
-            backbone (nn.Module): The backbone model.
-            projection_head (SwaVProjectionHead): The projection head.
-            prototypes (SwaVPrototypes): The prototypes.
-            queues (nn.ModuleList, optional): The queues. If n_queues > 0, this will be initialized with MemoryBankModules.
-            queue_length (int, optional): The length of the queue.
-            num_features_queued (int): The number of features queued.
-            start_queue_at_epoch (int): The epoch at which to start using the queue.
-        """
-        super().__init__()
-        # Backbone for feature extraction
-        self.backbone = backbone
-        # Projection head to project features to a lower-dimensional space
-        self.projection_head = SwaVProjectionHead(num_ftrs, num_ftrs // 2, out_dim)
-        # SwAV Prototypes module for prototype computation
-        self.prototypes = SwaVPrototypes(out_dim, n_prototypes, n_steps_frozen_prototypes)
-
-        self.queues = None
-        if n_queues > 0:
-            # Initialize the memory banks (queues)
-            self.queues = nn.ModuleList([MemoryBankModule(size=queue_length) for _ in range(n_queues)])
-            self.queue_length = queue_length
-            self.num_features_queued = 0
-            self.start_queue_at_epoch = start_queue_at_epoch
-
-    def forward(self, input, epoch=None, step=None):
-        """
-        Performs the forward pass for the SwAV model.
-
-        Args:
-            input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
-                and a list of low-resolution input images.
-            epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
-            step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
-        Returns:
-            Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
-                low-resolution prototypes, and queue prototypes.
-        """
-        high_resolution, low_resolution = input
-
-        # Normalize prototypes
-        self.prototypes.normalize()
-
-        # Compute high and low resolution features
-        high_resolution_features = [self._subforward(x) for x in high_resolution]
-        low_resolution_features = [self._subforward(x) for x in low_resolution]
-
-        # Compute prototypes for high and low resolution features
-        high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
-        low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
-        # Compute prototypes for queued features
-        queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
-        return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-
-    def _subforward(self, input):
-        """
-        Subforward pass to compute features for the input image.
-
-        Args:
-            input (Tensor): Input image tensor.
-
-        Returns:
-            Tensor: L2-normalized feature tensor.
-        """
-        # Extract features using the backbone
-        features = self.backbone(input).flatten(start_dim=1)
-        # Project features using the projection head
-        features = self.projection_head(features)
-        # L2-normalize features
-        features = nn.functional.normalize(features, dim=1, p=2)
-        return features
-
-    @torch.no_grad()
-    def _get_queue_prototypes(self, high_resolution_features, epoch=None):
-        """
-        Compute the queue prototypes for the given high-resolution features.
-
-        Args:
-            high_resolution_features (List[Tensor]): List of high-resolution feature tensors.
-            epoch (int, optional): Current epoch number. Required if `start_queue_at_epoch` > 0. Defaults to None.
-
-        Returns:
-            List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None.
-        """
-        if self.queues is None:
-            return None
-
-        if len(high_resolution_features) != len(self.queues):
-            raise ValueError(
-                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
-                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
-            )
-
-        # Get the queue features
-        queue_features = []
-        for i in range(len(self.queues)):
-            _, features = self.queues[i](high_resolution_features[i], update=True)
-            # Queue features are in (num_ftrs X queue_length) shape, while the high res
-            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
-            features = torch.permute(features, (1, 0))
-            queue_features.append(features)
-
-        # Do not return queue prototypes if not enough features have been queued
-        self.num_features_queued += high_resolution_features[0].shape[0]
-        if self.num_features_queued < self.queue_length:
-            return None
-
-        # If loss calculation with queue prototypes starts at a later epoch,
-        # just queue the features and return None instead of queue prototypes.
-        if self.start_queue_at_epoch > 0:
-            if epoch is None:
-                raise ValueError(
-                    "The epoch number must be passed to the `forward()` " "method if `start_queue_at_epoch` is greater than 0."
-                )
-            if epoch < self.start_queue_at_epoch:
-                return None
-
-        # Assign prototypes
-        queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
-        # Do not return queue prototypes if not enough features have been queued
-        return queue_prototypes
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class SwaV -(backbone: torch.nn.modules.module.Module, num_ftrs: int, out_dim: int, n_prototypes: int, n_queues: int, queue_length: int = 0, start_queue_at_epoch: int = 0, n_steps_frozen_prototypes: int = 0) -
-
-

Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.

-

Args

-
-
backbone : nn.Module
-
CNN backbone for feature extraction.
-
num_ftrs : int
-
Number of input features for the projection head.
-
out_dim : int
-
Output dimension for the projection head.
-
n_prototypes : int
-
Number of prototypes to compute.
-
n_queues : int
-
Number of memory banks (queues). Should be equal to the number of high-resolution inputs.
-
queue_length : int, optional
-
Length of the memory bank. Defaults to 0.
-
start_queue_at_epoch : int, optional
-
Number of the epoch at which SwaV starts using the queued features. Defaults to 0.
-
n_steps_frozen_prototypes : int, optional
-
Number of steps during which we keep the prototypes fixed. Defaults to 0.
-
-

Initialize a SwaV model.

-

Args

-
-
backbone : nn.Module
-
The backbone model.
-
num_ftrs : int
-
The number of input features.
-
out_dim : int
-
The dimension of the output.
-
n_prototypes : int
-
The number of prototypes.
-
n_queues : int
-
The number of queues.
-
queue_length : int, optional
-
The length of the queue. Default is 0.
-
start_queue_at_epoch : int, optional
-
The epoch at which to start using the queue. Default is 0.
-
n_steps_frozen_prototypes : int, optional
-
The number of steps to freeze prototypes. Default is 0.
-
-

Returns

-

None

-

Attributes

-
-
backbone : nn.Module
-
The backbone model.
-
projection_head : SwaVProjectionHead
-
The projection head.
-
prototypes : SwaVPrototypes
-
The prototypes.
-
queues : nn.ModuleList, optional
-
The queues. If n_queues > 0, this will be initialized with MemoryBankModules.
-
queue_length : int, optional
-
The length of the queue.
-
num_features_queued : int
-
The number of features queued.
-
start_queue_at_epoch : int
-
The epoch at which to start using the queue.
-
-
- -Expand source code - -
class SwaV(nn.Module):
-    """
-    Implements the SwAV (Swapping Assignments between multiple Views of the same image) model.
-
-    Args:
-        backbone (nn.Module): CNN backbone for feature extraction.
-        num_ftrs (int): Number of input features for the projection head.
-        out_dim (int): Output dimension for the projection head.
-        n_prototypes (int): Number of prototypes to compute.
-        n_queues (int): Number of memory banks (queues). Should be equal to the number of high-resolution inputs.
-        queue_length (int, optional): Length of the memory bank. Defaults to 0.
-        start_queue_at_epoch (int, optional): Number of the epoch at which SwaV starts using the queued features. Defaults to 0.
-        n_steps_frozen_prototypes (int, optional): Number of steps during which we keep the prototypes fixed. Defaults to 0.
-    """
-
-    def __init__(
-        self,
-        backbone: nn.Module,
-        num_ftrs: int,
-        out_dim: int,
-        n_prototypes: int,
-        n_queues: int,
-        queue_length: int = 0,
-        start_queue_at_epoch: int = 0,
-        n_steps_frozen_prototypes: int = 0,
-    ):
-        """
-        Initialize a SwaV model.
-
-        Args:
-            backbone (nn.Module): The backbone model.
-            num_ftrs (int): The number of input features.
-            out_dim (int): The dimension of the output.
-            n_prototypes (int): The number of prototypes.
-            n_queues (int): The number of queues.
-            queue_length (int, optional): The length of the queue. Default is 0.
-            start_queue_at_epoch (int, optional): The epoch at which to start using the queue. Default is 0.
-            n_steps_frozen_prototypes (int, optional): The number of steps to freeze prototypes. Default is 0.
-
-        Returns:
-            None
-
-        Attributes:
-            backbone (nn.Module): The backbone model.
-            projection_head (SwaVProjectionHead): The projection head.
-            prototypes (SwaVPrototypes): The prototypes.
-            queues (nn.ModuleList, optional): The queues. If n_queues > 0, this will be initialized with MemoryBankModules.
-            queue_length (int, optional): The length of the queue.
-            num_features_queued (int): The number of features queued.
-            start_queue_at_epoch (int): The epoch at which to start using the queue.
-        """
-        super().__init__()
-        # Backbone for feature extraction
-        self.backbone = backbone
-        # Projection head to project features to a lower-dimensional space
-        self.projection_head = SwaVProjectionHead(num_ftrs, num_ftrs // 2, out_dim)
-        # SwAV Prototypes module for prototype computation
-        self.prototypes = SwaVPrototypes(out_dim, n_prototypes, n_steps_frozen_prototypes)
-
-        self.queues = None
-        if n_queues > 0:
-            # Initialize the memory banks (queues)
-            self.queues = nn.ModuleList([MemoryBankModule(size=queue_length) for _ in range(n_queues)])
-            self.queue_length = queue_length
-            self.num_features_queued = 0
-            self.start_queue_at_epoch = start_queue_at_epoch
-
-    def forward(self, input, epoch=None, step=None):
-        """
-        Performs the forward pass for the SwAV model.
-
-        Args:
-            input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
-                and a list of low-resolution input images.
-            epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
-            step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
-        Returns:
-            Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
-                low-resolution prototypes, and queue prototypes.
-        """
-        high_resolution, low_resolution = input
-
-        # Normalize prototypes
-        self.prototypes.normalize()
-
-        # Compute high and low resolution features
-        high_resolution_features = [self._subforward(x) for x in high_resolution]
-        low_resolution_features = [self._subforward(x) for x in low_resolution]
-
-        # Compute prototypes for high and low resolution features
-        high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
-        low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
-        # Compute prototypes for queued features
-        queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
-        return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-
-    def _subforward(self, input):
-        """
-        Subforward pass to compute features for the input image.
-
-        Args:
-            input (Tensor): Input image tensor.
-
-        Returns:
-            Tensor: L2-normalized feature tensor.
-        """
-        # Extract features using the backbone
-        features = self.backbone(input).flatten(start_dim=1)
-        # Project features using the projection head
-        features = self.projection_head(features)
-        # L2-normalize features
-        features = nn.functional.normalize(features, dim=1, p=2)
-        return features
-
-    @torch.no_grad()
-    def _get_queue_prototypes(self, high_resolution_features, epoch=None):
-        """
-        Compute the queue prototypes for the given high-resolution features.
-
-        Args:
-            high_resolution_features (List[Tensor]): List of high-resolution feature tensors.
-            epoch (int, optional): Current epoch number. Required if `start_queue_at_epoch` > 0. Defaults to None.
-
-        Returns:
-            List[Tensor] or None: List of queue prototype tensors if conditions are met, otherwise None.
-        """
-        if self.queues is None:
-            return None
-
-        if len(high_resolution_features) != len(self.queues):
-            raise ValueError(
-                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
-                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
-            )
-
-        # Get the queue features
-        queue_features = []
-        for i in range(len(self.queues)):
-            _, features = self.queues[i](high_resolution_features[i], update=True)
-            # Queue features are in (num_ftrs X queue_length) shape, while the high res
-            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
-            features = torch.permute(features, (1, 0))
-            queue_features.append(features)
-
-        # Do not return queue prototypes if not enough features have been queued
-        self.num_features_queued += high_resolution_features[0].shape[0]
-        if self.num_features_queued < self.queue_length:
-            return None
-
-        # If loss calculation with queue prototypes starts at a later epoch,
-        # just queue the features and return None instead of queue prototypes.
-        if self.start_queue_at_epoch > 0:
-            if epoch is None:
-                raise ValueError(
-                    "The epoch number must be passed to the `forward()` " "method if `start_queue_at_epoch` is greater than 0."
-                )
-            if epoch < self.start_queue_at_epoch:
-                return None
-
-        # Assign prototypes
-        queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
-        # Do not return queue prototypes if not enough features have been queued
-        return queue_prototypes
-
-

Ancestors

-
    -
  • torch.nn.modules.module.Module
  • -
-

Class variables

-
-
var call_super_init : bool
-
-
-
-
var dump_patches : bool
-
-
-
-
var training : bool
-
-
-
-
-

Methods

-
-
-def forward(self, input, epoch=None, step=None) ‑> Callable[..., Any] -
-
-

Performs the forward pass for the SwAV model.

-

Args

-
-
input : Tuple[List[Tensor], List[Tensor]]
-
A tuple consisting of a list of high-resolution input images -and a list of low-resolution input images.
-
epoch : int, optional
-
Current training epoch. Required if start_queue_at_epoch > 0. Defaults to None.
-
step : int, optional
-
Current training step. Required if n_steps_frozen_prototypes > 0. Defaults to None.
-
-

Returns

-
-
Tuple[List[Tensor], List[Tensor], List[Tensor]]
-
A tuple containing lists of high-resolution prototypes, -low-resolution prototypes, and queue prototypes.
-
-
- -Expand source code - -
def forward(self, input, epoch=None, step=None):
-    """
-    Performs the forward pass for the SwAV model.
-
-    Args:
-        input (Tuple[List[Tensor], List[Tensor]]): A tuple consisting of a list of high-resolution input images
-            and a list of low-resolution input images.
-        epoch (int, optional): Current training epoch. Required if `start_queue_at_epoch` > 0. Defaults to None.
-        step (int, optional): Current training step. Required if `n_steps_frozen_prototypes` > 0. Defaults to None.
-
-    Returns:
-        Tuple[List[Tensor], List[Tensor], List[Tensor]]: A tuple containing lists of high-resolution prototypes,
-            low-resolution prototypes, and queue prototypes.
-    """
-    high_resolution, low_resolution = input
-
-    # Normalize prototypes
-    self.prototypes.normalize()
-
-    # Compute high and low resolution features
-    high_resolution_features = [self._subforward(x) for x in high_resolution]
-    low_resolution_features = [self._subforward(x) for x in low_resolution]
-
-    # Compute prototypes for high and low resolution features
-    high_resolution_prototypes = [self.prototypes(x, epoch) for x in high_resolution_features]
-    low_resolution_prototypes = [self.prototypes(x, epoch) for x in low_resolution_features]
-    # Compute prototypes for queued features
-    queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
-
-    return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/optimizers/index.html b/docs/api_docs/fmcib/ssl/optimizers/index.html deleted file mode 100644 index dc09bcd..0000000 --- a/docs/api_docs/fmcib/ssl/optimizers/index.html +++ /dev/null @@ -1,75 +0,0 @@ - - - - - - -fmcib.ssl.optimizers API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.optimizers

-
-
-
- -Expand source code - -
from .lars import LARS
-
-
-
-

Sub-modules

-
-
fmcib.ssl.optimizers.lars
-
- -
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/optimizers/lars.html b/docs/api_docs/fmcib/ssl/optimizers/lars.html deleted file mode 100644 index 05e0643..0000000 --- a/docs/api_docs/fmcib/ssl/optimizers/lars.html +++ /dev/null @@ -1,509 +0,0 @@ - - - - - - -fmcib.ssl.optimizers.lars API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.optimizers.lars

-
-
-

References

- -
- -Expand source code - -
"""
-References:
-    - https://arxiv.org/pdf/1708.03888.pdf
-    - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py
-"""
-import torch
-from torch.optim.optimizer import Optimizer, required
-
-
-class LARS(Optimizer):
-    """Extends SGD in PyTorch with LARS scaling from the paper
-    `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
-    Args:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float): learning rate
-        momentum (float, optional): momentum factor (default: 0)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        dampening (float, optional): dampening for momentum (default: 0)
-        nesterov (bool, optional): enables Nesterov momentum (default: False)
-        trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
-        eps (float, optional): eps for division denominator (default: 1e-8)
-
-    Example:
-        >>> model = torch.nn.Linear(10, 1)
-        >>> input = torch.Tensor(10)
-        >>> target = torch.Tensor([1.])
-        >>> loss_fn = lambda input, target: (input - target) ** 2
-        >>> #
-        >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
-        >>> optimizer.zero_grad()
-        >>> loss_fn(model(input), target).backward()
-        >>> optimizer.step()
-
-    .. note::
-        The application of momentum in the SGD part is modified according to
-        the PyTorch standards. LARS scaling fits into the equation in the
-        following fashion.
-
-        .. math::
-            \begin{aligned}
-                g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
-                v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
-                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
-            \\end{aligned}
-
-        where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
-        parameters, gradient, velocity, momentum, and weight decay respectively.
-        The :math:`lars_lr` is defined by Eq. 6 in the paper.
-        The Nesterov version is analogously modified.
-
-    .. warning::
-        Parameters with weight decay set to 0 will automatically be excluded from
-        layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
-        and BYOL.
-    """
-
-    def __init__(
-        self,
-        params,
-        lr=required,
-        momentum=0,
-        dampening=0,
-        weight_decay=0,
-        nesterov=False,
-        trust_coefficient=0.001,
-        eps=1e-8,
-    ):
-        if lr is not required and lr < 0.0:
-            raise ValueError(f"Invalid learning rate: {lr}")
-        if momentum < 0.0:
-            raise ValueError(f"Invalid momentum value: {momentum}")
-        if weight_decay < 0.0:
-            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
-        defaults = dict(
-            lr=lr,
-            momentum=momentum,
-            dampening=dampening,
-            weight_decay=weight_decay,
-            nesterov=nesterov,
-            trust_coefficient=trust_coefficient,
-            eps=eps,
-        )
-        if nesterov and (momentum <= 0 or dampening != 0):
-            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
-        super().__init__(params, defaults)
-
-    def __setstate__(self, state):
-        super().__setstate__(state)
-
-        for group in self.param_groups:
-            group.setdefault("nesterov", False)
-
-    @torch.no_grad()
-    def step(self, closure=None):
-        """Performs a single optimization step.
-
-        Args:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
-        """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
-        # exclude scaling for params with 0 weight decay
-        for group in self.param_groups:
-            weight_decay = group["weight_decay"]
-            momentum = group["momentum"]
-            dampening = group["dampening"]
-            nesterov = group["nesterov"]
-
-            for p in group["params"]:
-                if p.grad is None:
-                    continue
-
-                d_p = p.grad
-                p_norm = torch.norm(p.data)
-                g_norm = torch.norm(p.grad.data)
-
-                # lars scaling + weight decay part
-                if weight_decay != 0:
-                    if p_norm != 0 and g_norm != 0:
-                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                        lars_lr *= group["trust_coefficient"]
-
-                        d_p = d_p.add(p, alpha=weight_decay)
-                        d_p *= lars_lr
-
-                # sgd part
-                if momentum != 0:
-                    param_state = self.state[p]
-                    if "momentum_buffer" not in param_state:
-                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                    else:
-                        buf = param_state["momentum_buffer"]
-                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                    if nesterov:
-                        d_p = d_p.add(buf, alpha=momentum)
-                    else:
-                        d_p = buf
-
-                p.add_(d_p, alpha=-group["lr"])
-
-        return loss
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class LARS -(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08) -
-
-

Extends SGD in PyTorch with LARS scaling from the paper -Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>_.

-

Args

-
-
params : iterable
-
iterable of parameters to optimize or dicts defining -parameter groups
-
lr : float
-
learning rate
-
momentum : float, optional
-
momentum factor (default: 0)
-
weight_decay : float, optional
-
weight decay (L2 penalty) (default: 0)
-
dampening : float, optional
-
dampening for momentum (default: 0)
-
nesterov : bool, optional
-
enables Nesterov momentum (default: False)
-
trust_coefficient : float, optional
-
trust coefficient for computing LR (default: 0.001)
-
eps : float, optional
-
eps for division denominator (default: 1e-8)
-
-

Example

-
>>> model = torch.nn.Linear(10, 1)
->>> input = torch.Tensor(10)
->>> target = torch.Tensor([1.])
->>> loss_fn = lambda input, target: (input - target) ** 2
->>> #
->>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
->>> optimizer.zero_grad()
->>> loss_fn(model(input), target).backward()
->>> optimizer.step()
-
-
-

Note

-

The application of momentum in the SGD part is modified according to -the PyTorch standards. LARS scaling fits into the equation in the -following fashion.

-

[ egin{aligned} -g_{t+1} & = -ext{lars_lr} * (eta * p_{t} + g_{t+1}), \ -v_{t+1} & = \mu * v_{t} + g_{t+1}, \ -p_{t+1} & = p_{t} - -ext{lr} * v_{t+1}, -\end{aligned} ] -where :math:p, :math:g, :math:v, :math:\mu and :math:eta denote the -parameters, gradient, velocity, momentum, and weight decay respectively. -The :math:lars_lr is defined by Eq. 6 in the paper. -The Nesterov version is analogously modified.

-
-
-

Warning

-

Parameters with weight decay set to 0 will automatically be excluded from -layer-wise LR scaling. This is to ensure consistency with papers like SimCLR -and BYOL.

-
-
- -Expand source code - -
class LARS(Optimizer):
-    """Extends SGD in PyTorch with LARS scaling from the paper
-    `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
-    Args:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float): learning rate
-        momentum (float, optional): momentum factor (default: 0)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        dampening (float, optional): dampening for momentum (default: 0)
-        nesterov (bool, optional): enables Nesterov momentum (default: False)
-        trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
-        eps (float, optional): eps for division denominator (default: 1e-8)
-
-    Example:
-        >>> model = torch.nn.Linear(10, 1)
-        >>> input = torch.Tensor(10)
-        >>> target = torch.Tensor([1.])
-        >>> loss_fn = lambda input, target: (input - target) ** 2
-        >>> #
-        >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
-        >>> optimizer.zero_grad()
-        >>> loss_fn(model(input), target).backward()
-        >>> optimizer.step()
-
-    .. note::
-        The application of momentum in the SGD part is modified according to
-        the PyTorch standards. LARS scaling fits into the equation in the
-        following fashion.
-
-        .. math::
-            \begin{aligned}
-                g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
-                v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
-                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
-            \\end{aligned}
-
-        where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
-        parameters, gradient, velocity, momentum, and weight decay respectively.
-        The :math:`lars_lr` is defined by Eq. 6 in the paper.
-        The Nesterov version is analogously modified.
-
-    .. warning::
-        Parameters with weight decay set to 0 will automatically be excluded from
-        layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
-        and BYOL.
-    """
-
-    def __init__(
-        self,
-        params,
-        lr=required,
-        momentum=0,
-        dampening=0,
-        weight_decay=0,
-        nesterov=False,
-        trust_coefficient=0.001,
-        eps=1e-8,
-    ):
-        if lr is not required and lr < 0.0:
-            raise ValueError(f"Invalid learning rate: {lr}")
-        if momentum < 0.0:
-            raise ValueError(f"Invalid momentum value: {momentum}")
-        if weight_decay < 0.0:
-            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
-
-        defaults = dict(
-            lr=lr,
-            momentum=momentum,
-            dampening=dampening,
-            weight_decay=weight_decay,
-            nesterov=nesterov,
-            trust_coefficient=trust_coefficient,
-            eps=eps,
-        )
-        if nesterov and (momentum <= 0 or dampening != 0):
-            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
-
-        super().__init__(params, defaults)
-
-    def __setstate__(self, state):
-        super().__setstate__(state)
-
-        for group in self.param_groups:
-            group.setdefault("nesterov", False)
-
-    @torch.no_grad()
-    def step(self, closure=None):
-        """Performs a single optimization step.
-
-        Args:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
-        """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
-        # exclude scaling for params with 0 weight decay
-        for group in self.param_groups:
-            weight_decay = group["weight_decay"]
-            momentum = group["momentum"]
-            dampening = group["dampening"]
-            nesterov = group["nesterov"]
-
-            for p in group["params"]:
-                if p.grad is None:
-                    continue
-
-                d_p = p.grad
-                p_norm = torch.norm(p.data)
-                g_norm = torch.norm(p.grad.data)
-
-                # lars scaling + weight decay part
-                if weight_decay != 0:
-                    if p_norm != 0 and g_norm != 0:
-                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                        lars_lr *= group["trust_coefficient"]
-
-                        d_p = d_p.add(p, alpha=weight_decay)
-                        d_p *= lars_lr
-
-                # sgd part
-                if momentum != 0:
-                    param_state = self.state[p]
-                    if "momentum_buffer" not in param_state:
-                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                    else:
-                        buf = param_state["momentum_buffer"]
-                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                    if nesterov:
-                        d_p = d_p.add(buf, alpha=momentum)
-                    else:
-                        d_p = buf
-
-                p.add_(d_p, alpha=-group["lr"])
-
-        return loss
-
-

Ancestors

-
    -
  • torch.optim.optimizer.Optimizer
  • -
-

Methods

-
-
-def step(self, closure=None) -
-
-

Performs a single optimization step.

-

Args

-
-
closure : callable, optional
-
A closure that reevaluates the model -and returns the loss.
-
-
- -Expand source code - -
@torch.no_grad()
-def step(self, closure=None):
-    """Performs a single optimization step.
-
-    Args:
-        closure (callable, optional): A closure that reevaluates the model
-            and returns the loss.
-    """
-    loss = None
-    if closure is not None:
-        with torch.enable_grad():
-            loss = closure()
-
-    # exclude scaling for params with 0 weight decay
-    for group in self.param_groups:
-        weight_decay = group["weight_decay"]
-        momentum = group["momentum"]
-        dampening = group["dampening"]
-        nesterov = group["nesterov"]
-
-        for p in group["params"]:
-            if p.grad is None:
-                continue
-
-            d_p = p.grad
-            p_norm = torch.norm(p.data)
-            g_norm = torch.norm(p.grad.data)
-
-            # lars scaling + weight decay part
-            if weight_decay != 0:
-                if p_norm != 0 and g_norm != 0:
-                    lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
-                    lars_lr *= group["trust_coefficient"]
-
-                    d_p = d_p.add(p, alpha=weight_decay)
-                    d_p *= lars_lr
-
-            # sgd part
-            if momentum != 0:
-                param_state = self.state[p]
-                if "momentum_buffer" not in param_state:
-                    buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
-                else:
-                    buf = param_state["momentum_buffer"]
-                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
-                if nesterov:
-                    d_p = d_p.add(buf, alpha=momentum)
-                else:
-                    d_p = buf
-
-            p.add_(d_p, alpha=-group["lr"])
-
-    return loss
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/transforms/duplicate.html b/docs/api_docs/fmcib/ssl/transforms/duplicate.html deleted file mode 100644 index cc725a5..0000000 --- a/docs/api_docs/fmcib/ssl/transforms/duplicate.html +++ /dev/null @@ -1,153 +0,0 @@ - - - - - - -fmcib.ssl.transforms.duplicate API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.transforms.duplicate

-
-
-
- -Expand source code - -
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-
-
-class Duplicate:
-    """Duplicate an input and apply two different transforms. Used for SimCLR primarily."""
-
-    def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
-        """Duplicates an input and applies the given transformations to each copy separately.
-
-        Args:
-            transforms1 (Optional[Callable], optional): _description_. Defaults to None.
-            transforms2 (Optional[Callable], optional): _description_. Defaults to None.
-        """
-        # Wrapped into a list if it isn't one already to allow both a
-        # list of transforms as well as `torchvision.transform.Compose` transforms.
-        self.transforms1 = transforms1
-        self.transforms2 = transforms2
-
-    def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Args:
-            input (torch.Tensor or any other type supported by the given transforms): Input.
-
-        Returns:
-            Tuple[torch.Tensor, torch.Tensor]: a tuple of two tensors.
-        """
-        out1, out2 = input, deepcopy(input)
-        if self.transforms1 is not None:
-            out1 = self.transforms1(out1)
-        if self.transforms2 is not None:
-            out2 = self.transforms2(out2)
-        return (out1, out2)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class Duplicate -(transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None) -
-
-

Duplicate an input and apply two different transforms. Used for SimCLR primarily.

-

Duplicates an input and applies the given transformations to each copy separately.

-

Args

-
-
transforms1 : Optional[Callable], optional
-
description. Defaults to None.
-
transforms2 : Optional[Callable], optional
-
description. Defaults to None.
-
-
- -Expand source code - -
class Duplicate:
-    """Duplicate an input and apply two different transforms. Used for SimCLR primarily."""
-
-    def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
-        """Duplicates an input and applies the given transformations to each copy separately.
-
-        Args:
-            transforms1 (Optional[Callable], optional): _description_. Defaults to None.
-            transforms2 (Optional[Callable], optional): _description_. Defaults to None.
-        """
-        # Wrapped into a list if it isn't one already to allow both a
-        # list of transforms as well as `torchvision.transform.Compose` transforms.
-        self.transforms1 = transforms1
-        self.transforms2 = transforms2
-
-    def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Args:
-            input (torch.Tensor or any other type supported by the given transforms): Input.
-
-        Returns:
-            Tuple[torch.Tensor, torch.Tensor]: a tuple of two tensors.
-        """
-        out1, out2 = input, deepcopy(input)
-        if self.transforms1 is not None:
-            out1 = self.transforms1(out1)
-        if self.transforms2 is not None:
-            out2 = self.transforms2(out2)
-        return (out1, out2)
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/transforms/index.html b/docs/api_docs/fmcib/ssl/transforms/index.html deleted file mode 100644 index 93d17e8..0000000 --- a/docs/api_docs/fmcib/ssl/transforms/index.html +++ /dev/null @@ -1,77 +0,0 @@ - - - - - - -fmcib.ssl.transforms API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.transforms

-
-
-
- -Expand source code - -
from .duplicate import Duplicate
-from .random_resized_crop import RandomResizedCrop3D
-
-
-
-

Sub-modules

-
-
fmcib.ssl.transforms.duplicate
-
-
-
-
fmcib.ssl.transforms.random_resized_crop
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/ssl/transforms/random_resized_crop.html b/docs/api_docs/fmcib/ssl/transforms/random_resized_crop.html deleted file mode 100644 index bf47f01..0000000 --- a/docs/api_docs/fmcib/ssl/transforms/random_resized_crop.html +++ /dev/null @@ -1,165 +0,0 @@ - - - - - - -fmcib.ssl.transforms.random_resized_crop API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.ssl.transforms.random_resized_crop

-
-
-
- -Expand source code - -
from typing import Any, Dict, List
-
-import torch
-from monai.transforms import RandScaleCrop, Resize, Transform
-
-
-class RandomResizedCrop3D(Transform):
-    """
-    Combines monai's random spatial crop followed by resize to the desired size.
-
-    Modification:
-    1. The spatial crop is done with same dimensions for all the axes
-    2. Handles cases where the image_size is less than the crop_size by choosing
-        the smallest dimension as the random scale.
-
-    """
-
-    def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
-        """
-        Args:
-            scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
-             before resizing. The scale is defined with respect to the area of the original image.
-        """
-        super().__init__()
-        self.prob = prob
-        self.scale = scale
-        self.size = [size] * 3
-
-    def __call__(self, image):
-        if torch.rand(1) < self.prob:
-            random_scale = torch.empty(1).uniform_(*self.scale).item()
-            rand_cropper = RandScaleCrop(random_scale, random_size=False)
-            resizer = Resize(self.size, mode="trilinear")
-
-            for transform in [rand_cropper, resizer]:
-                image = transform(image)
-
-        return image
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class RandomResizedCrop3D -(prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]) -
-
-

Combines monai's random spatial crop followed by resize to the desired size.

-

Modification: -1. The spatial crop is done with same dimensions for all the axes -2. Handles cases where the image_size is less than the crop_size by choosing -the smallest dimension as the random scale.

-

Args

-
-
scale : List[int]
-
Specifies the lower and upper bounds for the random area of the crop,
-
-

before resizing. The scale is defined with respect to the area of the original image.

-
- -Expand source code - -
class RandomResizedCrop3D(Transform):
-    """
-    Combines monai's random spatial crop followed by resize to the desired size.
-
-    Modification:
-    1. The spatial crop is done with same dimensions for all the axes
-    2. Handles cases where the image_size is less than the crop_size by choosing
-        the smallest dimension as the random scale.
-
-    """
-
-    def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
-        """
-        Args:
-            scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
-             before resizing. The scale is defined with respect to the area of the original image.
-        """
-        super().__init__()
-        self.prob = prob
-        self.scale = scale
-        self.size = [size] * 3
-
-    def __call__(self, image):
-        if torch.rand(1) < self.prob:
-            random_scale = torch.empty(1).uniform_(*self.scale).item()
-            rand_cropper = RandScaleCrop(random_scale, random_size=False)
-            resizer = Resize(self.size, mode="trilinear")
-
-            for transform in [rand_cropper, resizer]:
-                image = transform(image)
-
-        return image
-
-

Ancestors

-
    -
  • monai.transforms.transform.Transform
  • -
  • abc.ABC
  • -
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/transforms/duplicate.html b/docs/api_docs/fmcib/transforms/duplicate.html deleted file mode 100644 index 7a4f69f..0000000 --- a/docs/api_docs/fmcib/transforms/duplicate.html +++ /dev/null @@ -1,159 +0,0 @@ - - - - - - -fmcib.transforms.duplicate API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.transforms.duplicate

-
-
-
- -Expand source code - -
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-
-
-class Duplicate:
-    """
-    Duplicate an input and apply two different transforms. Used for SimCLR primarily.
-    """
-
-    def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
-        """
-        Duplicates an input and applies the given transformations to each copy separately.
-
-        Args:
-            transforms1 (Optional[Callable]): _description_. Default is None.
-            transforms2 (Optional[Callable]): _description_. Default is None.
-        """
-        # Wrapped into a list if it isn't one already to allow both a
-        # list of transforms as well as `torchvision.transform.Compose` transforms.
-        self.transforms1 = transforms1
-        self.transforms2 = transforms2
-
-    def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Args:
-            input (torch.Tensor or any other type supported by the given transforms): Input.
-
-        Returns:
-            Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors.
-        """
-        out1, out2 = input, deepcopy(input)
-        if self.transforms1 is not None:
-            out1 = self.transforms1(out1)
-        if self.transforms2 is not None:
-            out2 = self.transforms2(out2)
-        return (out1, out2)
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class Duplicate -(transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None) -
-
-

Duplicate an input and apply two different transforms. Used for SimCLR primarily.

-

Duplicates an input and applies the given transformations to each copy separately.

-

Args

-
-
transforms1 : Optional[Callable]
-
description. Default is None.
-
transforms2 : Optional[Callable]
-
description. Default is None.
-
-
- -Expand source code - -
class Duplicate:
-    """
-    Duplicate an input and apply two different transforms. Used for SimCLR primarily.
-    """
-
-    def __init__(self, transforms1: Optional[Callable] = None, transforms2: Optional[Callable] = None):
-        """
-        Duplicates an input and applies the given transformations to each copy separately.
-
-        Args:
-            transforms1 (Optional[Callable]): _description_. Default is None.
-            transforms2 (Optional[Callable]): _description_. Default is None.
-        """
-        # Wrapped into a list if it isn't one already to allow both a
-        # list of transforms as well as `torchvision.transform.Compose` transforms.
-        self.transforms1 = transforms1
-        self.transforms2 = transforms2
-
-    def __call__(self, input: Any) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Args:
-            input (torch.Tensor or any other type supported by the given transforms): Input.
-
-        Returns:
-            Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors.
-        """
-        out1, out2 = input, deepcopy(input)
-        if self.transforms1 is not None:
-            out1 = self.transforms1(out1)
-        if self.transforms2 is not None:
-            out2 = self.transforms2(out2)
-        return (out1, out2)
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/transforms/index.html b/docs/api_docs/fmcib/transforms/index.html deleted file mode 100644 index 879369e..0000000 --- a/docs/api_docs/fmcib/transforms/index.html +++ /dev/null @@ -1,89 +0,0 @@ - - - - - - -fmcib.transforms API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.transforms

-
-
-
- -Expand source code - -
from .duplicate import Duplicate
-from .med3d import IntensityNormalizeOneVolume
-from .multicrop import MultiCrop
-from .random_resized_crop import RandomResizedCrop3D
-
-
-
-

Sub-modules

-
-
fmcib.transforms.duplicate
-
-
-
-
fmcib.transforms.med3d
-
-
-
-
fmcib.transforms.multicrop
-
-
-
-
fmcib.transforms.random_resized_crop
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/transforms/med3d.html b/docs/api_docs/fmcib/transforms/med3d.html deleted file mode 100644 index ca14a0f..0000000 --- a/docs/api_docs/fmcib/transforms/med3d.html +++ /dev/null @@ -1,205 +0,0 @@ - - - - - - -fmcib.transforms.med3d API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.transforms.med3d

-
-
-
- -Expand source code - -
import numpy as np
-from monai.transforms import Transform
-
-
-class IntensityNormalizeOneVolume(Transform):
-    """
-    A class representing an intensity normalized volume.
-
-    Attributes:
-        None
-
-    Methods:
-        __call__(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.
-
-        Args:
-            volume (numpy.ndarray): The input n-dimensional volume.
-
-        Returns:
-            out (numpy.ndarray): The normalized n-dimensional volume.
-    """
-
-    def __init__(self):
-        """
-        Initialize the object.
-
-        Returns:
-            None
-        """
-        super().__init__()
-
-    def __call__(self, volume):
-        """
-        Normalize the intensity of an nd volume based on the mean and std of the non-zero region.
-
-        Args:
-            volume: The input nd volume.
-
-        Returns:
-            out: The normalized nd volume.
-        """
-        volume = volume.astype(np.float32)
-        low, high = np.percentile(volume, [0.5, 99.5])
-        if high > 0:
-            volume = np.clip(volume, low, high)
-
-        pixels = volume[volume > 0]
-        mean = pixels.mean()
-        std = pixels.std()
-        out = (volume - mean) / std
-        out_random = np.random.normal(0, 1, size=volume.shape)
-        out[volume == 0] = out_random[volume == 0]
-        return out
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class IntensityNormalizeOneVolume -
-
-

A class representing an intensity normalized volume.

-

Attributes

-

None

-

Methods

-

call(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.

-

Args: -volume (numpy.ndarray): The input n-dimensional volume.

-

Returns: -out (numpy.ndarray): The normalized n-dimensional volume.

-

Initialize the object.

-

Returns

-

None

-
- -Expand source code - -
class IntensityNormalizeOneVolume(Transform):
-    """
-    A class representing an intensity normalized volume.
-
-    Attributes:
-        None
-
-    Methods:
-        __call__(self, volume): Normalize the intensity of an n-dimensional volume based on the mean and standard deviation of the non-zero region.
-
-        Args:
-            volume (numpy.ndarray): The input n-dimensional volume.
-
-        Returns:
-            out (numpy.ndarray): The normalized n-dimensional volume.
-    """
-
-    def __init__(self):
-        """
-        Initialize the object.
-
-        Returns:
-            None
-        """
-        super().__init__()
-
-    def __call__(self, volume):
-        """
-        Normalize the intensity of an nd volume based on the mean and std of the non-zero region.
-
-        Args:
-            volume: The input nd volume.
-
-        Returns:
-            out: The normalized nd volume.
-        """
-        volume = volume.astype(np.float32)
-        low, high = np.percentile(volume, [0.5, 99.5])
-        if high > 0:
-            volume = np.clip(volume, low, high)
-
-        pixels = volume[volume > 0]
-        mean = pixels.mean()
-        std = pixels.std()
-        out = (volume - mean) / std
-        out_random = np.random.normal(0, 1, size=volume.shape)
-        out[volume == 0] = out_random[volume == 0]
-        return out
-
-

Ancestors

-
    -
  • monai.transforms.transform.Transform
  • -
  • abc.ABC
  • -
-

Class variables

-
-
var backend : list[TransformBackends]
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/transforms/multicrop.html b/docs/api_docs/fmcib/transforms/multicrop.html deleted file mode 100644 index bacda27..0000000 --- a/docs/api_docs/fmcib/transforms/multicrop.html +++ /dev/null @@ -1,158 +0,0 @@ - - - - - - -fmcib.transforms.multicrop API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.transforms.multicrop

-
-
-
- -Expand source code - -
from typing import Any, Callable, List, Optional, Tuple
-
-from copy import deepcopy
-
-import torch
-from lighter.utils.misc import ensure_list
-
-
-class MultiCrop:
-    """
-    Multi-Crop augmentation.
-    """
-
-    def __init__(self, high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]]):
-        """
-        Initialize an instance of a class with transformations for high-resolution and low-resolution images.
-
-        Args:
-            high_resolution_transforms (list): A list of Callable objects representing the transformations to be applied to high-resolution images.
-            low_resolution_transforms (list, optional): A list of Callable objects representing the transformations to be applied to low-resolution images. Default is None.
-        """
-        self.high_resolution_transforms = ensure_list(high_resolution_transforms)
-        self.low_resolution_transforms = ensure_list(low_resolution_transforms)
-
-    def __call__(self, input):
-        """
-        This function applies a set of transformations to an input image and returns high and low-resolution crops.
-
-        Args:
-            input (image): The input image to be transformed.
-
-        Returns:
-            tuple: A tuple containing two lists:
-                - high_resolution_crops (list): A list of high-resolution cropped images.
-                - low_resolution_crops (list): A list of low-resolution cropped images.
-        """
-        high_resolution_crops = [transform(input) for transform in self.high_resolution_transforms]
-        low_resolution_crops = [transform(input) for transform in self.low_resolution_transforms]
-        return high_resolution_crops, low_resolution_crops
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class MultiCrop -(high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]]) -
-
-

Multi-Crop augmentation.

-

Initialize an instance of a class with transformations for high-resolution and low-resolution images.

-

Args

-
-
high_resolution_transforms : list
-
A list of Callable objects representing the transformations to be applied to high-resolution images.
-
low_resolution_transforms : list, optional
-
A list of Callable objects representing the transformations to be applied to low-resolution images. Default is None.
-
-
- -Expand source code - -
class MultiCrop:
-    """
-    Multi-Crop augmentation.
-    """
-
-    def __init__(self, high_resolution_transforms: List[Callable], low_resolution_transforms: Optional[List[Callable]]):
-        """
-        Initialize an instance of a class with transformations for high-resolution and low-resolution images.
-
-        Args:
-            high_resolution_transforms (list): A list of Callable objects representing the transformations to be applied to high-resolution images.
-            low_resolution_transforms (list, optional): A list of Callable objects representing the transformations to be applied to low-resolution images. Default is None.
-        """
-        self.high_resolution_transforms = ensure_list(high_resolution_transforms)
-        self.low_resolution_transforms = ensure_list(low_resolution_transforms)
-
-    def __call__(self, input):
-        """
-        This function applies a set of transformations to an input image and returns high and low-resolution crops.
-
-        Args:
-            input (image): The input image to be transformed.
-
-        Returns:
-            tuple: A tuple containing two lists:
-                - high_resolution_crops (list): A list of high-resolution cropped images.
-                - low_resolution_crops (list): A list of low-resolution cropped images.
-        """
-        high_resolution_crops = [transform(input) for transform in self.high_resolution_transforms]
-        low_resolution_crops = [transform(input) for transform in self.low_resolution_transforms]
-        return high_resolution_crops, low_resolution_crops
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/transforms/random_resized_crop.html b/docs/api_docs/fmcib/transforms/random_resized_crop.html deleted file mode 100644 index aca875c..0000000 --- a/docs/api_docs/fmcib/transforms/random_resized_crop.html +++ /dev/null @@ -1,188 +0,0 @@ - - - - - - -fmcib.transforms.random_resized_crop API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.transforms.random_resized_crop

-
-
-
- -Expand source code - -
from typing import Any, Dict, List
-
-import torch
-from monai.transforms import RandScaleCrop, Resize, Transform
-
-
-class RandomResizedCrop3D(Transform):
-    """
-    Combines monai's random spatial crop followed by resize to the desired size.
-
-    Modifications:
-    1. The spatial crop is done with the same dimensions for all the axes.
-    2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.
-    """
-
-    def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
-        """
-        Args:
-            scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
-             before resizing. The scale is defined with respect to the area of the original image.
-        """
-        super().__init__()
-        self.prob = prob
-        self.scale = scale
-        self.size = [size] * 3
-
-    def __call__(self, image):
-        """
-        Call method to apply random scale cropping and resizing to an image.
-
-        Args:
-            image (torch.Tensor): The input image.
-
-        Returns:
-            torch.Tensor: The transformed image.
-        """
-        if torch.rand(1) < self.prob:
-            random_scale = torch.empty(1).uniform_(*self.scale).item()
-            rand_cropper = RandScaleCrop(random_scale, random_size=False)
-            resizer = Resize(self.size, mode="trilinear")
-
-            for transform in [rand_cropper, resizer]:
-                image = transform(image)
-
-        return image
-
-
-
-
-
-
-
-
-
-

Classes

-
-
-class RandomResizedCrop3D -(prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]) -
-
-

Combines monai's random spatial crop followed by resize to the desired size.

-

Modifications: -1. The spatial crop is done with the same dimensions for all the axes. -2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.

-

Args

-
-
scale : List[int]
-
Specifies the lower and upper bounds for the random area of the crop,
-
-

before resizing. The scale is defined with respect to the area of the original image.

-
- -Expand source code - -
class RandomResizedCrop3D(Transform):
-    """
-    Combines monai's random spatial crop followed by resize to the desired size.
-
-    Modifications:
-    1. The spatial crop is done with the same dimensions for all the axes.
-    2. Handles cases where the image_size is less than the crop_size by choosing the smallest dimension as the random scale.
-    """
-
-    def __init__(self, prob: float = 1, size: int = 50, scale: List[float] = [0.5, 1.0]):
-        """
-        Args:
-            scale (List[int]): Specifies the lower and upper bounds for the random area of the crop,
-             before resizing. The scale is defined with respect to the area of the original image.
-        """
-        super().__init__()
-        self.prob = prob
-        self.scale = scale
-        self.size = [size] * 3
-
-    def __call__(self, image):
-        """
-        Call method to apply random scale cropping and resizing to an image.
-
-        Args:
-            image (torch.Tensor): The input image.
-
-        Returns:
-            torch.Tensor: The transformed image.
-        """
-        if torch.rand(1) < self.prob:
-            random_scale = torch.empty(1).uniform_(*self.scale).item()
-            rand_cropper = RandScaleCrop(random_scale, random_size=False)
-            resizer = Resize(self.size, mode="trilinear")
-
-            for transform in [rand_cropper, resizer]:
-                image = transform(image)
-
-        return image
-
-

Ancestors

-
    -
  • monai.transforms.transform.Transform
  • -
  • abc.ABC
  • -
-

Class variables

-
-
var backend : list[TransformBackends]
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/utils/download_utils.html b/docs/api_docs/fmcib/utils/download_utils.html deleted file mode 100644 index 65ad2aa..0000000 --- a/docs/api_docs/fmcib/utils/download_utils.html +++ /dev/null @@ -1,132 +0,0 @@ - - - - - - -fmcib.utils.download_utils API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.utils.download_utils

-
-
-
- -Expand source code - -
import sys
-
-
-# create this bar_progress method which is invoked automatically from wget
-def bar_progress(current, total, width=80):
-    """
-    Display a progress bar for a download.
-
-    Args:
-        current (int): The current progress value.
-        total (int): The total progress value.
-        width (int, optional): The width of the progress bar in characters. Defaults to 80.
-
-    Raises:
-        None
-
-    Returns:
-        None
-    """
-    progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
-    # Don't use print() as it will print in new line every time.
-    sys.stdout.write("\r" + progress_message)
-    sys.stdout.flush()
-
-
-
-
-
-
-
-

Functions

-
-
-def bar_progress(current, total, width=80) -
-
-

Display a progress bar for a download.

-

Args

-
-
current : int
-
The current progress value.
-
total : int
-
The total progress value.
-
width : int, optional
-
The width of the progress bar in characters. Defaults to 80.
-
-

Raises

-

None

-

Returns

-

None

-
- -Expand source code - -
def bar_progress(current, total, width=80):
-    """
-    Display a progress bar for a download.
-
-    Args:
-        current (int): The current progress value.
-        total (int): The total progress value.
-        width (int, optional): The width of the progress bar in characters. Defaults to 80.
-
-    Raises:
-        None
-
-    Returns:
-        None
-    """
-    progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total)
-    # Don't use print() as it will print in new line every time.
-    sys.stdout.write("\r" + progress_message)
-    sys.stdout.flush()
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/utils/idc_helper.html b/docs/api_docs/fmcib/utils/idc_helper.html deleted file mode 100644 index 44d6edd..0000000 --- a/docs/api_docs/fmcib/utils/idc_helper.html +++ /dev/null @@ -1,731 +0,0 @@ - - - - - - -fmcib.utils.idc_helper API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.utils.idc_helper

-
-
-
- -Expand source code - -
import concurrent.futures
-import os
-import sys
-from pathlib import Path
-
-import numpy as np
-import pandas as pd
-import pydicom
-import pydicom_seg
-import SimpleITK as sitk
-import wget
-
-
-class SuppressPrint:
-    """
-    A class that temporarily suppresses print statements.
-
-    Methods:
-        __enter__(): Sets sys.stdout to a dummy file object, suppressing print output.
-        __exit__(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.
-    """
-
-    def __enter__(self):
-        """
-        Enter the context manager and redirect the standard output to nothing.
-
-        Returns:
-            object: The context manager object.
-
-        Notes:
-            This context manager is used to redirect the standard output to nothing using the `open` function.
-            It saves the original standard output and assigns a new output destination as `/dev/null` on Unix-like systems.
-        """
-        self._original_stdout = sys.stdout
-        sys.stdout = open(os.devnull, "w")
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        """
-        Restores the original stdout and closes the modified stdout.
-
-        Args:
-            exc_type (type): The exception type, if an exception occurred. Otherwise, None.
-            exc_val (Exception): The exception instance, if an exception occurred. Otherwise, None.
-            exc_tb (traceback): The traceback object, if an exception occurred. Otherwise, None.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        sys.stdout.close()
-        sys.stdout = self._original_stdout
-
-
-with SuppressPrint():
-    from dcmrtstruct2nii import dcmrtstruct2nii
-    from dcmrtstruct2nii.adapters.input.image.dcminputadapter import DcmInputAdapter
-    from dcmrtstruct2nii.adapters.output.niioutputadapter import NiiOutputAdapter
-
-from google.cloud import storage
-from loguru import logger
-from tqdm import tqdm
-
-
-def dcmseg2nii(dcmseg_path, output_dir, tag=""):
-    """
-    Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.
-
-    Args:
-        dcmseg_path (str): The file path of the DICOM Segmentation object.
-        output_dir (str): The directory where the NIfTI files will be saved.
-        tag (str, optional): An optional tag to prepend to the output file names. Defaults to "".
-    """
-    dcm = pydicom.dcmread(dcmseg_path)
-    reader = pydicom_seg.SegmentReader()
-    result = reader.read(dcm)
-
-    for segment_number in result.available_segments:
-        image = result.segment_image(segment_number)  # lazy construction
-        sitk.WriteImage(image, output_dir + f"/{tag}{segment_number}.nii.gz", True)
-
-
-def download_from_manifest(df, save_dir, samples):
-    """
-    Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.
-
-    Parameters:
-        df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files.
-        save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved.
-        samples (int): The number of random samples to download. If None, all available samples will be downloaded.
-
-    Returns:
-        None
-    """
-    # Instantiates a client
-    storage_client = storage.Client()
-    logger.info("Downloading DICOM data from IDC (Imaging Data Commons) ...")
-    (save_dir / "dicom").mkdir(exist_ok=True, parents=True)
-
-    if samples is not None:
-        assert "PatientID" in df.columns
-        rows_with_annotations = df[df["Modality"].isin(["RTSTRUCT", "SEG"])]
-        unique_elements = rows_with_annotations["PatientID"].unique()
-        selected_elements = np.random.choice(unique_elements, min(len(unique_elements), samples), replace=False)
-        df = df[df["PatientID"].isin(selected_elements)]
-
-    def download_file(row):
-        """
-        Download a file from Google Cloud Storage.
-
-        Args:
-            row (dict): A dictionary containing the row data.
-
-        Raises:
-            None
-
-        Returns:
-            None
-        """
-        bucket_name, directory, file = row["gcs_url"].split("/")[-3:]
-        fn = f"{directory}/{file}"
-        bucket = storage_client.bucket(bucket_name)
-        blob = bucket.blob(fn)
-
-        current_save_dir = save_dir / "dicom" / row["PatientID"] / row["StudyInstanceUID"]
-        current_save_dir.mkdir(exist_ok=True, parents=True)
-        blob.download_to_filename(
-            str(current_save_dir / f'{row["Modality"]}_{row["SeriesInstanceUID"]}_{row["InstanceNumber"]}.dcm')
-        )
-
-    with concurrent.futures.ThreadPoolExecutor() as executor:
-        futures = []
-        for idx, row in df.iterrows():
-            futures.append(executor.submit(download_file, row))
-        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
-            pass
-
-
-def download_LUNG1(path, samples=None):
-    """
-    Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.
-
-    Parameters:
-        path (str): The directory path where the LUNG1 data manifest will be saved.
-        samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.
-
-    Returns:
-        None
-    """
-    save_dir = Path(path).resolve()
-    save_dir.mkdir(exist_ok=True, parents=True)
-
-    logger.info("Downloading LUNG1 manifest from Dropbox ...")
-    # Download LUNG1 data manifest, this is precomputed but any set of GCS dicom files can be used here
-    wget.download(
-        "https://www.dropbox.com/s/lkvv33nmepecyu5/nsclc_radiomics.csv?dl=1",
-        out=f"{save_dir}/nsclc_radiomics.csv",
-    )
-
-    df = pd.read_csv(f"{save_dir}/nsclc_radiomics.csv")
-
-    download_from_manifest(df, save_dir, samples)
-
-
-def download_RADIO(path, samples=None):
-    """
-    Downloads the RADIO manifest from Dropbox and saves it to the specified path.
-
-    Args:
-        path (str): The path where the manifest file will be saved.
-        samples (list, optional): A list of sample names to download. If None, all samples will be downloaded.
-
-    Returns:
-        None
-    """
-    save_dir = Path(path).resolve()
-    save_dir.mkdir(exist_ok=True, parents=True)
-
-    logger.info("Downloading RADIO manifest from Dropbox ...")
-    # Download RADIO data manifest, this is precomputed but any set of GCS dicom files can be used here
-    wget.download(
-        "https://www.dropbox.com/s/nhh1tb0rclrb7mw/nsclc_radiogenomics.csv?dl=1",
-        out=f"{save_dir}/nsclc_radiogenomics.csv",
-    )
-
-    df = pd.read_csv(f"{save_dir}/nsclc_radiogenomics.csv")
-
-    download_from_manifest(df, save_dir, samples)
-
-
-def process_series_dir(series_dir):
-    """
-    Process the series directory and extract relevant information.
-
-    Args:
-        series_dir (Path): The path to the series directory.
-
-    Returns:
-        dict: A dictionary containing the extracted information, including the image path, patient ID, and coordinates.
-
-    Raises:
-        None
-    """
-    # Check if RTSTRUCT file exists
-    rtstuct_files = list(series_dir.glob("*RTSTRUCT*"))
-    seg_files = list(series_dir.glob("*SEG*"))
-
-    if len(rtstuct_files) != 0:
-        dcmrtstruct2nii(str(rtstuct_files[0]), str(series_dir), str(series_dir))
-
-    elif len(seg_files) != 0:
-        dcmseg2nii(str(seg_files[0]), str(series_dir), tag="GTV-")
-
-        series_id = str(list(series_dir.glob("CT*.dcm"))[0]).split("_")[-2]
-        dicom_image = DcmInputAdapter().ingest(str(series_dir), series_id=series_id)
-        nii_output_adapter = NiiOutputAdapter()
-        nii_output_adapter.write(dicom_image, f"{series_dir}/image", gzip=True)
-    else:
-        logger.warning("Skipped file without any RTSTRUCT or SEG file")
-        return None
-
-    image = sitk.ReadImage(str(series_dir / "image.nii.gz"))
-    mask = sitk.ReadImage(str(list(series_dir.glob("*GTV-1*"))[0]))
-
-    # Get centroid from label shape filter
-    label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
-    label_shape_filter.Execute(mask)
-
-    try:
-        centroid = label_shape_filter.GetCentroid(255)
-    except:
-        centroid = label_shape_filter.GetCentroid(1)
-
-    x, y, z = centroid
-
-    row = {
-        "image_path": str(series_dir / "image.nii.gz"),
-        "PatientID": series_dir.parent.name,
-        "coordX": x,
-        "coordY": y,
-        "coordZ": z,
-    }
-
-    return row
-
-
-def build_image_seed_dict(path, samples=None):
-    """
-    Build a dictionary of image seeds from DICOM files.
-
-    Args:
-        path (str): The path to the directory containing DICOM files.
-        samples (int, optional): The number of samples to process. If None, all samples will be processed.
-
-    Returns:
-        pd.DataFrame: A DataFrame containing the image seeds.
-    """
-    sorted_dir = Path(path).resolve()
-    series_dirs = [x.parent for x in sorted_dir.rglob("*.dcm")]
-    series_dirs = sorted(list(set(series_dirs)))
-
-    logger.info("Converting DICOM files to NIFTI ...")
-
-    if samples is None:
-        samples = len(series_dirs)
-
-    rows = []
-
-    num_workers = os.cpu_count()  # Adjust this value based on the number of available CPU cores
-    with concurrent.futures.ProcessPoolExecutor(num_workers) as executor:
-        processed_rows = list(tqdm(executor.map(process_series_dir, series_dirs[:samples]), total=samples))
-
-    rows = [row for row in processed_rows if row]
-    return pd.DataFrame(rows)
-
-
-
-
-
-
-
-

Functions

-
-
-def build_image_seed_dict(path, samples=None) -
-
-

Build a dictionary of image seeds from DICOM files.

-

Args

-
-
path : str
-
The path to the directory containing DICOM files.
-
samples : int, optional
-
The number of samples to process. If None, all samples will be processed.
-
-

Returns

-
-
pd.DataFrame
-
A DataFrame containing the image seeds.
-
-
- -Expand source code - -
def build_image_seed_dict(path, samples=None):
-    """
-    Build a dictionary of image seeds from DICOM files.
-
-    Args:
-        path (str): The path to the directory containing DICOM files.
-        samples (int, optional): The number of samples to process. If None, all samples will be processed.
-
-    Returns:
-        pd.DataFrame: A DataFrame containing the image seeds.
-    """
-    sorted_dir = Path(path).resolve()
-    series_dirs = [x.parent for x in sorted_dir.rglob("*.dcm")]
-    series_dirs = sorted(list(set(series_dirs)))
-
-    logger.info("Converting DICOM files to NIFTI ...")
-
-    if samples is None:
-        samples = len(series_dirs)
-
-    rows = []
-
-    num_workers = os.cpu_count()  # Adjust this value based on the number of available CPU cores
-    with concurrent.futures.ProcessPoolExecutor(num_workers) as executor:
-        processed_rows = list(tqdm(executor.map(process_series_dir, series_dirs[:samples]), total=samples))
-
-    rows = [row for row in processed_rows if row]
-    return pd.DataFrame(rows)
-
-
-
-def dcmseg2nii(dcmseg_path, output_dir, tag='') -
-
-

Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.

-

Args

-
-
dcmseg_path : str
-
The file path of the DICOM Segmentation object.
-
output_dir : str
-
The directory where the NIfTI files will be saved.
-
tag : str, optional
-
An optional tag to prepend to the output file names. Defaults to "".
-
-
- -Expand source code - -
def dcmseg2nii(dcmseg_path, output_dir, tag=""):
-    """
-    Convert a DICOM Segmentation object to NIfTI format and save the resulting segment images.
-
-    Args:
-        dcmseg_path (str): The file path of the DICOM Segmentation object.
-        output_dir (str): The directory where the NIfTI files will be saved.
-        tag (str, optional): An optional tag to prepend to the output file names. Defaults to "".
-    """
-    dcm = pydicom.dcmread(dcmseg_path)
-    reader = pydicom_seg.SegmentReader()
-    result = reader.read(dcm)
-
-    for segment_number in result.available_segments:
-        image = result.segment_image(segment_number)  # lazy construction
-        sitk.WriteImage(image, output_dir + f"/{tag}{segment_number}.nii.gz", True)
-
-
-
-def download_LUNG1(path, samples=None) -
-
-

Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.

-

Parameters

-

path (str): The directory path where the LUNG1 data manifest will be saved. -samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.

-

Returns

-

None

-
- -Expand source code - -
def download_LUNG1(path, samples=None):
-    """
-    Downloads the LUNG1 data manifest from Dropbox and saves it to the specified path.
-
-    Parameters:
-        path (str): The directory path where the LUNG1 data manifest will be saved.
-        samples (list, optional): A list of specific samples to download. If None, all samples will be downloaded.
-
-    Returns:
-        None
-    """
-    save_dir = Path(path).resolve()
-    save_dir.mkdir(exist_ok=True, parents=True)
-
-    logger.info("Downloading LUNG1 manifest from Dropbox ...")
-    # Download LUNG1 data manifest, this is precomputed but any set of GCS dicom files can be used here
-    wget.download(
-        "https://www.dropbox.com/s/lkvv33nmepecyu5/nsclc_radiomics.csv?dl=1",
-        out=f"{save_dir}/nsclc_radiomics.csv",
-    )
-
-    df = pd.read_csv(f"{save_dir}/nsclc_radiomics.csv")
-
-    download_from_manifest(df, save_dir, samples)
-
-
-
-def download_RADIO(path, samples=None) -
-
-

Downloads the RADIO manifest from Dropbox and saves it to the specified path.

-

Args

-
-
path : str
-
The path where the manifest file will be saved.
-
samples : list, optional
-
A list of sample names to download. If None, all samples will be downloaded.
-
-

Returns

-

None

-
- -Expand source code - -
def download_RADIO(path, samples=None):
-    """
-    Downloads the RADIO manifest from Dropbox and saves it to the specified path.
-
-    Args:
-        path (str): The path where the manifest file will be saved.
-        samples (list, optional): A list of sample names to download. If None, all samples will be downloaded.
-
-    Returns:
-        None
-    """
-    save_dir = Path(path).resolve()
-    save_dir.mkdir(exist_ok=True, parents=True)
-
-    logger.info("Downloading RADIO manifest from Dropbox ...")
-    # Download RADIO data manifest, this is precomputed but any set of GCS dicom files can be used here
-    wget.download(
-        "https://www.dropbox.com/s/nhh1tb0rclrb7mw/nsclc_radiogenomics.csv?dl=1",
-        out=f"{save_dir}/nsclc_radiogenomics.csv",
-    )
-
-    df = pd.read_csv(f"{save_dir}/nsclc_radiogenomics.csv")
-
-    download_from_manifest(df, save_dir, samples)
-
-
-
-def download_from_manifest(df, save_dir, samples) -
-
-

Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.

-

Parameters

-

df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files. -save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved. -samples (int): The number of random samples to download. If None, all available samples will be downloaded.

-

Returns

-

None

-
- -Expand source code - -
def download_from_manifest(df, save_dir, samples):
-    """
-    Downloads DICOM data from IDC (Imaging Data Commons) based on the provided manifest.
-
-    Parameters:
-        df (pandas.DataFrame): The manifest DataFrame containing information about the DICOM files.
-        save_dir (pathlib.Path): The directory where the downloaded DICOM files will be saved.
-        samples (int): The number of random samples to download. If None, all available samples will be downloaded.
-
-    Returns:
-        None
-    """
-    # Instantiates a client
-    storage_client = storage.Client()
-    logger.info("Downloading DICOM data from IDC (Imaging Data Commons) ...")
-    (save_dir / "dicom").mkdir(exist_ok=True, parents=True)
-
-    if samples is not None:
-        assert "PatientID" in df.columns
-        rows_with_annotations = df[df["Modality"].isin(["RTSTRUCT", "SEG"])]
-        unique_elements = rows_with_annotations["PatientID"].unique()
-        selected_elements = np.random.choice(unique_elements, min(len(unique_elements), samples), replace=False)
-        df = df[df["PatientID"].isin(selected_elements)]
-
-    def download_file(row):
-        """
-        Download a file from Google Cloud Storage.
-
-        Args:
-            row (dict): A dictionary containing the row data.
-
-        Raises:
-            None
-
-        Returns:
-            None
-        """
-        bucket_name, directory, file = row["gcs_url"].split("/")[-3:]
-        fn = f"{directory}/{file}"
-        bucket = storage_client.bucket(bucket_name)
-        blob = bucket.blob(fn)
-
-        current_save_dir = save_dir / "dicom" / row["PatientID"] / row["StudyInstanceUID"]
-        current_save_dir.mkdir(exist_ok=True, parents=True)
-        blob.download_to_filename(
-            str(current_save_dir / f'{row["Modality"]}_{row["SeriesInstanceUID"]}_{row["InstanceNumber"]}.dcm')
-        )
-
-    with concurrent.futures.ThreadPoolExecutor() as executor:
-        futures = []
-        for idx, row in df.iterrows():
-            futures.append(executor.submit(download_file, row))
-        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
-            pass
-
-
-
-def process_series_dir(series_dir) -
-
-

Process the series directory and extract relevant information.

-

Args

-
-
series_dir : Path
-
The path to the series directory.
-
-

Returns

-
-
dict
-
A dictionary containing the extracted information, including the image path, patient ID, and coordinates.
-
-

Raises

-

None

-
- -Expand source code - -
def process_series_dir(series_dir):
-    """
-    Process the series directory and extract relevant information.
-
-    Args:
-        series_dir (Path): The path to the series directory.
-
-    Returns:
-        dict: A dictionary containing the extracted information, including the image path, patient ID, and coordinates.
-
-    Raises:
-        None
-    """
-    # Check if RTSTRUCT file exists
-    rtstuct_files = list(series_dir.glob("*RTSTRUCT*"))
-    seg_files = list(series_dir.glob("*SEG*"))
-
-    if len(rtstuct_files) != 0:
-        dcmrtstruct2nii(str(rtstuct_files[0]), str(series_dir), str(series_dir))
-
-    elif len(seg_files) != 0:
-        dcmseg2nii(str(seg_files[0]), str(series_dir), tag="GTV-")
-
-        series_id = str(list(series_dir.glob("CT*.dcm"))[0]).split("_")[-2]
-        dicom_image = DcmInputAdapter().ingest(str(series_dir), series_id=series_id)
-        nii_output_adapter = NiiOutputAdapter()
-        nii_output_adapter.write(dicom_image, f"{series_dir}/image", gzip=True)
-    else:
-        logger.warning("Skipped file without any RTSTRUCT or SEG file")
-        return None
-
-    image = sitk.ReadImage(str(series_dir / "image.nii.gz"))
-    mask = sitk.ReadImage(str(list(series_dir.glob("*GTV-1*"))[0]))
-
-    # Get centroid from label shape filter
-    label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
-    label_shape_filter.Execute(mask)
-
-    try:
-        centroid = label_shape_filter.GetCentroid(255)
-    except:
-        centroid = label_shape_filter.GetCentroid(1)
-
-    x, y, z = centroid
-
-    row = {
-        "image_path": str(series_dir / "image.nii.gz"),
-        "PatientID": series_dir.parent.name,
-        "coordX": x,
-        "coordY": y,
-        "coordZ": z,
-    }
-
-    return row
-
-
-
-
-
-

Classes

-
-
-class SuppressPrint -
-
-

A class that temporarily suppresses print statements.

-

Methods

-

enter(): Sets sys.stdout to a dummy file object, suppressing print output. -exit(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.

-
- -Expand source code - -
class SuppressPrint:
-    """
-    A class that temporarily suppresses print statements.
-
-    Methods:
-        __enter__(): Sets sys.stdout to a dummy file object, suppressing print output.
-        __exit__(exc_type, exc_val, exc_tb): Restores sys.stdout to its original value.
-    """
-
-    def __enter__(self):
-        """
-        Enter the context manager and redirect the standard output to nothing.
-
-        Returns:
-            object: The context manager object.
-
-        Notes:
-            This context manager is used to redirect the standard output to nothing using the `open` function.
-            It saves the original standard output and assigns a new output destination as `/dev/null` on Unix-like systems.
-        """
-        self._original_stdout = sys.stdout
-        sys.stdout = open(os.devnull, "w")
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        """
-        Restores the original stdout and closes the modified stdout.
-
-        Args:
-            exc_type (type): The exception type, if an exception occurred. Otherwise, None.
-            exc_val (Exception): The exception instance, if an exception occurred. Otherwise, None.
-            exc_tb (traceback): The traceback object, if an exception occurred. Otherwise, None.
-
-        Returns:
-            None
-
-        Raises:
-            None
-        """
-        sys.stdout.close()
-        sys.stdout = self._original_stdout
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/utils/index.html b/docs/api_docs/fmcib/utils/index.html deleted file mode 100644 index 3048cf4..0000000 --- a/docs/api_docs/fmcib/utils/index.html +++ /dev/null @@ -1,77 +0,0 @@ - - - - - - -fmcib.utils API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.utils

-
-
-
- -Expand source code - -
from .download_utils import bar_progress
-from .idc_helper import *
-
-
-
-

Sub-modules

-
-
fmcib.utils.download_utils
-
-
-
-
fmcib.utils.idc_helper
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/visualization/index.html b/docs/api_docs/fmcib/visualization/index.html deleted file mode 100644 index 4430606..0000000 --- a/docs/api_docs/fmcib/visualization/index.html +++ /dev/null @@ -1,71 +0,0 @@ - - - - - - -fmcib.visualization API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.visualization

-
-
-
- -Expand source code - -
from .verify_io import visualize_seed_point
-
-
-
-

Sub-modules

-
-
fmcib.visualization.verify_io
-
-
-
-
-
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/api_docs/fmcib/visualization/verify_io.html b/docs/api_docs/fmcib/visualization/verify_io.html deleted file mode 100644 index d59bd61..0000000 --- a/docs/api_docs/fmcib/visualization/verify_io.html +++ /dev/null @@ -1,244 +0,0 @@ - - - - - - -fmcib.visualization.verify_io API documentation - - - - - - - - - - - -
-
-
-

Module fmcib.visualization.verify_io

-
-
-
- -Expand source code - -
import matplotlib.pyplot as plt
-import monai.transforms as monai_transforms
-import numpy as np
-import torch
-from monai.visualize import blend_images
-
-
-def visualize_seed_point(row):
-    """
-    This function visualizes a seed point on an image.
-
-    Args:
-        row (pandas.Series): A row containing the information of the seed point, including the image path and the coordinates.
-            The following columns are expected: "image_path", "coordX", "coordY", "coordZ".
-
-    Returns:
-        None
-    """
-    # Define the transformation pipeline
-    T = monai_transforms.Compose(
-        [
-            monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
-            monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-            monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
-            monai_transforms.ScaleIntensityRanged(keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True),
-            monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
-            monai_transforms.SelectItemsd(keys=["image_path", "coordX", "coordY", "coordZ"]),
-        ]
-    )
-
-    # Apply the transformation pipeline
-    out = T(row)
-
-    # Calculate the center of the image
-    center = (-out["coordX"], -out["coordY"], out["coordZ"])
-    center = np.linalg.inv(np.array(out["image_path"].affine)) @ np.array(center + (1,))
-    center = [int(x) for x in center[:3]]
-
-    # Define the image and label
-    image = out["image_path"]
-    label = torch.zeros_like(image)
-
-    # Define the dimensions of the image and the patch
-    C, H, W, D = image.shape
-    Ph, Pw, Pd = 50, 50, 50
-
-    # Calculate and clamp the ranges for cropping
-    min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
-    min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
-    min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
-    # Check if coordinates are valid
-    assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
-    assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
-    assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
-    # Define the label for the cropped region
-    label[:, min_h:max_h, min_w:max_w, min_d:max_d] = 1
-
-    # Blend the image and the label
-    ret = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
-    ret = ret.permute(3, 2, 1, 0)
-
-    # Plot axial slice
-    plt.figure(figsize=(10, 10))
-    plt.subplot(1, 3, 1)
-    plt.imshow(ret[center[2], :, :])
-    plt.title("Axial")
-    plt.axis("off")
-
-    # Plot sagittal slice
-    plt.subplot(1, 3, 2)
-    plt.imshow(np.flipud(ret[:, center[1], :]))
-    plt.title("Coronal")
-    plt.axis("off")
-
-    # Plot coronal slice
-    plt.subplot(1, 3, 3)
-    plt.imshow(np.flipud(ret[:, :, center[0]]))
-    plt.title("Sagittal")
-
-    plt.axis("off")
-    plt.show()
-
-
-
-
-
-
-
-

Functions

-
-
-def visualize_seed_point(row) -
-
-

This function visualizes a seed point on an image.

-

Args

-
-
row : pandas.Series
-
A row containing the information of the seed point, including the image path and the coordinates. -The following columns are expected: "image_path", "coordX", "coordY", "coordZ".
-
-

Returns

-

None

-
- -Expand source code - -
def visualize_seed_point(row):
-    """
-    This function visualizes a seed point on an image.
-
-    Args:
-        row (pandas.Series): A row containing the information of the seed point, including the image path and the coordinates.
-            The following columns are expected: "image_path", "coordX", "coordY", "coordZ".
-
-    Returns:
-        None
-    """
-    # Define the transformation pipeline
-    T = monai_transforms.Compose(
-        [
-            monai_transforms.LoadImaged(keys=["image_path"], image_only=True, reader="ITKReader"),
-            monai_transforms.EnsureChannelFirstd(keys=["image_path"]),
-            monai_transforms.Spacingd(keys=["image_path"], pixdim=1, mode="bilinear", align_corners=True, diagonal=True),
-            monai_transforms.ScaleIntensityRanged(keys=["image_path"], a_min=-1024, a_max=3072, b_min=0, b_max=1, clip=True),
-            monai_transforms.Orientationd(keys=["image_path"], axcodes="LPS"),
-            monai_transforms.SelectItemsd(keys=["image_path", "coordX", "coordY", "coordZ"]),
-        ]
-    )
-
-    # Apply the transformation pipeline
-    out = T(row)
-
-    # Calculate the center of the image
-    center = (-out["coordX"], -out["coordY"], out["coordZ"])
-    center = np.linalg.inv(np.array(out["image_path"].affine)) @ np.array(center + (1,))
-    center = [int(x) for x in center[:3]]
-
-    # Define the image and label
-    image = out["image_path"]
-    label = torch.zeros_like(image)
-
-    # Define the dimensions of the image and the patch
-    C, H, W, D = image.shape
-    Ph, Pw, Pd = 50, 50, 50
-
-    # Calculate and clamp the ranges for cropping
-    min_h, max_h = max(center[0] - Ph // 2, 0), min(center[0] + Ph // 2, H)
-    min_w, max_w = max(center[1] - Pw // 2, 0), min(center[1] + Pw // 2, W)
-    min_d, max_d = max(center[2] - Pd // 2, 0), min(center[2] + Pd // 2, D)
-
-    # Check if coordinates are valid
-    assert min_h < max_h, "Invalid coordinates: min_h >= max_h"
-    assert min_w < max_w, "Invalid coordinates: min_w >= max_w"
-    assert min_d < max_d, "Invalid coordinates: min_d >= max_d"
-
-    # Define the label for the cropped region
-    label[:, min_h:max_h, min_w:max_w, min_d:max_d] = 1
-
-    # Blend the image and the label
-    ret = blend_images(image=image, label=label, alpha=0.3, cmap="hsv", rescale_arrays=False)
-    ret = ret.permute(3, 2, 1, 0)
-
-    # Plot axial slice
-    plt.figure(figsize=(10, 10))
-    plt.subplot(1, 3, 1)
-    plt.imshow(ret[center[2], :, :])
-    plt.title("Axial")
-    plt.axis("off")
-
-    # Plot sagittal slice
-    plt.subplot(1, 3, 2)
-    plt.imshow(np.flipud(ret[:, center[1], :]))
-    plt.title("Coronal")
-    plt.axis("off")
-
-    # Plot coronal slice
-    plt.subplot(1, 3, 3)
-    plt.imshow(np.flipud(ret[:, :, center[0]]))
-    plt.title("Sagittal")
-
-    plt.axis("off")
-    plt.show()
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index da8d051..065c13f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -20,6 +20,27 @@ theme: accent: blue extra_css: - assets/extra.css + + +# Plugins +# Auto API reference generation: https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages +plugins: + - search + - autorefs + - gen-files: + scripts: + - scripts/generate_api_reference_pages.py + - literate-nav: + nav_file: SUMMARY.md + - section-index + - mkdocstrings: + handlers: + python: + docstring_style: google + options: + # Removed the default filter that excludes private members (that is, members whose names start with a single underscore). + filters: null + nav: - 'index.md' - 'Getting Started': @@ -34,7 +55,7 @@ nav: - 'Reproduce Analysis': 'user-guide/analysis.md' # - 'Training baselines': 'user-guide/reproduce_baselines.md' - 'Tutorials': https://github.com/AIM-Harvard/foundation-cancer-image-biomarker/tree/master/tutorials - - 'API Reference': 'api_docs/fmcib' + - 'API Reference': 'reference/' markdown_extensions: - pymdownx.highlight: @@ -47,7 +68,11 @@ markdown_extensions: - admonition - pymdownx.details - pymdownx.superfences + repo_url: https://github.com/AIM-Harvard/foundation-cancer-image-biomarker +repo_name: AIM-Harvard/foundation-cancer-image-biomarker copyright: AIM © 2023 extra: generator: false + + diff --git a/scripts/generate_api_reference_pages.py b/scripts/generate_api_reference_pages.py new file mode 100644 index 0000000..98d6ec9 --- /dev/null +++ b/scripts/generate_api_reference_pages.py @@ -0,0 +1,51 @@ +"""Generate the code reference pages and navigation automatically. +Modified from https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages. +""" + +from pathlib import Path + +import mkdocs_gen_files + +PACKAGE = "fmcib" + +# Modules to exclude +EXCLUDE = [ + "fmcib.__init__", +] + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent +src = root / PACKAGE + +for path in sorted(src.rglob("*.py")): + print(f"Processing {path}") + module_path = path.relative_to(src).with_suffix("") + + module_py_notation = PACKAGE + "." + ".".join(module_path.parts) + if module_py_notation in EXCLUDE: + print(f"Excluding '{module_py_notation}' from the API reference.") + continue + + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path("reference", doc_path) + + parts = (PACKAGE, *module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1] == "__main__": + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav())