-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
another implementation + partial reproduction #17
Comments
Great work! I had a quick skim over the code and like some of the refactors :) After your initial training run did you repeat training but with a decreased learning rate? I also saw your feedforwards are 2048 rather than 4096 which could well be part of the difference. I could imagine the built in PyTorch Regardless thanks for the reproduction! I'm sure with a tiny amount of targeted polish it'd match or beat my own perplexities =] P.S. Checked out Talon and your various other projects - love the work! |
Just coming back to this now. I did some ablation testing with both of our codebases and swapping basically every combination of the components made almost no difference in speed. So I realized I made a few mistakes 🤦. I'll maybe have some new results soon.
Other differences: I'm not using warmup or (Also, have you looked into SRU++?) |
Thanks for the great paper!
I've created another open source implementation of the SHA RNN here: https://github.com/talonvoice/sha-rnn
I trained with similar parameters to the single head model at the end of your readme and achieved a bpc of 1.113 on test with LAMB, slightly worse than your 1.077, but still better than the Mogrifier LSTM. My epochs were also 1h instead of 30m. I used pytorch MultiHeadAttention for now instead of reimplementing the custom attention, which might be the reason for the different speed and bpc.
I have some notes in my README about the effort. It's possible I made some mistakes in the model or training as nobody else has reviewed the code yet.
The text was updated successfully, but these errors were encountered: