Skip to content

Commit

Permalink
Remove files to have mypy pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Sep 18, 2023
1 parent d959abe commit 4fd73a7
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 182 deletions.
2 changes: 1 addition & 1 deletion ahcore/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def copy_data(args: argparse.Namespace):
with DataManager(manifest_fn) as dm:
all_records = dm.get_records_by_split(args.manifest_name, args.split_name, split_category=None)
with Progress() as progress:
task = progress.add_task("[cyan]Copying...", total=len(all_records))
task = progress.add_task("[cyan]Copying...")
for patient in all_records:
for image in patient.images:
image_fn = image.filename
Expand Down
62 changes: 4 additions & 58 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@

from typing import Any, Optional, cast

import kornia as K
import pytorch_lightning as pl
import torch.optim.optimizer
from dlup.data.dataset import ConcatDataset
from pytorch_lightning.trainer.states import TrainerFn
from torch import nn
from torch.utils.tensorboard import SummaryWriter

import ahcore.transforms.augmentations
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger
from ahcore.utils.model import ExtractFeaturesHook

logger = get_logger(__name__)

Expand All @@ -44,7 +41,6 @@ def __init__(
augmentations: dict[str, nn.Module] | None = None,
metrics: dict[str, nn.Module] | None = None,
scheduler: Any | None = None, # noqa
attach_feature_layers: list[str] | None = None,
):
super().__init__()

Expand All @@ -69,18 +65,6 @@ def __init__(
self._wsi_metrics = metrics.get("wsi_level")

self._data_description = data_description
self._attach_feature_layers = attach_feature_layers

self._validation_dataset: ConcatDataset | None = None

# Setup test-time augmentation
self._tta_augmentations = [
ahcore.transforms.augmentations.Identity(),
K.augmentation.RandomHorizontalFlip(p=1.0),
K.augmentation.RandomVerticalFlip(p=1.0),
]
self._use_test_time_augmentation = False
self._tta_steps = len(self._tta_augmentations)

@property
def wsi_metrics(self):
Expand All @@ -99,10 +83,6 @@ def forward(self, sample):
def data_description(self) -> DataDescription:
return self._data_description

@property
def validation_dataset(self) -> Optional[ConcatDataset]:
return self._validation_dataset

@property
def _tensorboard(self) -> SummaryWriter | None:
_tensorboard = [_ for _ in self.loggers if isinstance(_, pl.loggers.tensorboard.TensorBoardLogger)]
Expand Down Expand Up @@ -181,41 +161,13 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn | str):

def _get_inference_prediction(self, _input: torch.Tensor) -> dict[str, torch.Tensor]:
output = {}

output_size = (self._tta_steps, _input.shape[0], self._num_classes, *map(int, _input.shape[2:]))
# FIXME: this mypy error I can't figure. Pytorch documentation says it can be tuple[int,...]
_predictions = torch.zeros(output_size, device=self.device) # type: ignore
_collected_features = {k: None for k in self._attach_feature_layers} if self._attach_feature_layers else {}

with ExtractFeaturesHook(self._model, layer_names=self._attach_feature_layers) as hook:
for idx, augmentation in enumerate(self._tta_augmentations):
model_prediction = self._model(augmentation(_input))
_predictions[idx] = augmentation.inverse(model_prediction)

if self._attach_feature_layers:
_features = hook.features
for key in _features:
if _collected_features[key] is None:
_collected_features[key] = torch.zeros(
(self._tta_steps, *map(int, _features[key].size())), # type: ignore
device=self.device,
)
_features[key] = _collected_features[key]

output["prediction"] = _predictions.mean(dim=0)

if self._attach_feature_layers:
output["features"] = _collected_features

output["prediction"] = self._model(_input)
return output

def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]:
# FIXME: This gives very weird errors...
del batch["labels"]

# if self.global_step == 0:
# if self._tensorboard:
# self._tensorboard.add_graph(self._model, batch["image"])
if self.global_step == 0:
if self._tensorboard:
self._tensorboard.add_graph(self._model, batch["image"])

output = self.do_step(batch, batch_idx, stage=TrainerFn.FITTING)
return output
Expand All @@ -229,12 +181,6 @@ def validation_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, An
raise ValueError("Filenames are not constant across the batch.")
return output

def on_validation_start(self) -> None:
assert hasattr(self.trainer, "datamodule"), "Datamodule is not defined for the trainer. Required for validation"
datamodule: AhCoreLightningModule = cast(AhCoreLightningModule, getattr(self.trainer, "datamodule"))
assert hasattr(datamodule, "val_concat_dataset"), "Validation dataset is not defined for the datamodule"
self._validation_dataset = datamodule.val_concat_dataset

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
if self._augmentations:
batch = self._augmentations["predict"](batch)
Expand Down
117 changes: 0 additions & 117 deletions ahcore/utils/model.py

This file was deleted.

3 changes: 0 additions & 3 deletions config/lit_module/monai_segmentation/attention_unet.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
_target_: ahcore.lit_module.AhCoreLightningModule
attach_feature_layers:
- "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.0.adn.A"

model:
# Do not set out_channels, this is derived from data_description.
_target_: monai.networks.nets.attentionunet.AttentionUnet
Expand Down
3 changes: 0 additions & 3 deletions config/lit_module/monai_segmentation/unet.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
_target_: ahcore.lit_module.AhCoreLightningModule
attach_feature_layers:
- 'model.0.conv'
- 'model.1.submodule.1.submodule.1.submodule'

model:
# Do not set out_channels, this is derived from data_description.
Expand Down

0 comments on commit 4fd73a7

Please sign in to comment.