Skip to content

Commit

Permalink
fix: check class of dataloader before blindly calling _finish() (#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jan 18, 2024
1 parent e5644ea commit ca23628
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import numpy as np
import torch
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
Expand Down Expand Up @@ -865,9 +867,11 @@ def on_train_end(self):
old_stdout = sys.stdout
with open(os.devnull, 'w') as f:
sys.stdout = f
if self.dataloader_train is not None:
if self.dataloader_train is not None and \
isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):
self.dataloader_train._finish()
if self.dataloader_val is not None:
if self.dataloader_val is not None and \
isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):
self.dataloader_val._finish()
sys.stdout = old_stdout

Expand Down

0 comments on commit ca23628

Please sign in to comment.