Skip to content

Commit

Permalink
Implement HED Kornia augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
BPdeRooij authored Jul 11, 2023
1 parent b71c668 commit 04d166d
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions ahcore/transforms/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import cast

import kornia.augmentation as K
from kornia.augmentation import random_generator as rg
import torch
from kornia.constants import DataKey, Resample
from omegaconf import ListConfig
Expand Down Expand Up @@ -85,6 +86,85 @@ def forward(self, *args: torch.Tensor, **kwargs):
return output


class HEDColorAugmentation(K.IntensityAugmentationBase2D):
"""
A torch implementation of the color stain augmentation algorithm on the
deconvolved Hemaetoxylin-Eosin-DAB (HED) channels of an image as described
by Tellez et al. (2018) in Appendix A & B here: https://arxiv.org/pdf/1808.05896.pdf.
"""

# Normalized OD matrix from Ruifrok et al. (2001)
HED_REFERENCE = torch.Tensor([[0.65, 0.70, 0.29],
[0.07, 0.99, 0.11],
[0.27, 0.57, 0.78]])

def __init__(self,
scale_sigma: float | tuple[float, float, float],
bias_sigma: float | tuple[float, float, float],
epsilon: float = 1e-6,
p: float = 0.5,
p_batch: float = 1.0,
same_on_batch=False,
keepdim=False,
**kwargs) -> None:
"""
Apply a color stain augmentation in the Hemaetoxylin-Eosin-DAB (HED) color space based on [1].
The fixed normalized OD matrix values are based on [2].
Parameters
----------
scale_sigma: float or tuple of floats
For each channel in the HED space a random scaling factor is drawn from alpha_i ~ U(1-sigma_i,1+sigma_i).
bias_sigma: float or tuple of floats
For each channel in the HED space a random bias is added drawn from beta_i ~ U(-sigma_i,sigma_i).
epsilon: float
Small positive bias to avoid numerical errors
References
----------
[1] Tellez, David, et al. "Whole-slide mitosis detection in H&E breast histology using PHH3 as a reference to train distilled stain-invariant convolutional networks." IEEE transactions on medical imaging 37.9 (2018): 2126-2136.
[2] Ruifrok AC, Johnston DA. Quantification of histochemical staining by color deconvolution. Anal Quant Cytol Histol. 2001 Aug;23(4):291-9. PMID: 11531144.
"""
super().__init__(p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim)
if (isinstance(scale_sigma, tuple) and len(scale_sigma) != 3) or (isinstance(bias_sigma, tuple) and len(bias_sigma) != 3):
raise ValueError(f"scale_sigma and bias_sigma should have either 1 or 3 values, got {scale_sigma} and {bias_sigma} instead.")

if isinstance(scale_sigma, float):
scale_sigma = tuple([scale_sigma for _ in range(3)])

if isinstance(bias_sigma, float):
bias_sigma = tuple([bias_sigma for _ in range(3)])

scale_sigma = torch.tensor(scale_sigma)
bias_sigma = torch.tensor(bias_sigma)

scale_factor = torch.stack([1-scale_sigma, 1+scale_sigma])
bias_factor = torch.stack([-bias_sigma, bias_sigma])

self._param_generator = rg.PlainUniformGenerator((scale_factor, 'scale', None, None), (bias_factor, 'bias', None, None))
self.flags = {'epsilon': torch.tensor([epsilon]),
'M': self.HED_REFERENCE,
'M_inv': torch.linalg.inv(self.HED_REFERENCE)}

def apply_transform(self, sample: torch.Tensor, params = None, flags = None, transform = None, data_keys: list[str | int | DataKey] = None, **kwargs) -> torch.Tensor:
"""
Apply HED color augmentation on an input tensor.
"""
epsilon = flags['epsilon'].to(sample)
reference_matrix = flags['M'].to(sample)
reference_matrix_inv = flags['M_inv'].to(sample)

alpha = params['scale'][:, None, None, :].to(sample)
beta = params['bias'][:, None, None, :].to(sample)

rgb_tensor = sample.permute(0, 2, 3, 1)
optical_density = -torch.log(rgb_tensor + epsilon)
hed_tensor = optical_density @ reference_matrix_inv

augmented_hed_tensor = torch.where(hed_tensor > epsilon, alpha * hed_tensor + beta, hed_tensor)
augmented_rgb_tensor = torch.exp(-augmented_hed_tensor @ reference_matrix) - epsilon
augmented_sample = augmented_rgb_tensor.permute(0, 3, 1, 2)
return augmented_sample


class CenterCrop(nn.Module):
"""Perform a center crop of the image and target"""

Expand Down

0 comments on commit 04d166d

Please sign in to comment.