Skip to content

Commit

Permalink
Refactor callback
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Sep 19, 2023
1 parent 99f72d6 commit b5b361e
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 83 deletions.
211 changes: 142 additions & 69 deletions ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import multiprocessing
import time
from collections import namedtuple
from multiprocessing import Pipe, Process, Queue, Semaphore
from multiprocessing.connection import Connection
from pathlib import Path
Expand Down Expand Up @@ -372,18 +373,17 @@ def on_validation_batch_end(
self._semaphore.release()

self._semaphore.acquire()
validation_dataset: ConcatDataset = pl_module.validation_dataset # type: ignore
validate_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore

current_dataset: TiledROIsSlideImageDataset
current_dataset, _ = validation_dataset.index_to_dataset(self._validation_index) # type: ignore
current_dataset, _ = validate_dataset.index_to_dataset(self._validation_index) # type: ignore
slide_image = current_dataset.slide_image

data_description: DataDescription = pl_module.data_description # type: ignore
inference_grid: GridDescription = data_description.inference_grid

mpp = inference_grid.mpp
if mpp is None:
self._logger.info("mpp is not set. Retrieving from slide image.")
mpp = slide_image.mpp

size = slide_image.get_scaled_size(slide_image.get_scaling(mpp))
Expand Down Expand Up @@ -486,7 +486,6 @@ def __init__(self, max_concurrent_writers: int):
# TODO: Handle tile operation such that we avoid repetitions.

self._tile_process_function = tile_process_function # function that is applied to the tile.

self._filenames: dict[Path, Path] = {} # This has all the h5 files

@property
Expand All @@ -498,7 +497,12 @@ def _validate_parameters(self):
if not dump_dir:
raise ValueError("Dump directory is not set.")

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
def setup(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
stage: Optional[str] = None,
) -> None:
_callback: Optional[WriteH5Callback] = None
for idx, callback in enumerate(trainer.callbacks): # type: ignore
if isinstance(callback, WriteH5Callback):
Expand Down Expand Up @@ -566,6 +570,92 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo
self._filenames = {} # Reset the filenames


# Create a data structure to hold all required information for each task
TaskData = namedtuple("TaskData", ["filename", "h5_filename", "metadata", "mask", "annotations"])


def prepare_task_data(filename, dump_dir, pl_module, data_description, data_manager):
h5_filename = _get_h5_output_filename(
dump_dir=dump_dir,
input_path=data_description.data_dir / filename,
model_name=str(pl_module.name),
step=pl_module.global_step,
)
image = data_manager.get_image_by_filename(str(filename))
metadata = fetch_image_metadata(image)
mask, annotations = get_mask_and_annotations_from_record(data_description.annotations_dir, image)

return TaskData(filename, h5_filename, metadata, mask, annotations)


def compute_metrics_for_case(
task_data: TaskData,
class_names,
data_description,
wsi_metrics,
save_per_image: bool,
):
# Extract the data from the namedtuple
filename, h5_filename, metadata, mask, annotations = task_data

dump_list = []

with H5FileImageReader(h5_filename, stitching_mode=StitchingMode.CROP) as h5reader:
dataset_of_validation_image = _ValidationDataset(
data_description=data_description,
native_mpp=metadata.mpp,
mask=mask,
annotations=annotations,
reader=h5reader,
)
for sample in dataset_of_validation_image:
prediction = torch.from_numpy(sample["prediction"]).unsqueeze(0).float()
target = torch.from_numpy(sample["target"]).unsqueeze(0)
roi = torch.from_numpy(sample["roi"]).unsqueeze(0)

wsi_metrics.process_batch(
predictions=prediction,
target=target,
roi=roi,
wsi_name=str(filename),
)
if save_per_image:
wsi_metrics_dictionary = {
"image_fn": str(data_description.data_dir / metadata.filename),
"uuid": filename.stem,
}
if filename.with_suffix(".tiff").is_file():
wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff"))
if filename.is_file():
wsi_metrics_dictionary["h5_fn"] = str(filename)
for metric in wsi_metrics._metrics:
metric.get_wsi_score(str(filename))
wsi_metrics_dictionary[metric.name] = {
class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item()
for class_idx in range(data_description.num_classes)
}
dump_list.append(wsi_metrics_dictionary)

return dump_list


# Adjusted stand-alone function.
def schedule_task(
task_data,
pool,
results_dict,
class_names,
data_description,
wsi_metrics,
save_per_image,
):
result = pool.apply_async(
compute_metrics_for_case,
args=(task_data, class_names, data_description, wsi_metrics, save_per_image),
)
results_dict[result] = task_data.filename


class ComputeWsiMetricsCallback(Callback):
def __init__(self, max_processes=10, save_per_image: bool = True):
"""
Expand All @@ -592,7 +682,12 @@ def __init__(self, max_processes=10, save_per_image: bool = True):
self._dump_list: list[dict[str, str]] = []
self._logger = get_logger(type(self).__name__)

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
def setup(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
stage: Optional[str] = None,
) -> None:
pl_module = cast(AhCoreLightningModule, pl_module)

_callback: Optional[WriteH5Callback] = None
Expand Down Expand Up @@ -629,7 +724,9 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optio

self._validate_metadata_gen = self._create_validate_image_metadata_gen()

def _create_validate_image_metadata_gen(self) -> Generator[ImageMetadata, None, None]:
def _create_validate_image_metadata_gen(
self,
) -> Generator[ImageMetadata, None, None]:
assert self._data_description
assert self._data_manager
gen = self._data_manager.get_image_metadata_by_split(
Expand Down Expand Up @@ -661,7 +758,7 @@ def on_validation_batch_end(
raise ValueError("Dump directory is not set.")

filenames = batch["path"] # Filenames are constant across the batch.
if not len(set(filenames)) != 1:
if len(set(filenames)) != 1:
raise ValueError(
"All paths in a batch must be the same. "
"Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler."
Expand All @@ -670,7 +767,6 @@ def on_validation_batch_end(
def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
assert self._dump_dir
assert self._data_description
data_dir = self._data_description.data_dir
metrics = []

with multiprocessing.Pool(processes=self._max_processes) as pool:
Expand All @@ -679,17 +775,26 @@ def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule):

# Fill up the initial task pool
for image_metadata in itertools.islice(self._validate_metadata, self._max_processes):
filename = image_metadata.filename
h5_filename = _get_h5_output_filename(
dump_dir=self._dump_dir,
input_path=data_dir / filename,
model_name=str(pl_module.name),
step=pl_module.global_step,
# Assemble the task data
# filename", "h5_filename", "metadata", "mask", "annotations
task_data = prepare_task_data(
image_metadata.filename,
self._dump_dir,
pl_module,
self._data_description,
self._data_manager,
)

# We need the full filename here so
result = pool.apply_async(self.compute_metrics_for_case, args=(filename, h5_filename))
results_to_filename[result] = filename
# Schedule task
schedule_task(
task_data,
pool,
results_to_filename,
self._class_names,
self._data_description,
self._wsi_metrics,
self._save_per_image,
)

while results_to_filename:
time.sleep(0.1) # Reduce excessive polling
Expand All @@ -709,60 +814,28 @@ def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule):

# Schedule a new task if there are more filenames left in the generator
next_metadata = next(self._validate_metadata, None)
next_filename = next_metadata.filename if next_metadata else None
if next_filename:
new_result = pool.apply_async(self.compute_metrics_for_case, args=(next_filename,))
results_to_filename[new_result] = next_filename
if next_metadata:
task_data = prepare_task_data(
image_metadata.filename,
self._dump_dir,
pl_module,
self._data_description,
self._data_manager,
)

# Schedule task
schedule_task(
task_data,
pool,
results_to_filename,
self._class_names,
self._data_description,
self._wsi_metrics,
self._save_per_image,
)

return metrics

def compute_metrics_for_case(self, filename: Path, h5_filename: Path):
# for mypy
assert self._class_names
assert self._data_description
assert self._wsi_metrics
# Given the filename we can request all the annotations and everything else.
# We can then compute the metrics for this case.
image = self._data_manager.get_image_by_filename(str(filename))
metadata = fetch_image_metadata(image)

with H5FileImageReader(h5_filename, stitching_mode=StitchingMode.CROP) as h5reader:
mask, annotations = get_mask_and_annotations_from_record(self._data_description.annotation_dir, image)
dataset_of_validation_image = _ValidationDataset(
data_description=self._data_description,
native_mpp=metadata.mpp,
mask=mask,
annotations=annotations,
reader=h5reader,
)
for sample in dataset_of_validation_image:
prediction = torch.from_numpy(sample["prediction"]).unsqueeze(0).float()
target = torch.from_numpy(sample["target"]).unsqueeze(0)
roi = torch.from_numpy(sample["roi"]).unsqueeze(0)

self._wsi_metrics.process_batch(
predictions=prediction,
target=target,
roi=roi,
wsi_name=str(filename),
)
if self._save_per_image is True:
wsi_metrics_dictionary = {
"image_fn": str(self._data_description.data_dir / metadata.filename),
"uuid": filename.stem,
}
if filename.with_suffix(".tiff").is_file():
wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff"))
if filename.is_file():
wsi_metrics_dictionary["h5_fn"] = str(filename)
for metric in self._wsi_metrics._metrics:
metric.get_wsi_score(str(filename))
wsi_metrics_dictionary[metric.name] = {
self._class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item()
for class_idx in range(self._data_description.num_classes)
}
self._dump_list.append(wsi_metrics_dictionary)

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
if not self._dump_dir:
raise ValueError("Dump directory is not set.")
Expand Down
3 changes: 2 additions & 1 deletion ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ def construct_dataset() -> ConcatDataset:
datasets.append(ds)
return ConcatDataset(datasets)

self._logger.info("Constructing dataset for stage %s (this can take a while)", stage)
self._logger.info("Constructing dataset for stage %s (this can take a while)", stage.value)
dataset = self._load_from_cache(construct_dataset, stage=stage)
setattr(self, f"{stage}_dataset", dataset)

lengths = np.asarray([len(ds) for ds in dataset.datasets])
self._logger.info(
Expand Down
7 changes: 4 additions & 3 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def _get_inference_prediction(self, _input: torch.Tensor) -> dict[str, torch.Ten
return output

def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]:
if self.global_step == 0:
if self._tensorboard:
self._tensorboard.add_graph(self._model, batch["image"])
# TODO: This is problematic as you first need to pass through the augmentations to get the correct shape
# if self.global_step == 0:
# if self._tensorboard:
# self._tensorboard.add_graph(self._model, batch["image"])

output = self.do_step(batch, batch_idx, stage=TrainerFn.FITTING)
return output
Expand Down
10 changes: 8 additions & 2 deletions ahcore/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,14 @@ def read_region(
l1, l2 = location
s1, s2 = size

original_location = (int(math.floor(l1 / scaling)) - order, int(math.floor(l2 / scaling)) - order)
original_size = (int(math.ceil(s1 / scaling)) + order, int(math.ceil(s2 / scaling)) + order)
original_location = (
int(math.floor(l1 / scaling)) - order,
int(math.floor(l2 / scaling)) - order,
)
original_size = (
int(math.ceil(s1 / scaling)) + order,
int(math.ceil(s2 / scaling)) + order,
)

raw_region = self.read_region_raw(original_location, original_size)

Expand Down
8 changes: 7 additions & 1 deletion ahcore/transforms/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ class Identity(K.AugmentationBase2D):
Identity transform.
"""

def __init__(self, p: float = 1.0, p_batch: float = 1.0, same_on_batch: bool = True, keepdim: bool = True):
def __init__(
self,
p: float = 1.0,
p_batch: float = 1.0,
same_on_batch: bool = True,
keepdim: bool = True,
):
if p != 1.0 or p_batch != 1.0 or not same_on_batch:
raise ValueError("Identity is always applied. No probabilities can be applied.")

Expand Down
Loading

0 comments on commit b5b361e

Please sign in to comment.