diff --git a/nnunetv2/training/loss/compound_losses.py b/nnunetv2/training/loss/compound_losses.py index eaeb5d8e0..900acf155 100644 --- a/nnunetv2/training/loss/compound_losses.py +++ b/nnunetv2/training/loss/compound_losses.py @@ -83,14 +83,20 @@ def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use def forward(self, net_output: torch.Tensor, target: torch.Tensor): if self.use_ignore_label: # target is one hot encoded here. invert it so that it is True wherever we can compute the loss - mask = (1 - target[:, -1:]).bool() + if target.dtype == torch.bool: + mask = ~target[:, -1:] + else: + mask = (1 - target[:, -1:]).bool() # remove ignore channel now that we have the mask - target_regions = torch.clone(target[:, :-1]) + # why did we use clone in the past? Should have documented that... + # target_regions = torch.clone(target[:, :-1]) + target_regions = target[:, :-1] else: target_regions = target mask = None dc_loss = self.dc(net_output, target_regions, loss_mask=mask) + target_regions = target_regions.float() if mask is not None: ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8) else: diff --git a/nnunetv2/training/loss/dice.py b/nnunetv2/training/loss/dice.py index 574435754..7b1fcd40f 100644 --- a/nnunetv2/training/loss/dice.py +++ b/nnunetv2/training/loss/dice.py @@ -142,13 +142,13 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: - y_onehot = torch.zeros(net_output.shape, device=net_output.device) + y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.bool) y_onehot.scatter_(1, gt.long(), 1) tp = net_output * y_onehot - fp = net_output * (1 - y_onehot) + fp = net_output * (~y_onehot) fn = (1 - net_output) * y_onehot - tn = (1 - net_output) * (1 - y_onehot) + tn = (1 - net_output) * (~y_onehot) if mask is not None: with torch.no_grad(): diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index c709b6de6..15b26917c 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -1043,7 +1043,10 @@ def validation_step(self, batch: dict) -> dict: # CAREFUL that you don't rely on target after this line! target[target == self.label_manager.ignore_label] = 0 else: - mask = 1 - target[:, -1:] + if target.dtype == torch.bool: + mask = ~target[:, -1:] + else: + mask = 1 - target[:, -1:] # CAREFUL that you don't rely on target after this line! target = target[:, :-1] else: