diff --git a/nnunetv2/imageio/base_reader_writer.py b/nnunetv2/imageio/base_reader_writer.py index 2847478ae..4ca536e5d 100644 --- a/nnunetv2/imageio/base_reader_writer.py +++ b/nnunetv2/imageio/base_reader_writer.py @@ -21,11 +21,11 @@ class BaseReaderWriter(ABC): @staticmethod def _check_all_same(input_list): - # compare all entries to the first - for i in input_list[1:]: - if i != input_list[0]: - return False - return True + if len(input_list) == 1: + return True + else: + # compare all entries to the first + return np.allclose(input_list[0], input_list[1:]) @staticmethod def _check_all_same_array(input_list): diff --git a/nnunetv2/training/dataloading/data_loader_2d.py b/nnunetv2/training/dataloading/data_loader_2d.py index db1597bcc..8597a3b38 100644 --- a/nnunetv2/training/dataloading/data_loader_2d.py +++ b/nnunetv2/training/dataloading/data_loader_2d.py @@ -90,6 +90,7 @@ def generate_train_batch(self): if self.transforms is not None: with torch.no_grad(): with threadpool_limits(limits=1, user_api=None): + data_all = torch.from_numpy(data_all).float() seg_all = torch.from_numpy(seg_all).to(torch.int16) images = [] @@ -99,7 +100,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images return {'data': data_all, 'target': seg_all, 'keys': selected_keys} diff --git a/nnunetv2/training/dataloading/data_loader_3d.py b/nnunetv2/training/dataloading/data_loader_3d.py index 36da8ded9..964fa67e8 100644 --- a/nnunetv2/training/dataloading/data_loader_3d.py +++ b/nnunetv2/training/dataloading/data_loader_3d.py @@ -62,7 +62,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images return {'data': data_all, 'target': seg_all, 'keys': selected_keys} diff --git a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py index e3a71a000..9d4867003 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py +++ b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -55,6 +55,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic self.num_epochs = 250 +class nnUNetTrainer_500epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 500 + + +class nnUNetTrainer_750epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 750 + + class nnUNetTrainer_2000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device('cuda')):