Skip to content

Commit

Permalink
Preliminary implementation of TTA
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Jul 12, 2023
1 parent b71c668 commit 1561ec6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
48 changes: 38 additions & 10 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from __future__ import annotations

from typing import Any

import kornia as K
import ahcore.transforms.augmentations
import pytorch_lightning as pl
import torch.optim.optimizer
from dlup.data.dataset import ConcatDataset
Expand Down Expand Up @@ -79,6 +80,16 @@ def __init__(
self.predict_metadata: InferenceMetadata = self.INFERENCE_DICT # Used for saving metadata during prediction
self._validation_dataset: ConcatDataset | None = None

# Setup test-time augmentation
self._tta_augmentations = [
ahcore.transforms.augmentations.Identity,
K.augmentation.RandomHorizontalFlip(p=1.0),
K.augmentation.RandomVerticalFlip(p=1.0),
]
self._use_test_time_augmentation = True
self._tta_steps = len(self._tta_augmentations)


@property
def wsi_metrics(self):
return self._wsi_metrics
Expand Down Expand Up @@ -125,21 +136,18 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn):
f"This is required during training and validation"
)

_input = batch["image"]
_target = batch["target"]
# Batch size is required for accurate loss calculation and logging
batch_size = _input.shape[0]
batch_size = batch["image"].shape[0]
# ROIs can reduce the usable area of the inputs, the loss should be scaled appropriately
roi = batch.get("roi", None)

# Extract features only when not training
layer_names = [] if stage == TrainerFn.FITTING else self._attach_feature_layers
with ExtractFeaturesHook(self._model, layer_names=layer_names) as hook:
_prediction = self._model(_input)
if layer_names is not []: # Only add the features if they are requested
batch["features"] = hook.features

if not self._use_test_time_augmentation and stage != TrainerFn.FITTING:
_prediction = self._get_prediction(batch_size, stage)
else:
_prediction = self._get_tta_prediction(batch)
batch["prediction"] = _prediction

loss = self._loss(_prediction, _target, roi)
# The relevant_dict contains values to know where the tiles originate.
_relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS}
Expand All @@ -159,6 +167,26 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn):

return output

def _get_prediction(self, batch: dict[str, torch.Tensor], stage: TrainerFn):
# Extract features only when not training
layer_names = [] if stage == TrainerFn.FITTING else self._attach_feature_layers
with ExtractFeaturesHook(self._model, layer_names=layer_names) as hook:
_prediction = self._model(batch["image"])
if layer_names is not []: # Only add the features if they are requested
batch["features"] = hook.features
return _prediction

def _get_tta_prediction(self, batch: dict[str, torch.Tensor]):
_predictions = torch.zeros([self._tta_steps, *batch["image"].size()], device=self.device)
for idx in range(self._tta_steps):
augmentation = self._tta_augmentations[idx]
if idx == 0:
_predictions[0] = self._get_prediction(batch, stage=TrainerFn.VALIDATING)
else:
_predictions[idx] = augmentation.inverse(self._model(augmentation(batch["image"])))

return _predictions.mean(dim=0)

def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]:
output = self.do_step(batch, batch_idx, stage=TrainerFn.FITTING)
if self.global_step == 0:
Expand Down
11 changes: 11 additions & 0 deletions ahcore/transforms/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def forward(self, *args: torch.Tensor, **kwargs):

return output

class Identity(nn.Module):
"""
Identity transform.
"""

def __init__(self):
super().__init__()

def forward(self, *args: torch.Tensor, **kwargs):
return args


class CenterCrop(nn.Module):
"""Perform a center crop of the image and target"""
Expand Down
5 changes: 4 additions & 1 deletion ahcore/transforms/image_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,10 @@ def _get_staining_vectors_from_cache_or_file(self, filenames):
# Now we need to compute it.
kwargs = {}
if Path(filename) in self._overwrite_mpp:
kwargs["overwrite_mpp"] = (self._overwrite_mpp[Path(filename)], self._overwrite_mpp[Path(filename)])
kwargs["overwrite_mpp"] = (
self._overwrite_mpp[Path(filename)],
self._overwrite_mpp[Path(filename)],
)
with SlideImage.from_file_path(filename, **kwargs) as slide_image:
logger.info("Computing Macenko staining vector for %s", filename)
stain_computer = MacenkoNormalizer(return_stains=False)
Expand Down

0 comments on commit 1561ec6

Please sign in to comment.