diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index a7040eb..24fd0f5 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -3,11 +3,12 @@ dataset.""" from __future__ import annotations -from typing import Callable, Any +from typing import Any, Callable import numpy as np import numpy.typing as npt import torch +from dlup.data.dataset import TileSample, TileSampleWithAnnotationData from dlup.data.transforms import ContainsPolygonToLabel, ConvertAnnotationsToMask, RenameLabels from torchvision.transforms import functional as F @@ -15,7 +16,6 @@ from ahcore.utils.data import DataDescription from ahcore.utils.io import get_logger from ahcore.utils.types import DlupDatasetSample -from dlup.data.dataset import TileSample, TileSampleWithAnnotationData PreTransformCallable = Callable[[TileSample], TileSampleWithAnnotationData] @@ -194,10 +194,15 @@ def __call__(self, sample: TileSample) -> dict[str, Any]: # Path objects cannot be collated sample["path"] = str(sample["path"]) - # This would prevent collate - if sample["annotations"] is None: + # Not required anymore + if "annotation_data" in sample: + del sample["annotation_data"] + + # Not required anymore + if "annotations" in sample: del sample["annotations"] - if sample["labels"] is None: + + if sample.get("labels") is None: del sample["labels"] return sample @@ -229,11 +234,6 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: roi = sample["annotation_data"]["roi"] sample["roi"] = torch.from_numpy(roi[np.newaxis, ...]).float() - # Not required anymore - del sample["annotation_data"] - # This might be empty. - del sample["annotations"] - return sample def __repr__(self) -> str: