diff --git a/wekws/model/loss.py b/wekws/model/loss.py index fef17d6..42045a0 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -169,7 +169,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - loss = F.cross_entropy(logits, target) + loss = F.cross_entropy(logits, target.type(torch.int64)) acc = acc_frame(logits, target) return loss, acc