Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricMarcus-ai committed Oct 4, 2023
1 parent 0f341a1 commit 9797db1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions ahcore/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch.utils.data import Dataset

from ahcore.lit_module import AhCoreLightningModule
from ahcore.metrics import WSIMetric
from ahcore.metrics import WSIMetricFactory
from ahcore.readers import H5FileImageReader, StitchingMode
from ahcore.transforms.pre_transforms import one_hot_encoding
from ahcore.utils.data import DataDescription, GridDescription
Expand Down Expand Up @@ -633,7 +633,7 @@ def compute_metrics_for_case(
wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff"))
if filename.is_file():
wsi_metrics_dictionary["h5_fn"] = str(filename)
for metric in wsi_metrics._tile_metric:
for metric in wsi_metrics._metrics:
metric.get_wsi_score(str(filename))
wsi_metrics_dictionary[metric.name] = {
class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item()
Expand Down Expand Up @@ -679,7 +679,7 @@ def __init__(self, max_processes: int = 10, save_per_image: bool = True) -> None
self._save_per_image = save_per_image
self._filenames: dict[Path, Path] = {}

self._wsi_metrics: WSIMetric | None = None
self._wsi_metrics: WSIMetricFactory | None = None
self._class_names: dict[int, str] = {}
self._data_manager = None
self._validate_filenames_gen = None
Expand Down Expand Up @@ -776,7 +776,7 @@ def on_validation_batch_end(
"Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler."
)

def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> list[dict[str, Any]]:
def compute_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> list[list[dict[str, dict[str, float]]]]:
assert self._dump_dir
assert self._data_description
assert self._validate_metadata
Expand Down
14 changes: 7 additions & 7 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn

from ahcore.exceptions import ConfigurationError
from ahcore.metrics import TileMetric, WSIMetric
from ahcore.metrics import MetricFactory, WSIMetricFactory
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger

Expand All @@ -37,7 +37,7 @@ def __init__(
data_description: DataDescription,
loss: nn.Module | None = None,
augmentations: dict[str, nn.Module] | None = None,
metrics: dict[str, WSIMetric | TileMetric] | None = None,
metrics: dict[str, MetricFactory | WSIMetricFactory] | None = None,
scheduler: Any | None = None, # noqa
):
super().__init__()
Expand All @@ -61,18 +61,18 @@ def __init__(
if metrics is not None:
tile_metric = metrics.get("tile_level")
wsi_metric = metrics.get("wsi_level", None)
if tile_metric is not None and not isinstance(tile_metric, TileMetric):
raise ConfigurationError("Tile metrics must be of type TileMetric")
if wsi_metric is not None and not isinstance(wsi_metric, WSIMetric):
raise ConfigurationError("WSI metrics must be of type WSIMetric")
if tile_metric is not None and not isinstance(tile_metric, MetricFactory):
raise ConfigurationError("Tile metrics must be of type MetricFactory")
if wsi_metric is not None and not isinstance(wsi_metric, WSIMetricFactory):
raise ConfigurationError("WSI metrics must be of type WSIMetricFactory")

self._tile_metric = tile_metric
self._wsi_metrics = wsi_metric

self._data_description = data_description

@property
def wsi_metrics(self) -> WSIMetric | None:
def wsi_metrics(self) -> WSIMetricFactory | None:
return self._wsi_metrics

@property
Expand Down
1 change: 0 additions & 1 deletion ahcore/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import h5py
import numpy as np
import numpy.typing as npt
from scipy.ndimage import map_coordinates

from ahcore.utils.io import get_logger
Expand Down

0 comments on commit 9797db1

Please sign in to comment.