Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small error in bias-computation in L08/code/softmax-regression_scratch.ipynb #5

Open
alanakbik opened this issue Jun 10, 2022 · 0 comments

Comments

@alanakbik
Copy link

Hello @rasbt,

first of all thanks for making all this material available online, as well as your video lectures! A really helpful resource!

A small issue and fix: The classic softmax regression implementation in L08/code/softmax-regression_scratch.ipynb has a small error in the bias computation (I think). Output for training (cell 8) gives the same weight for all bias terms:

Epoch: 049 | Train ACC: 0.858 | Cost: 0.484
Epoch: 050 | Train ACC: 0.858 | Cost: 0.481

Model parameters:
  Weights: tensor([[ 0.5582, -1.0240],
        [-0.5462,  0.0258],
        [-0.0119,  0.9982]])
  Bias: tensor([-1.2020e-08, -1.2020e-08, -1.2020e-08])

whereas the second implementation with nn.Module API gives different bias terms.

The problem lies in the torch.sum call in SoftmaxRegression1.backward: it computes a single sum over all biases which is later broadcast across all bias terms. You can fix this by changing

    def backward(self, x, y, probas):  
        grad_loss_wrt_w = -torch.mm(x.t(), y - probas).t()
        grad_loss_wrt_b = -torch.sum(y - probas)
        return grad_loss_wrt_w, grad_loss_wrt_b

to

    def backward(self, x, y, probas):  
        grad_loss_wrt_w = -torch.mm(x.t(), y - probas).t()
        grad_loss_wrt_b = -torch.sum(y - probas, dim=0)
        return grad_loss_wrt_w, grad_loss_wrt_b

it learns the toy problem a (very slight) bit better then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant