diff --git a/tests/images/test_functionals.py b/tests/images/test_functionals.py index ba774de..ce6b0eb 100644 --- a/tests/images/test_functionals.py +++ b/tests/images/test_functionals.py @@ -1,17 +1,14 @@ """borrow Staintool's test cases """ import unittest -import numpy as np from tests.util import fix_seed, dummy_from_numpy, psnr -from torch_staintools.functional.conversion.lab import rgb_to_lab from torch_staintools.functional.stain_extraction.macenko import MacenkoExtractor from torch_staintools.functional.stain_extraction.vahadane import VahadaneExtractor from torch_staintools.functional.optimization.dict_learning import get_concentrations from torch_staintools.functional.tissue_mask import get_tissue_mask, TissueMaskException from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration -from functools import partial from torchvision.transforms.functional import convert_image_dtype -from skimage.util import img_as_float32 +from torch_staintools.normalizer.reinhard import ReinhardNormalizer import torch import cv2 import os @@ -114,3 +111,38 @@ def test_tissue_mask(self): with self.assertRaises(TissueMaskException): get_tissue_mask(torch.zeros_like(dummy_scaled), luminosity_threshold=0.8, throw_error=True) + + @staticmethod + def mean_std_compare_squeezed(x, mask): + masked = x * mask + mean_list = [] + std_list = [] + for c in range(masked.shape[1]): + nonzero = masked[:, c: c + 1, :, :][mask != 0] + mean_list.append(nonzero.mean()) + std_list.append(nonzero.std()) + return torch.stack(mean_list).squeeze(), torch.stack(std_list).squeeze() + + def test_reinhard(self): + device = TestFunctional.device + dummy_scaled = convert_image_dtype(TestFunctional.new_dummy_img_tensor_ubyte(), torch.float32).to(device) + # not None mask + mask = get_tissue_mask(dummy_scaled, luminosity_threshold=0.8) + # 1 x 3 x 1 x 1 + + means_input, stds_input = ReinhardNormalizer._mean_std_helper(dummy_scaled, mask=mask) + + manual_mean, manual_std = TestFunctional.mean_std_compare_squeezed(dummy_scaled, mask) + self.assertTrue(torch.isclose(manual_mean, means_input.squeeze()).all()) + self.assertTrue(torch.isclose(manual_std, stds_input.squeeze()).all()) + + # no mask + rand_dummy = torch.randn(dummy_scaled.shape, device=dummy_scaled.device) + rand_mean, rand_std = ReinhardNormalizer._mean_std_helper(rand_dummy, mask=None) + + rand_mean_truth = rand_dummy.mean(dim=(2, 3), keepdim=True) + rand_std_truth = rand_dummy.std(dim=(2, 3), keepdim=True) + + self.assertTrue(torch.isclose(rand_mean, rand_mean_truth).all()) + self.assertTrue(torch.isclose(rand_std, rand_std_truth).all()) + diff --git a/torch_staintools/functional/utility/implementation.py b/torch_staintools/functional/utility/implementation.py index 28267f9..dcfd2b1 100644 --- a/torch_staintools/functional/utility/implementation.py +++ b/torch_staintools/functional/utility/implementation.py @@ -63,3 +63,31 @@ def default_rng(rng: Optional[torch.Generator | int], device: Optional[torch.dev return torch.Generator(device=device).manual_seed(rng) assert isinstance(rng, torch.Generator) return rng + + +def nanstd(data: torch.Tensor, dim: Optional[int | tuple] = None, + correction: float = 1) -> torch.Tensor: + """Compute the standard deviation while ignoring NaNs. + + Always keep the dim. + + Args: + data: Input tensor. + dim: The dimension or dimensions to reduce. If None (default), reduces all dimensions. + correction: Difference between the sample size and sample degrees of freedom. Defaults 1 (Bessel's). + + Returns: + torch.Tensor: Standard deviation with NaNs ignored. If `dim` is provided, + it reduces along that dimension(s), otherwise reduces all dimensions. + """ + + non_nan_mask = ~torch.isnan(data) + # find not-nan element + non_nan_count = non_nan_mask.sum(dim=dim, keepdim=True) + # compute mean of not-nan elements + mean = torch.nanmean(data, dim=dim, keepdim=True) + + # \Sigma (x - mean)^2 --> any x that is nan will be filtered by using nansum + sum_dev2 = ((data - mean) ** 2).nansum(dim=dim, keepdim=True) + # sqrt and normalize by corrected degrees of freedom + return torch.sqrt(sum_dev2 / (non_nan_count - correction)) \ No newline at end of file diff --git a/torch_staintools/normalizer/factory.py b/torch_staintools/normalizer/factory.py index dea12e6..651f75d 100644 --- a/torch_staintools/normalizer/factory.py +++ b/torch_staintools/normalizer/factory.py @@ -57,7 +57,7 @@ def build(method: TYPE_SUPPORTED, norm_method: Callable match method: case 'reinhard': - return ReinhardNormalizer.build() + return ReinhardNormalizer.build(luminosity_threshold=luminosity_threshold) case 'macenko' | 'vahadane': return StainSeparation.build(method=method, reconst_method=reconst_method, num_stains=num_stains, diff --git a/torch_staintools/normalizer/reinhard.py b/torch_staintools/normalizer/reinhard.py index d216929..612af42 100644 --- a/torch_staintools/normalizer/reinhard.py +++ b/torch_staintools/normalizer/reinhard.py @@ -2,8 +2,10 @@ from torch_staintools.functional.conversion.lab import rgb_to_lab, lab_to_rgb from torch_staintools.normalizer.base import Normalizer -from typing import Tuple +from torch_staintools.functional.tissue_mask import get_tissue_mask +from typing import Tuple, Optional from torch_staintools.functional.eps import get_eps +from torch_staintools.functional.utility.implementation import nanstd class ReinhardNormalizer(Normalizer): @@ -12,28 +14,40 @@ class ReinhardNormalizer(Normalizer): """ target_means: torch.Tensor target_stds: torch.Tensor + luminosity_threshold: float - def __init__(self): + def __init__(self, luminosity_threshold: Optional[float]): super().__init__(cache=None, device=None, rng=None) + self.luminosity_threshold = luminosity_threshold @staticmethod - def _mean_std_helper(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _mean_std_helper(image: torch.Tensor, *, mask: Optional[torch.Tensor] = None)\ + -> Tuple[torch.Tensor, torch.Tensor]: """Get the channel-wise mean and std of input Args: image: BCHW scaled to [0, 1] torch.float32. Usually in LAB. - + mask: luminosity tissue mask of image. Mean and std are only computed within the tissue regions. Returns: means, """ + assert mask is None or mask.dtype is torch.bool, f"{mask.dtype}" assert image.ndimension() == 4, f"{image.shape}" - means = image.mean(dim=(2, 3)).unsqueeze(-1).unsqueeze(-1) - stds = image.std(dim=(2, 3)).unsqueeze(-1).unsqueeze(-1) + if mask is None: + mask = torch.ones_like(image, dtype=torch.bool) + + image_masked = image * mask + image_masked[mask.expand_as(image_masked) == 0] = torch.nan + means = image_masked.nanmean(dim=(2, 3), keepdim=True) + stds = nanstd(image_masked, dim=(2, 3)) return means, stds def fit(self, image: torch.Tensor): """Fit - compute the means and stds of template in lab space. + Statistics are computed within tissue regions if a luminosity threshold is given to the normalizer upon + creation. + Args: image: template. BCHW. [0, 1] torch.float32. @@ -43,26 +57,28 @@ def fit(self, image: torch.Tensor): # BCHW img_lab: torch.Tensor = rgb_to_lab(image) assert img_lab.ndimension() == 4 and img_lab.shape[1] == 3, f"{img_lab.shape}" + mask = get_tissue_mask(image, luminosity_threshold=self.luminosity_threshold) # B1HW # 1, C, 1, 1 - means, stds = ReinhardNormalizer._mean_std_helper(img_lab) + means, stds = ReinhardNormalizer._mean_std_helper(img_lab, mask=mask) self.register_buffer('target_means', means) self.register_buffer('target_stds', stds) @staticmethod - def normalize_helper(image: torch.Tensor, target_means: torch.Tensor, target_stds: torch.Tensor): + def normalize_helper(image: torch.Tensor, target_means: torch.Tensor, target_stds: torch.Tensor, + mask: Optional[torch.Tensor] = None): """Helper. Args: image: BCHW format. torch.float32 type in range [0, 1]. target_means: channel-wise means of template target_stds: channel-wise stds of template - + mask: Optional luminosity tissue mask to compute the stats within masked region Returns: """ - means_input, stds_input = ReinhardNormalizer._mean_std_helper(image) + means_input, stds_input = ReinhardNormalizer._mean_std_helper(image, mask=mask) return (image - means_input) * (target_stds / (stds_input + get_eps(image))) + target_means def transform(self, x: torch.Tensor, *args, **kwargs): @@ -80,12 +96,15 @@ def transform(self, x: torch.Tensor, *args, **kwargs): """ # 1 C 1 1 lab_input = rgb_to_lab(x) - normalized_lab = ReinhardNormalizer.normalize_helper(lab_input, self.target_means, self.target_stds) + mask = get_tissue_mask(x, luminosity_threshold=self.luminosity_threshold, throw_error=False, + true_when_empty=True) + normalized_lab = ReinhardNormalizer.normalize_helper(lab_input, self.target_means, self.target_stds, + mask) return lab_to_rgb(normalized_lab).clamp_(0, 1) def forward(self, x: torch.Tensor, *args, **kwargs): return self.transform(x) @classmethod - def build(cls, *args, **kwargs): - return cls() + def build(cls, luminosity_threshold: Optional[float] = None, **kwargs): + return cls(luminosity_threshold=luminosity_threshold) diff --git a/torch_staintools/version.py b/torch_staintools/version.py index a6221b3..3f6fab6 100644 --- a/torch_staintools/version.py +++ b/torch_staintools/version.py @@ -1 +1 @@ -__version__ = '1.0.2' +__version__ = '1.0.3'