From 1561ec6266cc42e72fe0dffa54b172c9ff43557b Mon Sep 17 00:00:00 2001 From: Jonas Teuwen Date: Wed, 12 Jul 2023 12:40:54 +0200 Subject: [PATCH] Preliminary implementation of TTA --- ahcore/lit_module.py | 48 +++++++++++++++++++----- ahcore/transforms/augmentations.py | 11 ++++++ ahcore/transforms/image_normalization.py | 5 ++- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 9245a66..13c64bd 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -6,7 +6,8 @@ from __future__ import annotations from typing import Any - +import kornia as K +import ahcore.transforms.augmentations import pytorch_lightning as pl import torch.optim.optimizer from dlup.data.dataset import ConcatDataset @@ -79,6 +80,16 @@ def __init__( self.predict_metadata: InferenceMetadata = self.INFERENCE_DICT # Used for saving metadata during prediction self._validation_dataset: ConcatDataset | None = None + # Setup test-time augmentation + self._tta_augmentations = [ + ahcore.transforms.augmentations.Identity, + K.augmentation.RandomHorizontalFlip(p=1.0), + K.augmentation.RandomVerticalFlip(p=1.0), + ] + self._use_test_time_augmentation = True + self._tta_steps = len(self._tta_augmentations) + + @property def wsi_metrics(self): return self._wsi_metrics @@ -125,21 +136,18 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn): f"This is required during training and validation" ) - _input = batch["image"] _target = batch["target"] # Batch size is required for accurate loss calculation and logging - batch_size = _input.shape[0] + batch_size = batch["image"].shape[0] # ROIs can reduce the usable area of the inputs, the loss should be scaled appropriately roi = batch.get("roi", None) - # Extract features only when not training - layer_names = [] if stage == TrainerFn.FITTING else self._attach_feature_layers - with ExtractFeaturesHook(self._model, layer_names=layer_names) as hook: - _prediction = self._model(_input) - if layer_names is not []: # Only add the features if they are requested - batch["features"] = hook.features - + if not self._use_test_time_augmentation and stage != TrainerFn.FITTING: + _prediction = self._get_prediction(batch_size, stage) + else: + _prediction = self._get_tta_prediction(batch) batch["prediction"] = _prediction + loss = self._loss(_prediction, _target, roi) # The relevant_dict contains values to know where the tiles originate. _relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS} @@ -159,6 +167,26 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn): return output + def _get_prediction(self, batch: dict[str, torch.Tensor], stage: TrainerFn): + # Extract features only when not training + layer_names = [] if stage == TrainerFn.FITTING else self._attach_feature_layers + with ExtractFeaturesHook(self._model, layer_names=layer_names) as hook: + _prediction = self._model(batch["image"]) + if layer_names is not []: # Only add the features if they are requested + batch["features"] = hook.features + return _prediction + + def _get_tta_prediction(self, batch: dict[str, torch.Tensor]): + _predictions = torch.zeros([self._tta_steps, *batch["image"].size()], device=self.device) + for idx in range(self._tta_steps): + augmentation = self._tta_augmentations[idx] + if idx == 0: + _predictions[0] = self._get_prediction(batch, stage=TrainerFn.VALIDATING) + else: + _predictions[idx] = augmentation.inverse(self._model(augmentation(batch["image"]))) + + return _predictions.mean(dim=0) + def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]: output = self.do_step(batch, batch_idx, stage=TrainerFn.FITTING) if self.global_step == 0: diff --git a/ahcore/transforms/augmentations.py b/ahcore/transforms/augmentations.py index 411b479..f0d0417 100644 --- a/ahcore/transforms/augmentations.py +++ b/ahcore/transforms/augmentations.py @@ -84,6 +84,17 @@ def forward(self, *args: torch.Tensor, **kwargs): return output +class Identity(nn.Module): + """ + Identity transform. + """ + + def __init__(self): + super().__init__() + + def forward(self, *args: torch.Tensor, **kwargs): + return args + class CenterCrop(nn.Module): """Perform a center crop of the image and target""" diff --git a/ahcore/transforms/image_normalization.py b/ahcore/transforms/image_normalization.py index b4570b2..036c960 100644 --- a/ahcore/transforms/image_normalization.py +++ b/ahcore/transforms/image_normalization.py @@ -430,7 +430,10 @@ def _get_staining_vectors_from_cache_or_file(self, filenames): # Now we need to compute it. kwargs = {} if Path(filename) in self._overwrite_mpp: - kwargs["overwrite_mpp"] = (self._overwrite_mpp[Path(filename)], self._overwrite_mpp[Path(filename)]) + kwargs["overwrite_mpp"] = ( + self._overwrite_mpp[Path(filename)], + self._overwrite_mpp[Path(filename)], + ) with SlideImage.from_file_path(filename, **kwargs) as slide_image: logger.info("Computing Macenko staining vector for %s", filename) stain_computer = MacenkoNormalizer(return_stains=False)