diff --git a/README.md b/README.md index dfb4f49..aa643b1 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,5 @@ # AI for Oncology Core for Computational Pathology -[![Tox](https://github.com/NKI-AI/ahcore/actions/workflows/tox.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/tox.yml) -[![mypy](https://github.com/NKI-AI/ahcore/actions/workflows/mypy.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/mypy.yml) -[![Pylint](https://github.com/NKI-AI/ahcore/actions/workflows/pylint.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/pylint.yml) -[![Black](https://github.com/NKI-AI/ahcore/actions/workflows/black.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/black.yml) +[![Run Precommit Checks on PR Approval](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml) [![codecov](https://codecov.io/gh/NKI-AI/ahcore/branch/main/graph/badge.svg?token=OIJ7F9G7OO)](https://codecov.io/gh/NKI-AI/ahcore) Ahcore are the [AI for Oncology](https://aiforoncology.nl) core components for computational pathology. diff --git a/ahcore/callbacks.py b/ahcore/callbacks.py index 31ae055..6cafded 100644 --- a/ahcore/callbacks.py +++ b/ahcore/callbacks.py @@ -26,6 +26,7 @@ from dlup.tiling import Grid, GridOrder, TilingMode from dlup.writers import TiffCompression, TifffileImageWriter from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import TrainerFn from shapely.geometry import MultiPoint, Point from torch.utils.data import Dataset @@ -286,7 +287,7 @@ def __init__(self, max_queue_size: int, max_concurrent_writers: int, dump_dir: P self._dump_dir = Path(dump_dir) self._max_queue_size = max_queue_size self._semaphore = Semaphore(max_concurrent_writers) - self._validation_index = 0 + self._dataset_index = 0 self._logger = get_logger(type(self).__name__) @@ -314,7 +315,7 @@ def __process_management(self) -> None: def writers(self): return self._writers - def on_validation_batch_end( + def _batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, @@ -322,7 +323,7 @@ def on_validation_batch_end( batch: Any, batch_idx: int, dataloader_idx: int = 0, - ) -> None: + ): filename = batch["path"][0] # Filenames are constant across the batch. if any([filename != path for path in batch["path"]]): raise ValueError( @@ -350,10 +351,18 @@ def on_validation_batch_end( self._semaphore.release() self._semaphore.acquire() - validate_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore + + if trainer.state.fn == TrainerFn.VALIDATING: + total_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore + elif trainer.state.fn == TrainerFn.PREDICTING: + total_dataset: ConcatDataset = trainer.predict_dataloaders.dataset # type: ignore + else: + raise NotImplementedError( + f"TrainerFn {trainer.state.fn} is not supported for {self.__class__.__name__}." + ) current_dataset: TiledROIsSlideImageDataset - current_dataset, _ = validate_dataset.index_to_dataset(self._validation_index) # type: ignore + current_dataset, _ = total_dataset.index_to_dataset(self._dataset_index) # type: ignore slide_image = current_dataset.slide_image data_description: DataDescription = pl_module.data_description # type: ignore @@ -395,18 +404,46 @@ def on_validation_batch_end( coordinates_x, coordinates_y = batch["coordinates"] coordinates = torch.stack([coordinates_x, coordinates_y]).T.detach().cpu().numpy() self._writers[filename]["queue"].put((coordinates, prediction)) - self._validation_index += prediction.shape[0] + self._dataset_index += prediction.shape[0] - def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._current_filename is not None: self.__process_management() self._semaphore.release() - self._validation_index = 0 + self._dataset_index = 0 # Reset current filename to None for correct execution of subsequent validation loop self._current_filename = None # Clear all the writers from the current epoch self._writers = {} + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + @staticmethod def generator( queue: Queue[Optional[GenericArray]], # pylint: disable=unsubscriptable-object @@ -511,7 +548,7 @@ def setup( assert _callback.dump_dir, "_callback.dump_dir should never be None after the setup." self._dump_dir = _callback.dump_dir - def on_validation_batch_end( + def _batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, @@ -532,7 +569,7 @@ def on_validation_batch_end( ) self._filenames[filename] = output_filename - def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: assert self.dump_dir, "dump_dir should never be None here." self._logger.info("Writing TIFF files to %s", self.dump_dir / "outputs" / f"{pl_module.name}") results = [] @@ -564,6 +601,34 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo result.get() # Wait for the process to complete. self._filenames = {} # Reset the filenames + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + # Create a data structure to hold all required information for each task TaskData = namedtuple("TaskData", ["filename", "h5_filename", "metadata", "mask", "annotations"]) @@ -776,7 +841,9 @@ def on_validation_batch_end( "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler." ) - def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> list[list[dict[str, dict[str, float]]]]: + def compute_metrics( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> list[list[dict[str, dict[str, float]]]]: assert self._dump_dir assert self._data_description assert self._validate_metadata diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index ed106e7..93fa6c7 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -166,6 +166,12 @@ def construct_dataset() -> ConcatDataset[DlupDatasetSample]: drop_last=True, ) + elif stage == "predict": + batch_sampler = ahcore.data.samplers.WsiBatchSamplerPredict( + dataset=dataset, + batch_size=batch_size, + ) + else: batch_sampler = ahcore.data.samplers.WsiBatchSampler( dataset=dataset, @@ -235,6 +241,15 @@ def test_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: self._validate_data_iterator, batch_size=batch_size, stage="test" ) + def predict_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: + if not self._predict_data_iterator: + self.setup("predict") + batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size + assert self._predict_data_iterator + return self._construct_concatenated_dataloader( + self._predict_data_iterator, batch_size=batch_size, stage="predict" + ) + def teardown(self, stage: str | None = None) -> None: if stage is not None: getattr(self, f"_{stage}_data_iterator").__del__() diff --git a/ahcore/data/samplers.py b/ahcore/data/samplers.py index 2e0f5c6..15f128a 100644 --- a/ahcore/data/samplers.py +++ b/ahcore/data/samplers.py @@ -7,7 +7,7 @@ from typing import Generator, List from dlup.data.dataset import ConcatDataset, TiledROIsSlideImageDataset -from torch.utils.data import Sampler +from torch.utils.data import Sampler, SequentialSampler from ahcore.utils.io import get_logger @@ -18,7 +18,7 @@ class WsiBatchSampler(Sampler[List[int]]): def __init__(self, dataset: ConcatDataset[TiledROIsSlideImageDataset], batch_size: int) -> None: super().__init__(data_source=dataset) self._dataset = dataset - self._batch_size = batch_size + self.batch_size = batch_size self._slices: List[slice] = [] self._populate_slices() @@ -35,7 +35,7 @@ def __iter__(self) -> Generator[List[int], None, None]: # Within each slice, create batches of size self._batch_size for idx in range(slice_.start, slice_.stop): batch.append(idx) - if len(batch) == self._batch_size: + if len(batch) == self.batch_size: yield batch batch = [] @@ -45,12 +45,66 @@ def __iter__(self) -> Generator[List[int], None, None]: def __len__(self) -> int: # The total number of batches is the sum of the number of batches in each slice - return sum(math.ceil((s.stop - s.start) / self._batch_size) for s in self._slices) + return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" - f"batch_size={self._batch_size}, " + f"batch_size={self.batch_size}, " + f"num_batches={self.__len__()}, " + f"num_wsis={len(self._dataset.datasets)})" + ) + + +class WsiBatchSamplerPredict(Sampler[List[int]]): + """This Sampler is identical to the WsiBatchSampler, + but its signature is changed for compatibility with the predict phase of Lightning.""" + + def __init__( + self, + sampler: SequentialSampler | None = None, + batch_size: int | None = None, + drop_last: bool | None = None, + dataset: ConcatDataset[TiledROIsSlideImageDataset] | None = None, + ) -> None: + if sampler is not None: # During the predict phase, the sampler is passed as a parameter + self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = sampler.data_source # type: ignore + else: + self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = dataset # type: ignore + super().__init__(data_source=self._dataset) + self.batch_size = batch_size + + self._slices: List[slice] = [] + self._populate_slices() + + def _populate_slices(self) -> None: + for idx, _ in enumerate(self._dataset.datasets): + slice_start = 0 if len(self._slices) == 0 else self._slices[-1].stop + slice_stop = self._dataset.cumulative_sizes[idx] + self._slices.append(slice(slice_start, slice_stop)) + + def __iter__(self) -> Generator[List[int], None, None]: + for slice_ in self._slices: + batch = [] + # Within each slice, create batches of size self._batch_size + for idx in range(slice_.start, slice_.stop): + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + # If there are remaining items that couldn't form a full batch, yield them as a smaller batch + if len(batch) > 0: + yield batch + + def __len__(self) -> int: + # The total number of batches is the sum of the number of batches in each slice + return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"batch_size={self.batch_size}, " f"num_batches={self.__len__()}, " f"num_wsis={len(self._dataset.datasets)})" ) diff --git a/ahcore/entrypoints.py b/ahcore/entrypoints.py index bbe9084..62b3c17 100644 --- a/ahcore/entrypoints.py +++ b/ahcore/entrypoints.py @@ -218,9 +218,11 @@ def inference(config: DictConfig) -> None: raise NotImplementedError(f"No augmentations target found in <{config.augmentations[stage]}>") logger.info(f"Instantiating {stage} augmentations <{config.augmentations[stage]._target_}>") # noqa augmentations[stage] = hydra.utils.instantiate( - config.augmentations[stage], data_description=data_description + config.augmentations[stage], + data_description=data_description, + data_module=datamodule, + _convert_="object", ) - # Init lightning model if not config.lit_module.get("_target_"): raise NotImplementedError(f"No model target found in <{config.lit_module}>") @@ -266,10 +268,7 @@ def inference(config: DictConfig) -> None: # Inference logger.info("Starting inference...") - predict_dataloader = datamodule.predict_dataloader() - for metadata, dataloader in predict_dataloader: - model.predict_metadata = metadata # update the metadata to the current WSI - trainer.predict(model=model, dataloaders=dataloader) + trainer.predict(model=model, datamodule=datamodule) # Make sure everything closed properly logger.info("Finalizing...") diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 1e85417..941e556 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -183,13 +183,19 @@ def validation_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, An return output def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - if self._augmentations: + if self._augmentations and "predict" in self._augmentations: batch = self._augmentations["predict"](batch) - inputs = batch["image"] - predictions = self._model(inputs) - gathered_predictions = self.all_gather(predictions) - return gathered_predictions + _relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS} + batch = {**batch, **self._get_inference_prediction(batch["image"])} + _prediction = batch["prediction"] + output = {"prediction": _prediction, **_relevant_dict} + + # This is a sanity check. We expect the filenames to be constant across the batch. + filename = batch["path"][0] + if any([filename != f for f in batch["path"]]): + raise ValueError("Filenames are not constant across the batch.") + return output def configure_optimizers(self): optimizer = self.hparams.optimizer(params=self.parameters()) diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index e2a271f..5e01bbe 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -49,9 +49,9 @@ class GridDescription(BaseModel): class DataDescription(BaseModel): - mask_label: Optional[str] - mask_threshold: Optional[float] # This is only used for training - roi_name: Optional[str] + mask_label: Optional[str] = None + mask_threshold: Optional[float] = None # This is only used for training + roi_name: Optional[str] = None num_classes: PositiveInt data_dir: Path manifest_database_uri: str diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 02dd923..0731d3d 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -10,6 +10,7 @@ class CategoryEnum(PyEnum): TRAIN = "fit" VALIDATE = "validate" TEST = "test" + PREDICT = "predict" class Base(DeclarativeBase): diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 9b665ff..ae9f304 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -61,7 +61,7 @@ class _AnnotationReadersDict(TypedDict): def parse_annotations_from_record( annotations_root: Path, record: list[Mask] | list[ImageAnnotations] -) -> _AnnotationReturnTypes: +) -> _AnnotationReturnTypes | None: """ Parse the annotations from a record of type ImageAnnotations. @@ -77,8 +77,8 @@ def parse_annotations_from_record( WsiAnnotations The parsed annotations. """ - if record is None: - return + if not record: + return None assert len(record) == 1 valid_readers = list(_AnnotationReaders.keys()) @@ -100,7 +100,7 @@ def parse_annotations_from_record( def get_mask_and_annotations_from_record( annotations_root: Path, record: Image -) -> tuple[_AnnotationReturnTypes, _AnnotationReturnTypes]: +) -> tuple[_AnnotationReturnTypes | None, _AnnotationReturnTypes | None]: """ Get the mask and annotations from a record of type Image. @@ -121,7 +121,7 @@ def get_mask_and_annotations_from_record( return _masks, _annotations -def _get_rois(mask: WsiAnnotations, data_description: DataDescription, stage: str) -> Optional[Rois]: +def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, stage: str) -> Optional[Rois]: if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): return None diff --git a/config/callbacks/inference.yaml b/config/callbacks/inference.yaml new file mode 100644 index 0000000..ce37b35 --- /dev/null +++ b/config/callbacks/inference.yaml @@ -0,0 +1,16 @@ +defaults: + - model_summary.yaml + - rich_progress_bar.yaml + - write_h5_callback.yaml + - write_tiff_callback.yaml + - _self_ + +model_summary: + max_depth: -1 + +write_h5_callback: + max_queue_size: 100 + max_concurrent_writers: 4 + +write_tiff_callback: + max_concurrent_writers: 4 diff --git a/config/data_description/tissue_subtypes/segmentation_inference.yaml b/config/data_description/tissue_subtypes/segmentation_inference.yaml new file mode 100644 index 0000000..e3a9f12 --- /dev/null +++ b/config/data_description/tissue_subtypes/segmentation_inference.yaml @@ -0,0 +1,42 @@ +_target_: ahcore.utils.data.DataDescription +data_dir: ${oc.env:DATA_DIR} +annotations_dir: ${oc.env:ANNOTATIONS_DIR}/tissue_subtypes/v20230228_debug # specify in .env +manifest_database_uri: sqlite:///${oc.env:MANIFEST_PATH}/tissue_subtypes/v20230228_debug/manifest.db +manifest_name: "v20230228" +split_version: "v1" +use_roi: False +# Eg. 512 x 512 is: tile_size 726 x 726 and tile_overlap 107 x 107 +# Tiles are cropped in the transforms downstream, to ensure the patch is completely visible, they are extracted +# slightly larger (sqrt 2) with sufficient overlap so we have a 512 stride. +training_grid: + mpp: 1.0 + output_tile_size: [726, 726] + tile_overlap: [0, 0] + tile_size: [512, 512] +inference_grid: + mpp: 1.0 + tile_size: [512, 512] + tile_overlap: [128, 128] + +num_classes: 4 +use_class_weights: false # Use the class weights in the loss +remap_labels: + stroma: stroma + tumor: tumor + inflamed: stroma + dcis: tumor + lymphoid aggregates: ignore + dcis immune cells: stroma + necrotic areas: ignore + normal glands: stroma + blood vessels: stroma + fat cell area: ignore + red blood cells: stroma + fibrosis areas: ignore + artefacts: ignore + other: ignore + +index_map: + stroma: 1 + tumor: 2 + ignore: 3 diff --git a/config/inference.yaml b/config/inference.yaml index f2db865..8d6e978 100644 --- a/config/inference.yaml +++ b/config/inference.yaml @@ -10,6 +10,7 @@ defaults: - task: segmentation_inference.yaml - logger: null - lit_module: ??? + - callbacks: inference.yaml - trainer: default.yaml - hydra: default.yaml # if we want to setup different hydra logging dirs / color log - paths: default.yaml diff --git a/tools/populate_tcga_db.py b/tools/populate_tcga_db.py index 8cf5705..072163f 100644 --- a/tools/populate_tcga_db.py +++ b/tools/populate_tcga_db.py @@ -24,7 +24,10 @@ def get_patient_from_tcga_id(tcga_filename: str) -> str: return tcga_filename[:12] -def populate_from_annotated_tcga(session, image_folder: Path, annotation_folder: Path, path_to_mapping: Path): +def populate_from_annotated_tcga( + session, image_folder: Path, annotation_folder: Path, path_to_mapping: Path, predict: bool = False +): + """This is a basic example, adjust to your needs.""" # TODO: We should do the mpp as well here with open(path_to_mapping, "r") as f: @@ -40,8 +43,9 @@ def populate_from_annotated_tcga(session, image_folder: Path, annotation_folder: for folder in annotation_folder.glob("TCGA*"): patient_code = get_patient_from_tcga_id(folder.name) - annotation_path = folder / "annotations.json" - mask_path = folder / "roi.json" + if not predict: + annotation_path = folder / "annotations.json" + mask_path = folder / "roi.json" # Only add patient if it doesn't exist existing_patient = session.query(Patient).filter_by(patient_code=patient_code).first() # type: ignore @@ -53,9 +57,12 @@ def populate_from_annotated_tcga(session, image_folder: Path, annotation_folder: session.flush() # For now random. - split_category = random.choices( - [CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST], [70, 20, 10] - )[0] + if predict: + split_category = CategoryEnum.PREDICT + else: + split_category = random.choices( + [CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST], [70, 20, 10] + )[0] split = Split( category=split_category, @@ -98,11 +105,12 @@ def populate_from_annotated_tcga(session, image_folder: Path, annotation_folder: session.add(image) session.flush() # Flush so that Image ID is populated for future records - mask = Mask(filename=str(mask_path), reader="GEOJSON", image=image) - session.add(mask) + if not predict: + mask = Mask(filename=str(mask_path), reader="GEOJSON", image=image) + session.add(mask) - image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image) - session.add(image_annotation) + image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image) + session.add(image_annotation) label_data = "cancer" if random.choice([True, False]) else "benign" # Randomly decide if it's cancer or benign image_label = ImageLabels(label_data=label_data, image=image) @@ -112,8 +120,8 @@ def populate_from_annotated_tcga(session, image_folder: Path, annotation_folder: if __name__ == "__main__": - annotation_folder = Path("tissue_subtypes/v20230228_combined_v2/") + annotation_folder = Path("tissue_subtypes/v20230228_debug/") image_folder = Path("/data/groups/aiforoncology/archive/pathology/TCGA/images/") path_to_mapping = Path("/data/groups/aiforoncology/archive/pathology/TCGA/identifier_mapping.json") with open_db("manifest.db") as session: - populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping) + populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping, predict=True)