Skip to content

Commit

Permalink
nerge
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jun 21, 2024
2 parents a3f7935 + ed88855 commit 2eaa371
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
10 changes: 5 additions & 5 deletions nnunetv2/imageio/base_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
class BaseReaderWriter(ABC):
@staticmethod
def _check_all_same(input_list):
# compare all entries to the first
for i in input_list[1:]:
if i != input_list[0]:
return False
return True
if len(input_list) == 1:
return True
else:
# compare all entries to the first
return np.allclose(input_list[0], input_list[1:])

@staticmethod
def _check_all_same_array(input_list):
Expand Down
6 changes: 5 additions & 1 deletion nnunetv2/training/dataloading/data_loader_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def generate_train_batch(self):
if self.transforms is not None:
with torch.no_grad():
with threadpool_limits(limits=1, user_api=None):

data_all = torch.from_numpy(data_all).float()
seg_all = torch.from_numpy(seg_all).to(torch.int16)
images = []
Expand All @@ -99,7 +100,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images

return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/dataloading/data_loader_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images

return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
self.num_epochs = 250


class nnUNetTrainer_500epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 500


class nnUNetTrainer_750epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 750


class nnUNetTrainer_2000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
Expand Down

0 comments on commit 2eaa371

Please sign in to comment.