Skip to content

Commit

Permalink
Dlup update compatibility (#27)
Browse files Browse the repository at this point in the history
* Upgrade Dlup version + dataset

* dlup 0.3.29

* Fixed typing

* Fix mypy
---------

Co-authored-by: Jonas Teuwen <[email protected]>
  • Loading branch information
EricMarcus-ai and jonasteuwen authored Oct 10, 2023
1 parent 38e3028 commit ee3656f
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 53 deletions.
4 changes: 2 additions & 2 deletions ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _generate_regions(self) -> list[tuple[int, int]]:
"""
regions = []
for coordinates in self._grid:
if self._mask is None or self._is_masked(coordinates):
if self._mask is None or self._is_masked((coordinates[0], coordinates[1])):
regions.append(coordinates)
return regions

Expand Down Expand Up @@ -481,7 +481,7 @@ def _write_tiff(
tile_process_function: Callable[[GenericArray], GenericArray],
generator_from_reader: Callable[
[H5FileImageReader, tuple[int, int], Callable[[GenericArray], GenericArray]],
Generator[GenericArray, None, None],
Iterator[npt.NDArray[np.int_]],
],
) -> None:
logger.debug("Writing TIFF %s", filename.with_suffix(".tiff"))
Expand Down
44 changes: 28 additions & 16 deletions ahcore/cli/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy.typing as npt
import PIL.Image
from dlup import SlideImage
from dlup.data.dataset import TiledROIsSlideImageDataset
from dlup.data.dataset import TiledWsiDataset
from dlup.tiling import GridOrder, TilingMode
from PIL import Image
from pydantic import BaseModel
Expand All @@ -30,7 +30,7 @@
logger = getLogger(__name__)


def read_mask(path: Path) -> npt.NDArray[np.uint8]:
def read_mask(path: Path) -> npt.NDArray[np.int_]:
return iio.imread(path)[..., 0]


Expand All @@ -45,7 +45,7 @@ class SlideImageMetaData(BaseModel):
vendor: str | None

@classmethod
def from_dataset(cls, dataset: TiledROIsSlideImageDataset) -> "SlideImageMetaData":
def from_dataset(cls, dataset: TiledWsiDataset) -> "SlideImageMetaData":
_relevant_keys = ["aspect_ratio", "magnification", "mpp", "size", "vendor"]
return cls(
**{
Expand Down Expand Up @@ -85,15 +85,15 @@ class DatasetConfigs(BaseModel):
def _save_thumbnail(
image_fn: Path,
dataset_cfg: DatasetConfigs,
mask: npt.NDArray[np.uint8] | None,
mask: npt.NDArray[np.int_] | None,
) -> tuple[npt.NDArray[np.uint8], npt.NDArray[np.uint8] | None, npt.NDArray[np.uint8]]:
target_mpp = max(dataset_cfg.mpp * 30, 30)
tile_size = (
min(30, dataset_cfg.tile_size[0] // 30),
min(30, dataset_cfg.tile_size[1] // 30),
)

dataset = TiledROIsSlideImageDataset.from_standard_tiling(
dataset = TiledWsiDataset.from_standard_tiling(
image_fn,
target_mpp,
tile_size,
Expand All @@ -113,7 +113,12 @@ def _save_thumbnail(
mask_arr = None

thumbnail_io = io.BytesIO()
thumbnail = dataset.slide_image.get_thumbnail(tuple(scaled_region_view.size))

# TODO: This needs to change in dlup, the scaled_region_view needs to return size in int, int.
_tile_size = tuple(scaled_region_view.size)
tile_size = (_tile_size[0], _tile_size[1])

thumbnail = dataset.slide_image.get_thumbnail(tile_size)
thumbnail.convert("RGB").save(thumbnail_io, quality=75)
thumbnail_arr = np.frombuffer(thumbnail_io.getvalue(), dtype="uint8")

Expand All @@ -137,10 +142,10 @@ def _save_thumbnail(

def create_slide_image_dataset(
slide_image_path: Path,
mask: SlideImage | npt.NDArray[np.uint8 | np.bool_] | None,
mask: SlideImage | npt.NDArray[np.int_] | None,
cfg: DatasetConfigs,
overwrite_mpp: tuple[float, float] | None = None,
) -> TiledROIsSlideImageDataset:
) -> TiledWsiDataset:
"""
Initializes and returns a slide image dataset.
Expand All @@ -162,7 +167,7 @@ def create_slide_image_dataset(
"""

return TiledROIsSlideImageDataset.from_standard_tiling(
return TiledWsiDataset.from_standard_tiling(
path=slide_image_path,
mpp=cfg.mpp,
tile_size=cfg.tile_size,
Expand All @@ -177,15 +182,22 @@ def create_slide_image_dataset(


def _generator(
dataset: TiledROIsSlideImageDataset, quality: int | None = 80, compression: str = "JPEG"
dataset: TiledWsiDataset, quality: int | None = 80, compression: str = "JPEG"
) -> Generator[Any, Any, Any]:
for idx, sample in enumerate(dataset):
for idx in range(len(dataset)):
# TODO: To use:
# for idx, sample in enumerate(dataset):
# The following needs to be added to TiledWsiDataset:
# def __iter__(self) -> Iterator[RegionFromWsiDatasetSample]:
# for i in range(len(self)):
# yield self[i]
sample = dataset[idx]
buffered = io.BytesIO()
if quality is not None:
# If we just cast the PIL.Image to RGB, the alpha channel is set to black
# which is a bit unnatural if you look in the image pyramid where it would be white in lower resolutions
# this is why we take the following approach.
tile = sample["image"]
tile: PIL.Image.Image = sample["image"]
background = PIL.Image.new("RGB", tile.size, (255, 255, 255)) # Create a white background
background.paste(tile, mask=tile.split()[3]) # Paste the image using the alpha channel as mask
background.convert("RGB").save(buffered, format=compression, quality=quality)
Expand All @@ -199,7 +211,7 @@ def _generator(


def save_tiles(
dataset: TiledROIsSlideImageDataset,
dataset: TiledWsiDataset,
h5_writer: H5FileImageWriter,
quality: int | None = 80,
) -> None:
Expand Down Expand Up @@ -252,11 +264,11 @@ def _tiling_pipeline(
)
save_tiles(dataset, h5_writer, quality)
if save_thumbnail:
thumbnail, mask, overlay = _save_thumbnail(image_path, dataset_cfg, mask)
thumbnail, thumbnail_mask, overlay = _save_thumbnail(image_path, dataset_cfg, mask)

if mask is not None:
if thumbnail_mask is not None:
h5_writer.add_associated_images(
images=(("thumbnail", thumbnail), ("mask", mask), ("overlay", overlay)),
images=(("thumbnail", thumbnail), ("mask", thumbnail_mask), ("overlay", overlay)),
description="thumbnail, mask and overlay",
)
else:
Expand Down
6 changes: 3 additions & 3 deletions ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from __future__ import annotations

import uuid as uuid_module
from typing import Any, Callable, Iterator, Optional
from typing import Any, Callable, Generator, Iterator, Optional

import numpy as np
import pytorch_lightning as pl
import torch
from dlup.data.dataset import ConcatDataset
from dlup.data.dataset import ConcatDataset, TiledWsiDataset
from pytorch_lightning.trainer.states import TrainerFn
from torch.utils.data import DataLoader, Sampler

Expand Down Expand Up @@ -119,7 +119,7 @@ def setup(self, stage: str) -> None:

self._logger.info("Constructing dataset iterator for stage %s", stage)

def dataset_iterator() -> Iterator[_DlupDataset]:
def dataset_iterator() -> Generator[TiledWsiDataset, None, None]:
gen = datasets_from_data_description(
db_manager=self._data_manager,
data_description=self.data_description,
Expand Down
14 changes: 7 additions & 7 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from __future__ import annotations

from typing import Any
from typing import Any, Callable

import pytorch_lightning as pl
import torch.optim.optimizer
Expand All @@ -16,7 +16,7 @@
from ahcore.metrics import MetricFactory, WSIMetricFactory
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger

from ahcore.utils.types import DlupDatasetSample
logger = get_logger(__name__)


Expand Down Expand Up @@ -77,9 +77,9 @@ def wsi_metrics(self) -> WSIMetricFactory | None:

@property
def name(self) -> str:
return self._model.__class__.__name__
return str(self._model.__class__.__name__)

def forward(self, sample):
def forward(self, sample: DlupDatasetSample) -> DlupDatasetSample:
"""This function is only used during inference"""
self._model.eval()
return self._model.forward(sample)
Expand Down Expand Up @@ -198,9 +198,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return output

def configure_optimizers(self):
optimizer = self.hparams.optimizer(params=self.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
optimizer = self.hparams.optimizer(params=self.parameters()) # type: ignore
if self.hparams.scheduler is not None: # type: ignore
scheduler = self.hparams.scheduler(optimizer=optimizer) # type: ignore
return {
"optimizer": optimizer,
"lr_scheduler": {
Expand Down
31 changes: 18 additions & 13 deletions ahcore/transforms/pre_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
dataset."""
from __future__ import annotations

from typing import Any, Callable
from typing import Callable, Any

import numpy as np
import numpy.typing as npt
import torch
from dlup.data.transforms import ContainsPolygonToLabel, ConvertAnnotationsToMask, RenameLabels
from torchvision.transforms import Compose
from torchvision.transforms import functional as F

from ahcore.exceptions import ConfigurationError
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger
from ahcore.utils.types import DlupDatasetSample
from dlup.data.dataset import TileSample, TileSampleWithAnnotationData

PreTransformCallable = Callable[[TileSample], TileSampleWithAnnotationData]

logger = get_logger(__name__)


class PreTransformTaskFactory:
def __init__(self, transforms: list[Callable]):
def __init__(self, transforms: list[PreTransformCallable]):
"""
Pre-transforms are transforms that are applied to the samples directly originating from the dataset.
These transforms are typically the same for the specific tasks (e.g., segmentation,
Expand All @@ -39,7 +42,7 @@ def __init__(self, transforms: list[Callable]):
ImageToTensor(),
AllowCollate(),
]
self._transforms = Compose(transforms)
self._transforms = transforms

@classmethod
def for_segmentation(
Expand All @@ -61,7 +64,7 @@ def for_segmentation(
PreTransformTaskFactory
The `PreTransformTaskFactory` initialized for segmentation tasks.
"""
transforms: list[Callable] = []
transforms: list[PreTransformCallable] = []
if not requires_target:
return cls(transforms)

Expand All @@ -82,7 +85,7 @@ def for_segmentation(
def for_wsi_classification(
cls, data_description: DataDescription, requires_target: bool = True
) -> PreTransformTaskFactory:
transforms: list[Callable] = []
transforms: list[PreTransformCallable] = []
if not requires_target:
return cls(transforms)

Expand All @@ -100,8 +103,10 @@ def for_tile_classification(cls, roi_name: str, label: str, threshold: float) ->
convert_annotations = ContainsPolygonToLabel(roi_name=roi_name, label=label, threshold=threshold)
return cls([convert_annotations])

def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
return self._transforms(data)
def __call__(self, data: DlupDatasetSample) -> DlupDatasetSample:
for transform in self._transforms:
data = transform(data)
return data

def __repr__(self) -> str:
return f"PreTransformTaskFactory(transforms={self._transforms})"
Expand All @@ -122,7 +127,7 @@ class LabelToClassIndex:
def __init__(self, index_map: dict[str, int]):
self._index_map = index_map

def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample:
sample["labels"] = {
label_name: self._index_map[label_value] for label_name, label_value in sample["labels"].items()
}
Expand All @@ -145,7 +150,7 @@ def __init__(self, index_map: dict[str, int]):
# Check the max value in the mask
self._largest_index = max(index_map.values())

def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample:
mask = sample["annotation_data"]["mask"]

new_mask = np.zeros((self._largest_index + 1, *mask.shape))
Expand All @@ -156,7 +161,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
return sample


def one_hot_encoding(index_map: dict[str, int], mask: npt.NDArray[Any]) -> npt.NDArray[Any]:
def one_hot_encoding(index_map: dict[str, int], mask: npt.NDArray[np.int_ | np.float_]) -> npt.NDArray[np.float32]:
"""
functional interface to convert labels/predictions into one-hot codes
Expand Down Expand Up @@ -185,7 +190,7 @@ class AllowCollate:
This transform converts the path to a string. Same holds for the annotations and labels
"""

def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
def __call__(self, sample: TileSample) -> dict[str, Any]:
# Path objects cannot be collated
sample["path"] = str(sample["path"])

Expand All @@ -203,7 +208,7 @@ class ImageToTensor:
Transform to translate the output of a dlup dataset to data_description supported by AhCore
"""

def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]:
sample["image"] = F.pil_to_tensor(sample["image"].convert("RGB")).float()

if sample["image"].sum() == 0:
Expand Down
Loading

0 comments on commit ee3656f

Please sign in to comment.