You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've noticed torch.nn.CrossEntropyLoss is used for the cross entropy loss and a custom loss from utils.losses is used for the Dice loss, as used as follows:
However, the default reduction method for torch.nn.CrossEntropyLoss is 'mean', so the Dice loss is always roughly about H*W(*D) times bigger than the CE loss.
So, a direct mean of two losses as used in the following code would not be actually the intended average.
Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.
The text was updated successfully, but these errors were encountered:
Thank you for the awesome repository!
I've noticed
torch.nn.CrossEntropyLoss
is used for the cross entropy loss and a custom loss fromutils.losses
is used for the Dice loss, as used as follows:SSL4MIS/code/train_uncertainty_aware_mean_teacher_3D.py
Lines 124 to 125 in 30e05d8
The Dice loss seems to use a 'sum' reduction as follows:
SSL4MIS/code/utils/losses.py
Lines 169 to 177 in 30e05d8
However, the default reduction method for
torch.nn.CrossEntropyLoss
is 'mean', so the Dice loss is always roughly aboutH*W(*D)
times bigger than the CE loss.So, a direct mean of two losses as used in the following code would not be actually the intended average.
SSL4MIS/code/train_uncertainty_aware_mean_teacher_3D.py
Line 171 in 30e05d8
Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.
The text was updated successfully, but these errors were encountered: