diff --git a/ahcore/transforms/augmentations.py b/ahcore/transforms/augmentations.py index 411b479..1c336d3 100644 --- a/ahcore/transforms/augmentations.py +++ b/ahcore/transforms/augmentations.py @@ -8,6 +8,7 @@ from typing import cast import kornia.augmentation as K +from kornia.augmentation import random_generator as rg import torch from kornia.constants import DataKey, Resample from omegaconf import ListConfig @@ -85,6 +86,85 @@ def forward(self, *args: torch.Tensor, **kwargs): return output +class HEDColorAugmentation(K.IntensityAugmentationBase2D): + """ + A torch implementation of the color stain augmentation algorithm on the + deconvolved Hemaetoxylin-Eosin-DAB (HED) channels of an image as described + by Tellez et al. (2018) in Appendix A & B here: https://arxiv.org/pdf/1808.05896.pdf. + """ + + # Normalized OD matrix from Ruifrok et al. (2001) + HED_REFERENCE = torch.Tensor([[0.65, 0.70, 0.29], + [0.07, 0.99, 0.11], + [0.27, 0.57, 0.78]]) + + def __init__(self, + scale_sigma: float | tuple[float, float, float], + bias_sigma: float | tuple[float, float, float], + epsilon: float = 1e-6, + p: float = 0.5, + p_batch: float = 1.0, + same_on_batch=False, + keepdim=False, + **kwargs) -> None: + """ + Apply a color stain augmentation in the Hemaetoxylin-Eosin-DAB (HED) color space based on [1]. + The fixed normalized OD matrix values are based on [2]. + Parameters + ---------- + scale_sigma: float or tuple of floats + For each channel in the HED space a random scaling factor is drawn from alpha_i ~ U(1-sigma_i,1+sigma_i). + bias_sigma: float or tuple of floats + For each channel in the HED space a random bias is added drawn from beta_i ~ U(-sigma_i,sigma_i). + epsilon: float + Small positive bias to avoid numerical errors + References + ---------- + [1] Tellez, David, et al. "Whole-slide mitosis detection in H&E breast histology using PHH3 as a reference to train distilled stain-invariant convolutional networks." IEEE transactions on medical imaging 37.9 (2018): 2126-2136. + [2] Ruifrok AC, Johnston DA. Quantification of histochemical staining by color deconvolution. Anal Quant Cytol Histol. 2001 Aug;23(4):291-9. PMID: 11531144. + """ + super().__init__(p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) + if (isinstance(scale_sigma, tuple) and len(scale_sigma) != 3) or (isinstance(bias_sigma, tuple) and len(bias_sigma) != 3): + raise ValueError(f"scale_sigma and bias_sigma should have either 1 or 3 values, got {scale_sigma} and {bias_sigma} instead.") + + if isinstance(scale_sigma, float): + scale_sigma = tuple([scale_sigma for _ in range(3)]) + + if isinstance(bias_sigma, float): + bias_sigma = tuple([bias_sigma for _ in range(3)]) + + scale_sigma = torch.tensor(scale_sigma) + bias_sigma = torch.tensor(bias_sigma) + + scale_factor = torch.stack([1-scale_sigma, 1+scale_sigma]) + bias_factor = torch.stack([-bias_sigma, bias_sigma]) + + self._param_generator = rg.PlainUniformGenerator((scale_factor, 'scale', None, None), (bias_factor, 'bias', None, None)) + self.flags = {'epsilon': torch.tensor([epsilon]), + 'M': self.HED_REFERENCE, + 'M_inv': torch.linalg.inv(self.HED_REFERENCE)} + + def apply_transform(self, sample: torch.Tensor, params = None, flags = None, transform = None, data_keys: list[str | int | DataKey] = None, **kwargs) -> torch.Tensor: + """ + Apply HED color augmentation on an input tensor. + """ + epsilon = flags['epsilon'].to(sample) + reference_matrix = flags['M'].to(sample) + reference_matrix_inv = flags['M_inv'].to(sample) + + alpha = params['scale'][:, None, None, :].to(sample) + beta = params['bias'][:, None, None, :].to(sample) + + rgb_tensor = sample.permute(0, 2, 3, 1) + optical_density = -torch.log(rgb_tensor + epsilon) + hed_tensor = optical_density @ reference_matrix_inv + + augmented_hed_tensor = torch.where(hed_tensor > epsilon, alpha * hed_tensor + beta, hed_tensor) + augmented_rgb_tensor = torch.exp(-augmented_hed_tensor @ reference_matrix) - epsilon + augmented_sample = augmented_rgb_tensor.permute(0, 3, 1, 2) + return augmented_sample + + class CenterCrop(nn.Module): """Perform a center crop of the image and target"""