Skip to content

Commit

Permalink
Inference Pipeline (#26)
Browse files Browse the repository at this point in the history
* Inference pipeline + callbacks

* Cleanup + populate_db example

* Update README.md -- Precommit badge
  • Loading branch information
EricMarcus-ai authored Oct 6, 2023
1 parent 9797db1 commit fdc8ad1
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 51 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# AI for Oncology Core for Computational Pathology
[![Tox](https://github.com/NKI-AI/ahcore/actions/workflows/tox.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/tox.yml)
[![mypy](https://github.com/NKI-AI/ahcore/actions/workflows/mypy.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/mypy.yml)
[![Pylint](https://github.com/NKI-AI/ahcore/actions/workflows/pylint.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/pylint.yml)
[![Black](https://github.com/NKI-AI/ahcore/actions/workflows/black.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/black.yml)
[![Run Precommit Checks on PR Approval](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml)
[![codecov](https://codecov.io/gh/NKI-AI/ahcore/branch/main/graph/badge.svg?token=OIJ7F9G7OO)](https://codecov.io/gh/NKI-AI/ahcore)

Ahcore are the [AI for Oncology](https://aiforoncology.nl) core components for computational pathology.
Expand Down
89 changes: 78 additions & 11 deletions ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dlup.tiling import Grid, GridOrder, TilingMode
from dlup.writers import TiffCompression, TifffileImageWriter
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import TrainerFn
from shapely.geometry import MultiPoint, Point
from torch.utils.data import Dataset

Expand Down Expand Up @@ -286,7 +287,7 @@ def __init__(self, max_queue_size: int, max_concurrent_writers: int, dump_dir: P
self._dump_dir = Path(dump_dir)
self._max_queue_size = max_queue_size
self._semaphore = Semaphore(max_concurrent_writers)
self._validation_index = 0
self._dataset_index = 0

self._logger = get_logger(type(self).__name__)

Expand Down Expand Up @@ -314,15 +315,15 @@ def __process_management(self) -> None:
def writers(self):
return self._writers

def on_validation_batch_end(
def _batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
):
filename = batch["path"][0] # Filenames are constant across the batch.
if any([filename != path for path in batch["path"]]):
raise ValueError(
Expand Down Expand Up @@ -350,10 +351,18 @@ def on_validation_batch_end(
self._semaphore.release()

self._semaphore.acquire()
validate_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore

if trainer.state.fn == TrainerFn.VALIDATING:
total_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore
elif trainer.state.fn == TrainerFn.PREDICTING:
total_dataset: ConcatDataset = trainer.predict_dataloaders.dataset # type: ignore
else:
raise NotImplementedError(
f"TrainerFn {trainer.state.fn} is not supported for {self.__class__.__name__}."
)

current_dataset: TiledROIsSlideImageDataset
current_dataset, _ = validate_dataset.index_to_dataset(self._validation_index) # type: ignore
current_dataset, _ = total_dataset.index_to_dataset(self._dataset_index) # type: ignore
slide_image = current_dataset.slide_image

data_description: DataDescription = pl_module.data_description # type: ignore
Expand Down Expand Up @@ -395,18 +404,46 @@ def on_validation_batch_end(
coordinates_x, coordinates_y = batch["coordinates"]
coordinates = torch.stack([coordinates_x, coordinates_y]).T.detach().cpu().numpy()
self._writers[filename]["queue"].put((coordinates, prediction))
self._validation_index += prediction.shape[0]
self._dataset_index += prediction.shape[0]

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if self._current_filename is not None:
self.__process_management()
self._semaphore.release()
self._validation_index = 0
self._dataset_index = 0
# Reset current filename to None for correct execution of subsequent validation loop
self._current_filename = None
# Clear all the writers from the current epoch
self._writers = {}

def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_predict_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._epoch_end(trainer, pl_module)

def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._epoch_end(trainer, pl_module)

@staticmethod
def generator(
queue: Queue[Optional[GenericArray]], # pylint: disable=unsubscriptable-object
Expand Down Expand Up @@ -511,7 +548,7 @@ def setup(
assert _callback.dump_dir, "_callback.dump_dir should never be None after the setup."
self._dump_dir = _callback.dump_dir

def on_validation_batch_end(
def _batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
Expand All @@ -532,7 +569,7 @@ def on_validation_batch_end(
)
self._filenames[filename] = output_filename

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert self.dump_dir, "dump_dir should never be None here."
self._logger.info("Writing TIFF files to %s", self.dump_dir / "outputs" / f"{pl_module.name}")
results = []
Expand Down Expand Up @@ -564,6 +601,34 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo
result.get() # Wait for the process to complete.
self._filenames = {} # Reset the filenames

def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_predict_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._epoch_end(trainer, pl_module)

def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._epoch_end(trainer, pl_module)


# Create a data structure to hold all required information for each task
TaskData = namedtuple("TaskData", ["filename", "h5_filename", "metadata", "mask", "annotations"])
Expand Down Expand Up @@ -776,7 +841,9 @@ def on_validation_batch_end(
"Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler."
)

def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> list[list[dict[str, dict[str, float]]]]:
def compute_metrics(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> list[list[dict[str, dict[str, float]]]]:
assert self._dump_dir
assert self._data_description
assert self._validate_metadata
Expand Down
15 changes: 15 additions & 0 deletions ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def construct_dataset() -> ConcatDataset[DlupDatasetSample]:
drop_last=True,
)

elif stage == "predict":
batch_sampler = ahcore.data.samplers.WsiBatchSamplerPredict(
dataset=dataset,
batch_size=batch_size,
)

else:
batch_sampler = ahcore.data.samplers.WsiBatchSampler(
dataset=dataset,
Expand Down Expand Up @@ -235,6 +241,15 @@ def test_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]:
self._validate_data_iterator, batch_size=batch_size, stage="test"
)

def predict_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]:
if not self._predict_data_iterator:
self.setup("predict")
batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size
assert self._predict_data_iterator
return self._construct_concatenated_dataloader(
self._predict_data_iterator, batch_size=batch_size, stage="predict"
)

def teardown(self, stage: str | None = None) -> None:
if stage is not None:
getattr(self, f"_{stage}_data_iterator").__del__()
Expand Down
64 changes: 59 additions & 5 deletions ahcore/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Generator, List

from dlup.data.dataset import ConcatDataset, TiledROIsSlideImageDataset
from torch.utils.data import Sampler
from torch.utils.data import Sampler, SequentialSampler

from ahcore.utils.io import get_logger

Expand All @@ -18,7 +18,7 @@ class WsiBatchSampler(Sampler[List[int]]):
def __init__(self, dataset: ConcatDataset[TiledROIsSlideImageDataset], batch_size: int) -> None:
super().__init__(data_source=dataset)
self._dataset = dataset
self._batch_size = batch_size
self.batch_size = batch_size

self._slices: List[slice] = []
self._populate_slices()
Expand All @@ -35,7 +35,7 @@ def __iter__(self) -> Generator[List[int], None, None]:
# Within each slice, create batches of size self._batch_size
for idx in range(slice_.start, slice_.stop):
batch.append(idx)
if len(batch) == self._batch_size:
if len(batch) == self.batch_size:
yield batch
batch = []

Expand All @@ -45,12 +45,66 @@ def __iter__(self) -> Generator[List[int], None, None]:

def __len__(self) -> int:
# The total number of batches is the sum of the number of batches in each slice
return sum(math.ceil((s.stop - s.start) / self._batch_size) for s in self._slices)
return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"batch_size={self._batch_size}, "
f"batch_size={self.batch_size}, "
f"num_batches={self.__len__()}, "
f"num_wsis={len(self._dataset.datasets)})"
)


class WsiBatchSamplerPredict(Sampler[List[int]]):
"""This Sampler is identical to the WsiBatchSampler,
but its signature is changed for compatibility with the predict phase of Lightning."""

def __init__(
self,
sampler: SequentialSampler | None = None,
batch_size: int | None = None,
drop_last: bool | None = None,
dataset: ConcatDataset[TiledROIsSlideImageDataset] | None = None,
) -> None:
if sampler is not None: # During the predict phase, the sampler is passed as a parameter
self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = sampler.data_source # type: ignore
else:
self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = dataset # type: ignore
super().__init__(data_source=self._dataset)
self.batch_size = batch_size

self._slices: List[slice] = []
self._populate_slices()

def _populate_slices(self) -> None:
for idx, _ in enumerate(self._dataset.datasets):
slice_start = 0 if len(self._slices) == 0 else self._slices[-1].stop
slice_stop = self._dataset.cumulative_sizes[idx]
self._slices.append(slice(slice_start, slice_stop))

def __iter__(self) -> Generator[List[int], None, None]:
for slice_ in self._slices:
batch = []
# Within each slice, create batches of size self._batch_size
for idx in range(slice_.start, slice_.stop):
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []

# If there are remaining items that couldn't form a full batch, yield them as a smaller batch
if len(batch) > 0:
yield batch

def __len__(self) -> int:
# The total number of batches is the sum of the number of batches in each slice
return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"batch_size={self.batch_size}, "
f"num_batches={self.__len__()}, "
f"num_wsis={len(self._dataset.datasets)})"
)
11 changes: 5 additions & 6 deletions ahcore/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,11 @@ def inference(config: DictConfig) -> None:
raise NotImplementedError(f"No augmentations target found in <{config.augmentations[stage]}>")
logger.info(f"Instantiating {stage} augmentations <{config.augmentations[stage]._target_}>") # noqa
augmentations[stage] = hydra.utils.instantiate(
config.augmentations[stage], data_description=data_description
config.augmentations[stage],
data_description=data_description,
data_module=datamodule,
_convert_="object",
)

# Init lightning model
if not config.lit_module.get("_target_"):
raise NotImplementedError(f"No model target found in <{config.lit_module}>")
Expand Down Expand Up @@ -266,10 +268,7 @@ def inference(config: DictConfig) -> None:

# Inference
logger.info("Starting inference...")
predict_dataloader = datamodule.predict_dataloader()
for metadata, dataloader in predict_dataloader:
model.predict_metadata = metadata # update the metadata to the current WSI
trainer.predict(model=model, dataloaders=dataloader)
trainer.predict(model=model, datamodule=datamodule)

# Make sure everything closed properly
logger.info("Finalizing...")
16 changes: 11 additions & 5 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,19 @@ def validation_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, An
return output

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
if self._augmentations:
if self._augmentations and "predict" in self._augmentations:
batch = self._augmentations["predict"](batch)

inputs = batch["image"]
predictions = self._model(inputs)
gathered_predictions = self.all_gather(predictions)
return gathered_predictions
_relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS}
batch = {**batch, **self._get_inference_prediction(batch["image"])}
_prediction = batch["prediction"]
output = {"prediction": _prediction, **_relevant_dict}

# This is a sanity check. We expect the filenames to be constant across the batch.
filename = batch["path"][0]
if any([filename != f for f in batch["path"]]):
raise ValueError("Filenames are not constant across the batch.")
return output

def configure_optimizers(self):
optimizer = self.hparams.optimizer(params=self.parameters())
Expand Down
6 changes: 3 additions & 3 deletions ahcore/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class GridDescription(BaseModel):


class DataDescription(BaseModel):
mask_label: Optional[str]
mask_threshold: Optional[float] # This is only used for training
roi_name: Optional[str]
mask_label: Optional[str] = None
mask_threshold: Optional[float] = None # This is only used for training
roi_name: Optional[str] = None
num_classes: PositiveInt
data_dir: Path
manifest_database_uri: str
Expand Down
1 change: 1 addition & 0 deletions ahcore/utils/database_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class CategoryEnum(PyEnum):
TRAIN = "fit"
VALIDATE = "validate"
TEST = "test"
PREDICT = "predict"


class Base(DeclarativeBase):
Expand Down
Loading

0 comments on commit fdc8ad1

Please sign in to comment.