Skip to content

Commit

Permalink
Implement an inverse for the identity.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Jul 12, 2023
1 parent 37cfebf commit 11418dd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
39 changes: 14 additions & 25 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
self._use_test_time_augmentation = True
self._tta_steps = len(self._tta_augmentations)


@property
def wsi_metrics(self):
return self._wsi_metrics
Expand Down Expand Up @@ -170,34 +169,24 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn):
def _get_inference_prediction(self, _input: torch.Tensor) -> dict[str, torch.Tensor]:
output = {}

if self._use_test_time_augmentation:
_predictions = torch.zeros([self._tta_steps, *_input.size()], device=self.device)
_collected_features = None

with ExtractFeaturesHook(self._model, layer_names=self._attach_feature_layers) as hook:
for idx, augmentation in enumerate(self._tta_augmentations):
model_prediction = self._model(augmentation(_input))
_predictions[idx] = augmentation.inverse(model_prediction)
_predictions = torch.zeros([self._tta_steps, *_input.size()], device=self.device)
_collected_features = None

if self._attach_feature_layers:
_features = hook.features
if _collected_features is None:
_collected_features = torch.zeros([self._tta_steps, *_features.size()], device=self.device)
_features[idx] = _collected_features
with ExtractFeaturesHook(self._model, layer_names=self._attach_feature_layers) as hook:
for idx, augmentation in enumerate(self._tta_augmentations):
model_prediction = self._model(augmentation(_input))
_predictions[idx] = augmentation.inverse(model_prediction)

output["prediction"] = _predictions.mean(dim=0)
if self._attach_feature_layers:
_features = hook.features
if _collected_features is None:
_collected_features = torch.zeros([self._tta_steps, *_features.size()], device=self.device)
_features[idx] = _collected_features

if self._attach_feature_layers:
if _collected_features is None:
output["features"] = hook.features.unsqueeze(0)
else:
output["features"] = _collected_features
output["prediction"] = _predictions.mean(dim=0)

else:
with ExtractFeaturesHook(self._model, layer_names=self._attach_feature_layers) as hook:
_prediction = self._model(_input)
if self._attach_feature_layers:
output["features"] = hook.features.unsqueeze(0)
if self._attach_feature_layers:
output["features"] = _collected_features

return output

Expand Down
5 changes: 4 additions & 1 deletion ahcore/transforms/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger

import kornia as K
logger = get_logger(__name__)


Expand Down Expand Up @@ -84,6 +84,7 @@ def forward(self, *args: torch.Tensor, **kwargs):

return output


class Identity(nn.Module):
"""
Identity transform.
Expand All @@ -95,6 +96,8 @@ def __init__(self):
def forward(self, *args: torch.Tensor, **kwargs):
return args

def inverse(self, *args: torch.Tensor, **kwargs):
return args

class CenterCrop(nn.Module):
"""Perform a center crop of the image and target"""
Expand Down

0 comments on commit 11418dd

Please sign in to comment.