From ea985cf8b003934cdcdea076e1964844d3eb5ee5 Mon Sep 17 00:00:00 2001 From: "Wald, Tassilo" Date: Wed, 11 Oct 2023 16:41:01 +0200 Subject: [PATCH] Fix init call of NoDeepSupervision trainer --- .../nnUNetTrainerNoDeepSupervision.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py index de047c592..1152fbeb4 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -1,7 +1,16 @@ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import torch class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self.enable_deep_supervision = False