Skip to content

Commit

Permalink
[fix] Convert target to torch.int64 for cross_entropy (#141)
Browse files Browse the repository at this point in the history
On my machine the original code threw an error 
RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int'

I followed https://github.com/wenet-e2e/wekws#installation to setup the environment so I'm curious if this error has ever occured to other people.
  • Loading branch information
wangtiance authored Sep 10, 2023
1 parent 2c3c9ce commit 6ae98ef
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion wekws/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
(float): loss of current batch
(float): accuracy of current batch
"""
loss = F.cross_entropy(logits, target)
loss = F.cross_entropy(logits, target.type(torch.int64))
acc = acc_frame(logits, target)
return loss, acc

Expand Down

0 comments on commit 6ae98ef

Please sign in to comment.