You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
first of all thanks for making all this material available online, as well as your video lectures! A really helpful resource!
A small issue and fix: The classic softmax regression implementation in L08/code/softmax-regression_scratch.ipynb has a small error in the bias computation (I think). Output for training (cell 8) gives the same weight for all bias terms:
whereas the second implementation with nn.Module API gives different bias terms.
The problem lies in the torch.sum call in SoftmaxRegression1.backward: it computes a single sum over all biases which is later broadcast across all bias terms. You can fix this by changing
defbackward(self, x, y, probas):
grad_loss_wrt_w=-torch.mm(x.t(), y-probas).t()
grad_loss_wrt_b=-torch.sum(y-probas)
returngrad_loss_wrt_w, grad_loss_wrt_b
to
defbackward(self, x, y, probas):
grad_loss_wrt_w=-torch.mm(x.t(), y-probas).t()
grad_loss_wrt_b=-torch.sum(y-probas, dim=0)
returngrad_loss_wrt_w, grad_loss_wrt_b
it learns the toy problem a (very slight) bit better then.
The text was updated successfully, but these errors were encountered:
Hello @rasbt,
first of all thanks for making all this material available online, as well as your video lectures! A really helpful resource!
A small issue and fix: The classic softmax regression implementation in
L08/code/softmax-regression_scratch.ipynb
has a small error in the bias computation (I think). Output for training (cell 8) gives the same weight for all bias terms:whereas the second implementation with nn.Module API gives different bias terms.
The problem lies in the
torch.sum
call inSoftmaxRegression1.backward
: it computes a single sum over all biases which is later broadcast across all bias terms. You can fix this by changingto
it learns the toy problem a (very slight) bit better then.
The text was updated successfully, but these errors were encountered: