From f7a5e0ded025758feaba9a363f956f74e961b663 Mon Sep 17 00:00:00 2001 From: Jonas Teuwen Date: Mon, 18 Sep 2023 21:52:59 +0200 Subject: [PATCH] Fix metric computation callback --- .mypy.ini | 2 +- .pre-commit-config.yaml | 49 ++++ ahcore/callbacks.py | 411 +++++++++++++++++------------- ahcore/data/dataset.py | 15 +- ahcore/lit_module.py | 3 +- ahcore/utils/manifest.py | 42 ++- ahcore/utils/manifest_database.py | 72 ++++-- tools/convert_wsi_to_h5.py | 1 - 8 files changed, 379 insertions(+), 216 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.mypy.ini b/.mypy.ini index e9b9eee..159fa83 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -34,4 +34,4 @@ ignore_missing_imports = True # TODO: This needs to be fixed obviously [mypy-dlup.*] -ignore_errors = True \ No newline at end of file +ignore_errors = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1d75036 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,49 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black +- repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + # Ignore the configuration files + hooks: + - id: flake8 + exclude: ^docs/ +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) +- repo: local # Use pylint from local environment as it requires to import packages + hooks: + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + args: [--strict] +- repo: https://github.com/Yelp/detect-secrets + rev: v1.4.0 + hooks: + - id: detect-secrets +- repo: local # Use pylint from local environment as it requires to import packages + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + args: + [ + "-rn", # Only display messages + "-sn", # Don't display the score + ] diff --git a/ahcore/callbacks.py b/ahcore/callbacks.py index 010486c..180792d 100644 --- a/ahcore/callbacks.py +++ b/ahcore/callbacks.py @@ -1,13 +1,15 @@ from __future__ import annotations -import concurrent.futures import hashlib +import itertools import json import logging -from multiprocessing import Pipe, Pool, Process, Queue, Semaphore +import multiprocessing +import time +from multiprocessing import Pipe, Process, Queue, Semaphore from multiprocessing.connection import Connection from pathlib import Path -from typing import Any, Iterator, Optional, TypedDict, cast +from typing import Any, Generator, Iterator, Optional, TypedDict, cast import numpy as np import numpy.typing as npt @@ -24,10 +26,13 @@ from shapely.geometry import MultiPoint, Point from torch.utils.data import Dataset +from ahcore.lit_module import AhCoreLightningModule from ahcore.readers import H5FileImageReader, StitchingMode from ahcore.transforms.pre_transforms import one_hot_encoding from ahcore.utils.data import DataDescription, GridDescription from ahcore.utils.io import get_logger +from ahcore.utils.manifest import get_mask_and_annotations_from_record +from ahcore.utils.manifest_database import DataManager, ImageMetadata, fetch_image_metadata from ahcore.writers import H5FileImageWriter logger = get_logger(__name__) @@ -244,7 +249,7 @@ def _get_uuid_for_filename(input_path: Path) -> str: return hex_dig -def _get_output_filename(dump_dir: Path, input_path: Path, model_name: str, step: None | int | str = None) -> Path: +def _get_h5_output_filename(dump_dir: Path, input_path: Path, model_name: str, step: None | int | str = None) -> Path: hex_dig = _get_uuid_for_filename(input_path=input_path) # Return the hashed filename with the new extension @@ -350,7 +355,7 @@ def on_validation_batch_end( ) if filename != self._current_filename: - output_filename = _get_output_filename( + output_filename = _get_h5_output_filename( self.dump_dir, filename, model_name=str(pl_module.name), @@ -471,7 +476,7 @@ def tile_process_function(x): class WriteTiffCallback(Callback): def __init__(self, max_concurrent_writers: int): - self._pool = Pool(max_concurrent_writers) + self._pool = multiprocessing.Pool(max_concurrent_writers) self._logger = get_logger(type(self).__name__) self._dump_dir: Optional[Path] = None self.__write_h5_callback_index = -1 @@ -520,7 +525,7 @@ def on_validation_batch_end( filename = Path(batch["path"][0]) # Filenames are constant across the batch. if filename not in self._filenames: - output_filename = _get_output_filename( + output_filename = _get_h5_output_filename( dump_dir=self.dump_dir, input_path=filename, model_name=str(pl_module.name), @@ -561,177 +566,221 @@ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningMo self._filenames = {} # Reset the filenames -# class ComputeWsiMetricsCallback(Callback): -# def __init__(self, max_threads=10, save_per_image: bool = True): -# """ -# Callback to compute metrics on whole-slide images. This callback is used to compute metrics on whole-slide -# images in separate threads. -# -# Parameters -# ---------- -# max_threads : int -# The maximum number of concurrent threads. -# """ -# self._data_description = None -# self._reader = H5FileImageReader -# self._dump_dir = None -# self._save_per_image = save_per_image -# self._filenames: dict[Path, Path] = {} -# self._logger = get_logger(type(self).__name__) -# self._semaphore = Semaphore(max_threads) # Limit the number of threads -# -# self._wsi_metrics = None -# -# self._dump_list: list[dict[str, str]] = [] -# -# def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: -# _callback: Optional[WriteH5Callback] = None -# for idx, callback in enumerate(trainer.callbacks): -# if isinstance(callback, WriteH5Callback): -# _callback = cast(WriteH5Callback, trainer.callbacks[idx]) -# break -# -# if _callback is None: -# raise ValueError( -# "WriteH5Callback is not in the trainer's callbacks. " -# "This is required before WSI metrics can be computed using this Callback" -# ) -# -# self._dump_dir = _callback.dump_dir -# -# self._wsi_metrics = pl_module.wsi_metrics -# self._data_description = trainer.datamodule.data_description # type: ignore -# -# if not self._data_description: -# raise ValueError("Data description is not set.") -# -# self._class_names = dict([(v, k) for k, v in self._data_description.index_map.items()]) -# self._class_names[0] = "background" -# -# # We should also attach the validation manifest to the class, but convert it to a dictionary mapping -# # the UUID -# data_dir = self._data_description.data_dir -# for manifest in trainer.datamodule.val_manifest: -# image_fn = data_dir / manifest.image[0] -# self._validation_manifests[_get_uuid_for_filename(image_fn)] = manifest -# -# self._logger.info("Added %s images to validation manifest.", len(self._validation_manifests)) -# -# @property -# def metrics(self): -# return self._metrics -# -# def on_validation_batch_end( -# self, -# trainer: pl.Trainer, -# pl_module: pl.LightningModule, -# outputs, -# batch, -# batch_idx, -# dataloader_idx=0, -# ): -# if not self._dump_dir: -# raise ValueError("Dump directory is not set.") -# -# filename = Path(batch["path"][0]) # Filenames are constant across the batch. -# if filename not in self._filenames: -# output_filename = _get_output_filename( -# dump_dir=self._dump_dir, -# input_path=filename, -# model_name=pl_module.name, -# step=pl_module.global_step, -# ) -# self._logger.debug("%s -> %s", filename, output_filename) -# self._filenames[output_filename] = filename -# -# def compute_metrics(self): -# metrics = [] -# with concurrent.futures.ThreadPoolExecutor() as executor: -# future_to_filename = { -# executor.submit(self.compute_metrics_for_case, filename): filename for filename in self._filenames -# } -# -# for future in concurrent.futures.as_completed(future_to_filename): -# filename = future_to_filename[future] -# try: -# metric = future.result() -# except Exception as exc: -# self._logger.error("%r generated an exception: %s" % (filename, exc)) -# else: -# metrics.append(metric) -# self._logger.debug("Metric for %r is %s" % (filename, metric)) -# return metrics -# -# def compute_metrics_for_case(self, filename): -# wsi_filename = self._filenames[filename] -# validation_manifest = self._validation_manifests[_get_uuid_for_filename(wsi_filename)] -# native_mpp = validation_manifest.mpp -# with self._semaphore: # Only allow a certain number of threads to compute metrics concurrently -# # Compute the metric for one filename here... -# with H5FileImageReader(filename, stitching_mode=StitchingMode.CROP) as h5reader: -# mask = _parse_annotations( -# validation_manifest.mask, -# base_dir=self._data_description.annotations_dir, -# ) -# annotations = _parse_annotations( -# validation_manifest.annotations, -# base_dir=self._data_description.annotations_dir, -# ) -# dataset_of_validation_image = _ValidationDataset( -# data_description=self._data_description, -# native_mpp=native_mpp, -# mask=mask, -# annotations=annotations, -# reader=h5reader, -# ) -# for sample in dataset_of_validation_image: -# prediction = torch.from_numpy(sample["prediction"]).unsqueeze(0).float() -# target = torch.from_numpy(sample["target"]).unsqueeze(0) -# roi = torch.from_numpy(sample["roi"]).unsqueeze(0) -# -# self._wsi_metrics.process_batch( -# predictions=prediction, -# target=target, -# roi=roi, -# wsi_name=str(filename), -# ) -# if self._save_per_image is True: -# wsi_metrics_dictionary = { -# "image_fn": str(wsi_filename), -# "uuid": filename.stem, -# } -# if filename.with_suffix(".tiff").is_file(): -# wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff")) -# if filename.is_file(): -# wsi_metrics_dictionary["h5_fn"] = str(filename) -# for metric in self._wsi_metrics._metrics: -# metric.get_wsi_score(str(filename)) -# wsi_metrics_dictionary[metric.name] = { -# self._class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item() -# for class_idx in range(self._data_description.num_classes) -# } -# self._dump_list.append(wsi_metrics_dictionary) -# -# def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): -# if not self._dump_dir: -# raise ValueError("Dump directory is not set.") -# if not self._wsi_metrics: -# raise ValueError("WSI metrics are not set.") -# -# # Ensure that all h5 files have been written -# self._logger.debug("Computing metrics for %s predictions", len(self._filenames)) -# self.compute_metrics() -# metrics = self._wsi_metrics.get_average_score() -# with open( -# self._dump_dir / "outputs" / pl_module.name / f"step_{pl_module.global_step}" / "results.json", -# "w", -# encoding="utf-8", -# ) as json_file: -# json.dump(self._dump_list, json_file, indent=2) -# self._wsi_metrics.reset() -# # Reset stuff -# self._dump_list = [] -# self._filenames = {} -# -# self._logger.debug("Metrics: %s", metrics) -# pl_module.log_dict(metrics, prog_bar=True) +class ComputeWsiMetricsCallback(Callback): + def __init__(self, max_processes=10, save_per_image: bool = True): + """ + Callback to compute metrics on whole-slide images. This callback is used to compute metrics on whole-slide + images in separate processes. + + Parameters + ---------- + max_processes : int + The maximum number of concurrent processes. + """ + self._data_description: Optional[DataDescription] = None + self._reader = H5FileImageReader + self._max_processes: int = max_processes + self._dump_dir: Optional[Path] = None + self._save_per_image = save_per_image + self._filenames: dict[Path, Path] = {} + + self._wsi_metrics = None + self._class_names: dict[int, str] = {} + self._data_manager = None + self._validate_filenames_gen = None + + self._dump_list: list[dict[str, str]] = [] + self._logger = get_logger(type(self).__name__) + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: + assert self._data_description + index_map = self._data_description.index_map + assert index_map + pl_module = cast(AhCoreLightningModule, pl_module) + + _callback: Optional[WriteH5Callback] = None + for idx, callback in enumerate(trainer.callbacks): # type: ignore + if isinstance(callback, WriteH5Callback): + _callback = cast(WriteH5Callback, trainer.callbacks[idx]) # type: ignore + break + + if _callback is None: + raise ValueError( + "WriteH5Callback is not in the trainer's callbacks. " + "This is required before WSI metrics can be computed using this Callback" + ) + + self._dump_dir = _callback.dump_dir + + self._wsi_metrics = pl_module.wsi_metrics + self._data_description = trainer.datamodule.data_description # type: ignore + + if not self._data_description: + raise ValueError("Data description is not set.") + + self._class_names = dict([(v, k) for k, v in index_map.items()]) + self._class_names[0] = "background" + + # Here we can query the database for the validation images + self._data_manager: DataManager = trainer.datamodule.data_manager # type: ignore + # Initialize the generator here + + self._validate_metadata_gen = self._create_validate_image_metadata_gen() + + def _create_validate_image_metadata_gen(self) -> Generator[ImageMetadata, None, None]: + assert self._data_description + assert self._data_manager + gen = self._data_manager.get_image_metadata_by_split( + manifest_name=self._data_description.manifest_name, + split_version=self._data_description.split_version, + split_category="validate", + ) + for image_metadata in gen: + yield image_metadata + + @property + def _validate_metadata(self) -> Generator[ImageMetadata, None, None]: + return self._validate_metadata_gen + + @property + def metrics(self): + return self._metrics + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx=0, + ): + if not self._dump_dir: + raise ValueError("Dump directory is not set.") + + filenames = batch["path"] # Filenames are constant across the batch. + if not len(set(filenames)) != 1: + raise ValueError( + "All paths in a batch must be the same. " + "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler." + ) + + def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + assert self._dump_dir + assert self._data_description + data_dir = self._data_description.data_dir + metrics = [] + + with multiprocessing.Pool(processes=self._max_processes) as pool: + results_to_filename = {} + completed_tasks = 0 + + # Fill up the initial task pool + for image_metadata in itertools.islice(self._validate_metadata, self._max_processes): + filename = image_metadata.filename + h5_filename = _get_h5_output_filename( + dump_dir=self._dump_dir, + input_path=data_dir / filename, + model_name=str(pl_module.name), + step=pl_module.global_step, + ) + + # We need the full filename here so + result = pool.apply_async(self.compute_metrics_for_case, args=(filename, h5_filename)) + results_to_filename[result] = filename + + while results_to_filename: + time.sleep(0.1) # Reduce excessive polling + # Check for completed tasks + for result in list(results_to_filename.keys()): + if result.ready(): + filename = results_to_filename.pop(result) + try: + metric = result.get() + except Exception as exc: + self._logger.error("%r generated an exception: %s" % (filename, exc)) + else: + metrics.append(metric) + self._logger.debug("Metric for %r is %s" % (filename, metric)) + + completed_tasks += 1 + + # Schedule a new task if there are more filenames left in the generator + next_metadata = next(self._validate_metadata, None) + next_filename = next_metadata.filename if next_metadata else None + if next_filename: + new_result = pool.apply_async(self.compute_metrics_for_case, args=(next_filename,)) + results_to_filename[new_result] = next_filename + + return metrics + + def compute_metrics_for_case(self, filename: Path, h5_filename: Path): + # for mypy + assert self._class_names + assert self._data_description + assert self._wsi_metrics + # Given the filename we can request all the annotations and everything else. + # We can then compute the metrics for this case. + image = self._data_manager.get_image_by_filename(str(filename)) + metadata = fetch_image_metadata(image) + + with H5FileImageReader(h5_filename, stitching_mode=StitchingMode.CROP) as h5reader: + mask, annotations = get_mask_and_annotations_from_record(self._data_description.annotation_dir, image) + dataset_of_validation_image = _ValidationDataset( + data_description=self._data_description, + native_mpp=metadata.mpp, + mask=mask, + annotations=annotations, + reader=h5reader, + ) + for sample in dataset_of_validation_image: + prediction = torch.from_numpy(sample["prediction"]).unsqueeze(0).float() + target = torch.from_numpy(sample["target"]).unsqueeze(0) + roi = torch.from_numpy(sample["roi"]).unsqueeze(0) + + self._wsi_metrics.process_batch( + predictions=prediction, + target=target, + roi=roi, + wsi_name=str(filename), + ) + if self._save_per_image is True: + wsi_metrics_dictionary = { + "image_fn": str(self._data_description.data_dir / metadata.filename), + "uuid": filename.stem, + } + if filename.with_suffix(".tiff").is_file(): + wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff")) + if filename.is_file(): + wsi_metrics_dictionary["h5_fn"] = str(filename) + for metric in self._wsi_metrics._metrics: + metric.get_wsi_score(str(filename)) + wsi_metrics_dictionary[metric.name] = { + self._class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item() + for class_idx in range(self._data_description.num_classes) + } + self._dump_list.append(wsi_metrics_dictionary) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + if not self._dump_dir: + raise ValueError("Dump directory is not set.") + if not self._wsi_metrics: + raise ValueError("WSI metrics are not set.") + + # Ensure that all h5 files have been written + self._logger.debug("Computing metrics for %s predictions", len(self._filenames)) + self.compute_metrics(trainer, pl_module) + metrics = self._wsi_metrics.get_average_score() + with open( + self._dump_dir / "outputs" / pl_module.name / f"step_{pl_module.global_step}" / "results.json", + "w", + encoding="utf-8", + ) as json_file: + json.dump(self._dump_list, json_file, indent=2) + self._wsi_metrics.reset() + # Reset stuff + self._dump_list = [] + self._filenames = {} + + self._logger.debug("Metrics: %s", metrics) + pl_module.log_dict(metrics, prog_bar=True) diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index c2aebcc..bdd4ecd 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -20,17 +20,6 @@ from ahcore.utils.manifest_database import DataManager -def hello(params=None): - if params: - print("Hello", params) - else: - print("Hello") - - import sys - - sys.exit() - - class DlupDataModule(pl.LightningDataModule): """Datamodule for the Ahcore framework. This datamodule is based on `dlup`.""" @@ -117,6 +106,10 @@ def __init__( } self._num_classes = data_description.num_classes + @property + def data_manager(self) -> DataManager: + return self._data_manager + def setup(self, stage: Optional[str] = None) -> None: if not stage: return diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 6142ec9..fd7a53d 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -5,11 +5,10 @@ - Wrapping models""" from __future__ import annotations -from typing import Any, Optional, cast +from typing import Any import pytorch_lightning as pl import torch.optim.optimizer -from dlup.data.dataset import ConcatDataset from pytorch_lightning.trainer.states import TrainerFn from torch import nn from torch.utils.tensorboard import SummaryWriter diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 27a0195..da13eab 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -21,6 +21,7 @@ from pytorch_lightning.trainer.states import TrainerFn from ahcore.utils.data import DataDescription +from ahcore.utils.database_models import Image, ImageAnnotations from ahcore.utils.io import get_logger from ahcore.utils.manifest_database import DataManager from ahcore.utils.rois import compute_rois @@ -59,7 +60,22 @@ class _AnnotationReadersDict(TypedDict): _Stages = Enum("Stages", [(_, _) for _ in ["fit", "validate", "test", "predict"]]) # type: ignore -def _parse_annotations(annotations_root: Path, record): +def parse_annotations_from_record(annotations_root: Path, record: list[ImageAnnotations]): + """ + Parse the annotations from a record of type ImageAnnotations. + + Parameters + ---------- + annotations_root : Path + The root directory of the annotations. + record : list[Type[ImageAnnotations]] + The record containing the annotations. + + Returns + ------- + WsiAnnotations + The parsed annotations. + """ if record is None: return assert len(record) == 1 @@ -73,6 +89,27 @@ def _parse_annotations(annotations_root: Path, record): raise NotImplementedError +def get_mask_and_annotations_from_record(annotations_root: Path, record: Image): + """ + Get the mask and annotations from a record of type Image. + + Parameters + ---------- + annotations_root : Path + The root directory of the annotations. + record : Type[Image] + The record containing the mask and annotations. + + Returns + ------- + tuple[WsiAnnotations, WsiAnnotations] + The mask and annotations. + """ + _masks = parse_annotations_from_record(annotations_root, record.masks) + _annotations = parse_annotations_from_record(annotations_root, record.annotations) + return _masks, _annotations + + def _get_rois(mask, data_description: DataDescription, stage: str): if (mask is None) or (stage != TrainerFn.FITTING) or (not data_description.convert_mask_to_rois): return None @@ -104,8 +141,7 @@ def datasets_from_data_description(db_manager: DataManager, data_description: Da labels = [(label.key, label.value) for label in record.labels] if record.labels else None for image in record.images: - mask = _parse_annotations(annotations_root, image.masks) - annotations = _parse_annotations(annotations_root, image.annotations) + mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) rois = _get_rois(mask, data_description, stage) mask_threshold = 0.0 if stage != TrainerFn.FITTING else data_description.mask_threshold diff --git a/ahcore/utils/manifest_database.py b/ahcore/utils/manifest_database.py index 8945184..c5d1f81 100644 --- a/ahcore/utils/manifest_database.py +++ b/ahcore/utils/manifest_database.py @@ -1,10 +1,12 @@ # encoding: utf-8 from pathlib import Path from types import TracebackType -from typing import Generator, Literal, NamedTuple, Optional, Type +from typing import Generator, Literal, Optional, Type +from pydantic import AfterValidator, BaseModel from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing_extensions import Annotated from ahcore.utils.database_models import Base, Image, Manifest, Patient, Split, SplitDefinitions from ahcore.utils.io import get_logger @@ -15,11 +17,23 @@ class RecordNotFoundError(Exception): pass -class ImageMetadata(NamedTuple): +def is_positive(v: int | float) -> int | float: + assert v > 0, f"{v} is not a positive a positive {type(v)}" + return v + + +PositiveInt = Annotated[int, AfterValidator(is_positive)] +PositiveFloat = Annotated[float, AfterValidator(is_positive)] + + +class ImageMetadata(BaseModel): + class Config: + allow_mutation = False + filename: Path - height: int - width: int - mpp: float + height: PositiveInt + width: PositiveInt + mpp: PositiveFloat def open_db(database_uri: str): @@ -47,6 +61,13 @@ def get_or_create_patient(session, patient_code, manifest): return existing_patient +def fetch_image_metadata(image: Image) -> ImageMetadata: + """Extract metadata from an Image object.""" + return ImageMetadata( + filename=Path(image.filename), height=int(image.height), width=int(image.width), mpp=float(image.mpp) + ) + + class DataManager: def __init__(self, database_uri: str) -> None: self._database_uri = database_uri @@ -91,12 +112,29 @@ def get_records_by_split( for patient in patients: yield patient - @staticmethod - def _fetch_image_metadata(image: Image) -> ImageMetadata: - """Extract metadata from an Image object.""" - return ImageMetadata( - filename=Path(image.filename), height=int(image.height), width=int(image.width), mpp=float(image.mpp) - ) + def get_image_metadata_by_split( + self, manifest_name: str, split_version: str, split_category: Optional[str] = None + ) -> Generator[ImageMetadata, None, None]: + """ + Yields the metadata of images for a given manifest name, split version, and optional split category. + + Parameters + ---------- + manifest_name : str + The name of the manifest. + split_version : str + The version of the split. + split_category : Optional[str], default=None + The category of the split (e.g., "fit", "validate", "test"). + + Yields + ------- + ImageMetadata + The metadata of the image. + """ + for patient in self.get_records_by_split(manifest_name, split_version, split_category): + for image in patient.images: + yield fetch_image_metadata(image) def get_image_metadata_by_patient(self, patient_code: str) -> list[ImageMetadata]: """ @@ -115,9 +153,9 @@ def get_image_metadata_by_patient(self, patient_code: str) -> list[ImageMetadata patient = self._session.query(Patient).filter_by(patient_code=patient_code).first() # type: ignore self._ensure_record(patient, f"Patient with code {patient_code} not found") - return [self._fetch_image_metadata(image) for image in patient.images] + return [fetch_image_metadata(image) for image in patient.images] - def get_image_metadata_by_filename(self, filename: str) -> ImageMetadata: + def get_image_by_filename(self, filename: str) -> Type[Image]: """ Fetch the metadata for an image based on its filename. @@ -128,13 +166,13 @@ def get_image_metadata_by_filename(self, filename: str) -> ImageMetadata: Returns ------- - ImageMetadata - Metadata of the image. + Image + The image from the database. """ image = self._session.query(Image).filter_by(filename=filename).first() self._ensure_record(image, f"Image with filename {filename} not found") assert image - return self._fetch_image_metadata(image) + return image def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: """ @@ -152,7 +190,7 @@ def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: """ image = self._session.query(Image).filter_by(id=image_id).first() self._ensure_record(image, f"No image found with ID {image_id}") - return self._fetch_image_metadata(image) + return fetch_image_metadata(image) def __enter__(self) -> "DataManager": return self diff --git a/tools/convert_wsi_to_h5.py b/tools/convert_wsi_to_h5.py index 7483e77..f3cdb28 100644 --- a/tools/convert_wsi_to_h5.py +++ b/tools/convert_wsi_to_h5.py @@ -28,7 +28,6 @@ from rich.progress import Progress from ahcore.cli import dir_path, file_path -from ahcore.readers import H5FileImageReader from ahcore.writers import H5FileImageWriter logger = getLogger(__name__)