diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 318be58e9..690a15fb2 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -11,6 +11,8 @@ import numpy as np import torch +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ @@ -865,9 +867,11 @@ def on_train_end(self): old_stdout = sys.stdout with open(os.devnull, 'w') as f: sys.stdout = f - if self.dataloader_train is not None: + if self.dataloader_train is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_train._finish() - if self.dataloader_val is not None: + if self.dataloader_val is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_val._finish() sys.stdout = old_stdout