Skip to content

Commit

Permalink
fix #2300; dataloader did not stack segmentations properly if deep su…
Browse files Browse the repository at this point in the history
…pervision was disabled
  • Loading branch information
FabianIsensee committed Jun 19, 2024
1 parent 852b303 commit e42b9ed
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion nnunetv2/training/dataloading/data_loader_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images
if torch is not None:
torch.set_num_threads(torch_nthreads)
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/dataloading/data_loader_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images
if torch is not None:
torch.set_num_threads(torch_nthreads)
Expand Down

0 comments on commit e42b9ed

Please sign in to comment.