diff --git a/ahcore/cli/data.py b/ahcore/cli/data.py index 049cba4..f4fa8c1 100644 --- a/ahcore/cli/data.py +++ b/ahcore/cli/data.py @@ -34,7 +34,7 @@ def copy_data(args: argparse.Namespace): with DataManager(manifest_fn) as dm: all_records = dm.get_records_by_split(args.manifest_name, args.split_name, split_category=None) with Progress() as progress: - task = progress.add_task("[cyan]Copying...", total=len(all_records)) + task = progress.add_task("[cyan]Copying...") for patient in all_records: for image in patient.images: image_fn = image.filename diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 7cb2a1f..6142ec9 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -7,7 +7,6 @@ from typing import Any, Optional, cast -import kornia as K import pytorch_lightning as pl import torch.optim.optimizer from dlup.data.dataset import ConcatDataset @@ -15,10 +14,8 @@ from torch import nn from torch.utils.tensorboard import SummaryWriter -import ahcore.transforms.augmentations from ahcore.utils.data import DataDescription from ahcore.utils.io import get_logger -from ahcore.utils.model import ExtractFeaturesHook logger = get_logger(__name__) @@ -44,7 +41,6 @@ def __init__( augmentations: dict[str, nn.Module] | None = None, metrics: dict[str, nn.Module] | None = None, scheduler: Any | None = None, # noqa - attach_feature_layers: list[str] | None = None, ): super().__init__() @@ -69,18 +65,6 @@ def __init__( self._wsi_metrics = metrics.get("wsi_level") self._data_description = data_description - self._attach_feature_layers = attach_feature_layers - - self._validation_dataset: ConcatDataset | None = None - - # Setup test-time augmentation - self._tta_augmentations = [ - ahcore.transforms.augmentations.Identity(), - K.augmentation.RandomHorizontalFlip(p=1.0), - K.augmentation.RandomVerticalFlip(p=1.0), - ] - self._use_test_time_augmentation = False - self._tta_steps = len(self._tta_augmentations) @property def wsi_metrics(self): @@ -99,10 +83,6 @@ def forward(self, sample): def data_description(self) -> DataDescription: return self._data_description - @property - def validation_dataset(self) -> Optional[ConcatDataset]: - return self._validation_dataset - @property def _tensorboard(self) -> SummaryWriter | None: _tensorboard = [_ for _ in self.loggers if isinstance(_, pl.loggers.tensorboard.TensorBoardLogger)] @@ -181,41 +161,13 @@ def do_step(self, batch, batch_idx: int, stage: TrainerFn | str): def _get_inference_prediction(self, _input: torch.Tensor) -> dict[str, torch.Tensor]: output = {} - - output_size = (self._tta_steps, _input.shape[0], self._num_classes, *map(int, _input.shape[2:])) - # FIXME: this mypy error I can't figure. Pytorch documentation says it can be tuple[int,...] - _predictions = torch.zeros(output_size, device=self.device) # type: ignore - _collected_features = {k: None for k in self._attach_feature_layers} if self._attach_feature_layers else {} - - with ExtractFeaturesHook(self._model, layer_names=self._attach_feature_layers) as hook: - for idx, augmentation in enumerate(self._tta_augmentations): - model_prediction = self._model(augmentation(_input)) - _predictions[idx] = augmentation.inverse(model_prediction) - - if self._attach_feature_layers: - _features = hook.features - for key in _features: - if _collected_features[key] is None: - _collected_features[key] = torch.zeros( - (self._tta_steps, *map(int, _features[key].size())), # type: ignore - device=self.device, - ) - _features[key] = _collected_features[key] - - output["prediction"] = _predictions.mean(dim=0) - - if self._attach_feature_layers: - output["features"] = _collected_features - + output["prediction"] = self._model(_input) return output def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]: - # FIXME: This gives very weird errors... - del batch["labels"] - - # if self.global_step == 0: - # if self._tensorboard: - # self._tensorboard.add_graph(self._model, batch["image"]) + if self.global_step == 0: + if self._tensorboard: + self._tensorboard.add_graph(self._model, batch["image"]) output = self.do_step(batch, batch_idx, stage=TrainerFn.FITTING) return output @@ -229,12 +181,6 @@ def validation_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, An raise ValueError("Filenames are not constant across the batch.") return output - def on_validation_start(self) -> None: - assert hasattr(self.trainer, "datamodule"), "Datamodule is not defined for the trainer. Required for validation" - datamodule: AhCoreLightningModule = cast(AhCoreLightningModule, getattr(self.trainer, "datamodule")) - assert hasattr(datamodule, "val_concat_dataset"), "Validation dataset is not defined for the datamodule" - self._validation_dataset = datamodule.val_concat_dataset - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: if self._augmentations: batch = self._augmentations["predict"](batch) diff --git a/ahcore/utils/model.py b/ahcore/utils/model.py deleted file mode 100644 index a0f286f..0000000 --- a/ahcore/utils/model.py +++ /dev/null @@ -1,117 +0,0 @@ -# encoding: utf-8 -from __future__ import annotations - -from pprint import pformat -from typing import Any, Optional, Type - -import torch -import torch.nn as nn -from torch.utils.hooks import RemovableHandle - - -class ExtractFeaturesHook: - """ - Context manager to add hooks to the layers of a PyTorch model to extract feature maps. - - Parameters - ---------- - model : torch.nn.Module - The model to which the hooks will be added. - layer_names : list[str] - The names of the layers from which the feature maps will be extracted. - - Attributes - ---------- - features : dict - The feature maps extracted from the layers. - - Examples - -------- - >>> model = MyModel() - >>> layer_names = ["conv1", "conv2"] - >>> with ExtractFeaturesHook(model, layer_names) as hook: - ... output = model(input_data) - ... conv1_features = hook.features["conv1"] - ... conv2_features = hook.features["conv2"] - """ - - def __init__(self, model: nn.Module, *, layer_names: list[str] | None): - self.model = model - self.layer_names = layer_names if layer_names is not None else [] - self.hooks: list[RemovableHandle] = [] - self.features: dict[str, torch.Tensor] = {} - - def save_output( - self, - layer_name: str, - module: nn.Module, - input: torch.Tensor, - output: torch.Tensor, - ): - """ - Hook function that saves the output of a layer. - - In this setting we are not using the `module` or `input` parameters. However, they are required to be - present in the function signature to be used as a hook. For future reference, they can be used as follows: - - You can use module if you want to save or analyze the weights or biases of the layer. - - You can use input if you want to see what goes into the layer, - or if you want to save or analyze the input data. - - Parameters - ---------- - layer_name : str - The name of the layer. - module : torch.nn.Module - The layer module. - input : torch.Tensor - The input to the layer. - output : torch.Tensor - The output of the layer. - """ - self.features[layer_name] = output.detach() - - def __enter__(self): - """ - Registers the hooks when entering the context. - """ - if not hasattr(self, "layer_cache"): - self.layer_cache = {name: module for name, module in self.model.named_modules()} - - for layer_name in self.layer_names: - if layer_name in self.layer_cache: - module = self.layer_cache[layer_name] - hook = module.register_forward_hook( - lambda module, input, output, layer_name=layer_name: self.save_output( - layer_name, module, input, output - ) - ) - self.hooks.append(hook) - else: - raise ValueError( - f"No layer named {layer_name} in model. " - f"These are the available ones: {pformat(self.layer_cache)}." - ) - return self - - def __exit__( - self, - type: Optional[Type[BaseException]], - value: Optional[BaseException], - traceback: Optional[Any], - ) -> None: - """ - Removes the hooks when exiting the context. - - Parameters - ---------- - type : Optional[Type[BaseException]] - The type of exception that caused the context to be exited, if any. - value : Optional[BaseException] - The instance of exception that caused the context to be exited, if any. - traceback : Optional[Any] - A traceback object encapsulating the call stack at the point where the exception originally occurred, - if any. - """ - for hook in self.hooks: - hook.remove() - self.hooks = [] diff --git a/config/lit_module/monai_segmentation/attention_unet.yaml b/config/lit_module/monai_segmentation/attention_unet.yaml index 40eeb95..31360da 100644 --- a/config/lit_module/monai_segmentation/attention_unet.yaml +++ b/config/lit_module/monai_segmentation/attention_unet.yaml @@ -1,7 +1,4 @@ _target_: ahcore.lit_module.AhCoreLightningModule -attach_feature_layers: - - "model.1.submodule.1.submodule.1.submodule.1.submodule.conv.0.adn.A" - model: # Do not set out_channels, this is derived from data_description. _target_: monai.networks.nets.attentionunet.AttentionUnet diff --git a/config/lit_module/monai_segmentation/unet.yaml b/config/lit_module/monai_segmentation/unet.yaml index 2d76b1c..06eeb9f 100644 --- a/config/lit_module/monai_segmentation/unet.yaml +++ b/config/lit_module/monai_segmentation/unet.yaml @@ -1,7 +1,4 @@ _target_: ahcore.lit_module.AhCoreLightningModule -attach_feature_layers: - - 'model.0.conv' - - 'model.1.submodule.1.submodule.1.submodule' model: # Do not set out_channels, this is derived from data_description.