Skip to content

Commit

Permalink
add optional tissue masking to Reinhard
Browse files Browse the repository at this point in the history
  • Loading branch information
CielAl committed Jan 3, 2024
1 parent ae44bf6 commit a4f0c50
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 19 deletions.
40 changes: 36 additions & 4 deletions tests/images/test_functionals.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""borrow Staintool's test cases
"""
import unittest
import numpy as np
from tests.util import fix_seed, dummy_from_numpy, psnr
from torch_staintools.functional.conversion.lab import rgb_to_lab
from torch_staintools.functional.stain_extraction.macenko import MacenkoExtractor
from torch_staintools.functional.stain_extraction.vahadane import VahadaneExtractor
from torch_staintools.functional.optimization.dict_learning import get_concentrations
from torch_staintools.functional.tissue_mask import get_tissue_mask, TissueMaskException
from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration
from functools import partial
from torchvision.transforms.functional import convert_image_dtype
from skimage.util import img_as_float32
from torch_staintools.normalizer.reinhard import ReinhardNormalizer
import torch
import cv2
import os
Expand Down Expand Up @@ -114,3 +111,38 @@ def test_tissue_mask(self):

with self.assertRaises(TissueMaskException):
get_tissue_mask(torch.zeros_like(dummy_scaled), luminosity_threshold=0.8, throw_error=True)

@staticmethod
def mean_std_compare_squeezed(x, mask):
masked = x * mask
mean_list = []
std_list = []
for c in range(masked.shape[1]):
nonzero = masked[:, c: c + 1, :, :][mask != 0]
mean_list.append(nonzero.mean())
std_list.append(nonzero.std())
return torch.stack(mean_list).squeeze(), torch.stack(std_list).squeeze()

def test_reinhard(self):
device = TestFunctional.device
dummy_scaled = convert_image_dtype(TestFunctional.new_dummy_img_tensor_ubyte(), torch.float32).to(device)
# not None mask
mask = get_tissue_mask(dummy_scaled, luminosity_threshold=0.8)
# 1 x 3 x 1 x 1

means_input, stds_input = ReinhardNormalizer._mean_std_helper(dummy_scaled, mask=mask)

manual_mean, manual_std = TestFunctional.mean_std_compare_squeezed(dummy_scaled, mask)
self.assertTrue(torch.isclose(manual_mean, means_input.squeeze()).all())
self.assertTrue(torch.isclose(manual_std, stds_input.squeeze()).all())

# no mask
rand_dummy = torch.randn(dummy_scaled.shape, device=dummy_scaled.device)
rand_mean, rand_std = ReinhardNormalizer._mean_std_helper(rand_dummy, mask=None)

rand_mean_truth = rand_dummy.mean(dim=(2, 3), keepdim=True)
rand_std_truth = rand_dummy.std(dim=(2, 3), keepdim=True)

self.assertTrue(torch.isclose(rand_mean, rand_mean_truth).all())
self.assertTrue(torch.isclose(rand_std, rand_std_truth).all())

28 changes: 28 additions & 0 deletions torch_staintools/functional/utility/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,31 @@ def default_rng(rng: Optional[torch.Generator | int], device: Optional[torch.dev
return torch.Generator(device=device).manual_seed(rng)
assert isinstance(rng, torch.Generator)
return rng


def nanstd(data: torch.Tensor, dim: Optional[int | tuple] = None,
correction: float = 1) -> torch.Tensor:
"""Compute the standard deviation while ignoring NaNs.
Always keep the dim.
Args:
data: Input tensor.
dim: The dimension or dimensions to reduce. If None (default), reduces all dimensions.
correction: Difference between the sample size and sample degrees of freedom. Defaults 1 (Bessel's).
Returns:
torch.Tensor: Standard deviation with NaNs ignored. If `dim` is provided,
it reduces along that dimension(s), otherwise reduces all dimensions.
"""

non_nan_mask = ~torch.isnan(data)
# find not-nan element
non_nan_count = non_nan_mask.sum(dim=dim, keepdim=True)
# compute mean of not-nan elements
mean = torch.nanmean(data, dim=dim, keepdim=True)

# \Sigma (x - mean)^2 --> any x that is nan will be filtered by using nansum
sum_dev2 = ((data - mean) ** 2).nansum(dim=dim, keepdim=True)
# sqrt and normalize by corrected degrees of freedom
return torch.sqrt(sum_dev2 / (non_nan_count - correction))
2 changes: 1 addition & 1 deletion torch_staintools/normalizer/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build(method: TYPE_SUPPORTED,
norm_method: Callable
match method:
case 'reinhard':
return ReinhardNormalizer.build()
return ReinhardNormalizer.build(luminosity_threshold=luminosity_threshold)
case 'macenko' | 'vahadane':
return StainSeparation.build(method=method, reconst_method=reconst_method,
num_stains=num_stains,
Expand Down
45 changes: 32 additions & 13 deletions torch_staintools/normalizer/reinhard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from torch_staintools.functional.conversion.lab import rgb_to_lab, lab_to_rgb
from torch_staintools.normalizer.base import Normalizer
from typing import Tuple
from torch_staintools.functional.tissue_mask import get_tissue_mask
from typing import Tuple, Optional
from torch_staintools.functional.eps import get_eps
from torch_staintools.functional.utility.implementation import nanstd


class ReinhardNormalizer(Normalizer):
Expand All @@ -12,28 +14,40 @@ class ReinhardNormalizer(Normalizer):
"""
target_means: torch.Tensor
target_stds: torch.Tensor
luminosity_threshold: float

def __init__(self):
def __init__(self, luminosity_threshold: Optional[float]):
super().__init__(cache=None, device=None, rng=None)
self.luminosity_threshold = luminosity_threshold

@staticmethod
def _mean_std_helper(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _mean_std_helper(image: torch.Tensor, *, mask: Optional[torch.Tensor] = None)\
-> Tuple[torch.Tensor, torch.Tensor]:
"""Get the channel-wise mean and std of input
Args:
image: BCHW scaled to [0, 1] torch.float32. Usually in LAB.
mask: luminosity tissue mask of image. Mean and std are only computed within the tissue regions.
Returns:
means,
"""
assert mask is None or mask.dtype is torch.bool, f"{mask.dtype}"
assert image.ndimension() == 4, f"{image.shape}"
means = image.mean(dim=(2, 3)).unsqueeze(-1).unsqueeze(-1)
stds = image.std(dim=(2, 3)).unsqueeze(-1).unsqueeze(-1)
if mask is None:
mask = torch.ones_like(image, dtype=torch.bool)

image_masked = image * mask
image_masked[mask.expand_as(image_masked) == 0] = torch.nan
means = image_masked.nanmean(dim=(2, 3), keepdim=True)
stds = nanstd(image_masked, dim=(2, 3))
return means, stds

def fit(self, image: torch.Tensor):
"""Fit - compute the means and stds of template in lab space.
Statistics are computed within tissue regions if a luminosity threshold is given to the normalizer upon
creation.
Args:
image: template. BCHW. [0, 1] torch.float32.
Expand All @@ -43,26 +57,28 @@ def fit(self, image: torch.Tensor):
# BCHW
img_lab: torch.Tensor = rgb_to_lab(image)
assert img_lab.ndimension() == 4 and img_lab.shape[1] == 3, f"{img_lab.shape}"
mask = get_tissue_mask(image, luminosity_threshold=self.luminosity_threshold)
# B1HW
# 1, C, 1, 1
means, stds = ReinhardNormalizer._mean_std_helper(img_lab)
means, stds = ReinhardNormalizer._mean_std_helper(img_lab, mask=mask)

self.register_buffer('target_means', means)
self.register_buffer('target_stds', stds)

@staticmethod
def normalize_helper(image: torch.Tensor, target_means: torch.Tensor, target_stds: torch.Tensor):
def normalize_helper(image: torch.Tensor, target_means: torch.Tensor, target_stds: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""Helper.
Args:
image: BCHW format. torch.float32 type in range [0, 1].
target_means: channel-wise means of template
target_stds: channel-wise stds of template
mask: Optional luminosity tissue mask to compute the stats within masked region
Returns:
"""
means_input, stds_input = ReinhardNormalizer._mean_std_helper(image)
means_input, stds_input = ReinhardNormalizer._mean_std_helper(image, mask=mask)
return (image - means_input) * (target_stds / (stds_input + get_eps(image))) + target_means

def transform(self, x: torch.Tensor, *args, **kwargs):
Expand All @@ -80,12 +96,15 @@ def transform(self, x: torch.Tensor, *args, **kwargs):
"""
# 1 C 1 1
lab_input = rgb_to_lab(x)
normalized_lab = ReinhardNormalizer.normalize_helper(lab_input, self.target_means, self.target_stds)
mask = get_tissue_mask(x, luminosity_threshold=self.luminosity_threshold, throw_error=False,
true_when_empty=True)
normalized_lab = ReinhardNormalizer.normalize_helper(lab_input, self.target_means, self.target_stds,
mask)
return lab_to_rgb(normalized_lab).clamp_(0, 1)

def forward(self, x: torch.Tensor, *args, **kwargs):
return self.transform(x)

@classmethod
def build(cls, *args, **kwargs):
return cls()
def build(cls, luminosity_threshold: Optional[float] = None, **kwargs):
return cls(luminosity_threshold=luminosity_threshold)
2 changes: 1 addition & 1 deletion torch_staintools/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.2'
__version__ = '1.0.3'

0 comments on commit a4f0c50

Please sign in to comment.