Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using @cached_property to cache @property calls. #1734

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions nnunetv2/utilities/label_handling/label_handling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

from functools import cached_property
from time import time
from typing import Union, List, Tuple, Type

Expand Down Expand Up @@ -218,15 +220,15 @@ def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple
(isinstance(i, (tuple, list)) and not (
len(np.unique(i)) == 1 and np.unique(i)[0] == 0))]

@property
@cached_property
def foreground_regions(self):
return self.filter_background(self.all_regions)

@property
@cached_property
def foreground_labels(self):
return self.filter_background(self.all_labels)

@property
@cached_property
def num_segmentation_heads(self):
if self.has_regions:
return len(self.foreground_regions)
Expand Down
34 changes: 13 additions & 21 deletions nnunetv2/utilities/plans_handling/plans_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dynamic_network_architectures
from copy import deepcopy
from functools import lru_cache, partial
from functools import lru_cache, partial, cached_property
from typing import Union, Tuple, List, Type, Callable

import numpy as np
Expand Down Expand Up @@ -44,8 +44,7 @@ def data_identifier(self) -> str:
def preprocessor_name(self) -> str:
return self.configuration['preprocessor_name']

@property
@lru_cache(maxsize=1)
@cached_property
def preprocessor_class(self) -> Type[DefaultPreprocessor]:
preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"),
self.preprocessor_name,
Expand Down Expand Up @@ -80,8 +79,7 @@ def use_mask_for_norm(self) -> List[bool]:
def UNet_class_name(self) -> str:
return self.configuration['UNet_class_name']

@property
@lru_cache(maxsize=1)
@cached_property
def UNet_class(self) -> Type[nn.Module]:
unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"),
self.UNet_class_name,
Expand Down Expand Up @@ -121,8 +119,7 @@ def conv_kernel_sizes(self) -> List[List[int]]:
def unet_max_num_features(self) -> int:
return self.configuration['unet_max_num_features']

@property
@lru_cache(maxsize=1)
@cached_property
def resampling_fn_data(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Expand All @@ -134,8 +131,7 @@ def resampling_fn_data(self) -> Callable[
fn = partial(fn, **self.configuration['resampling_fn_data_kwargs'])
return fn

@property
@lru_cache(maxsize=1)
@cached_property
def resampling_fn_probabilities(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Expand All @@ -147,8 +143,7 @@ def resampling_fn_probabilities(self) -> Callable[
fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs'])
return fn

@property
@lru_cache(maxsize=1)
@cached_property
def resampling_fn_seg(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Expand All @@ -164,7 +159,7 @@ def resampling_fn_seg(self) -> Callable[
def batch_dice(self) -> bool:
return self.configuration['batch_dice']

@property
@cached_property
def next_stage_names(self) -> Union[List[str], None]:
ret = self.configuration.get('next_stage')
if ret is not None:
Expand Down Expand Up @@ -218,7 +213,7 @@ def _internal_resolve_configuration_inheritance(self, configuration_name: str,
configuration = base_config
return configuration

@lru_cache(maxsize=10)
@lru_cache(maxsize=None)
def get_configuration(self, configuration_name: str):
if configuration_name not in self.plans['configurations'].keys():
raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. "
Expand All @@ -243,8 +238,7 @@ def original_median_spacing_after_transp(self) -> List[float]:
def original_median_shape_after_transp(self) -> List[float]:
return self.plans['original_median_shape_after_transp']

@property
@lru_cache(maxsize=1)
@cached_property
def image_reader_writer_class(self) -> Type[BaseReaderWriter]:
return recursive_find_reader_writer_by_name(self.plans['image_reader_writer'])

Expand All @@ -256,12 +250,11 @@ def transpose_forward(self) -> List[int]:
def transpose_backward(self) -> List[int]:
return self.plans['transpose_backward']

@property
@cached_property
def available_configurations(self) -> List[str]:
return list(self.plans['configurations'].keys())

@property
@lru_cache(maxsize=1)
@cached_property
def experiment_planner_class(self) -> Type[ExperimentPlanner]:
planner_name = self.experiment_planner_name
experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
Expand All @@ -273,8 +266,7 @@ def experiment_planner_class(self) -> Type[ExperimentPlanner]:
def experiment_planner_name(self) -> str:
return self.plans['experiment_planner_used']

@property
@lru_cache(maxsize=1)
@cached_property
def label_manager_class(self) -> Type[LabelManager]:
return get_labelmanager_class_from_plans(self.plans)

Expand All @@ -283,7 +275,7 @@ def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager:
regions_class_order=dataset_json.get('regions_class_order'),
**kwargs)

@property
@cached_property
def foreground_intensity_properties_per_channel(self) -> dict:
if 'foreground_intensity_properties_per_channel' not in self.plans.keys():
if 'foreground_intensity_properties_by_modality' in self.plans.keys():
Expand Down