Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

another implementation + partial reproduction #17

Open
lunixbochs opened this issue Jan 13, 2021 · 2 comments
Open

another implementation + partial reproduction #17

lunixbochs opened this issue Jan 13, 2021 · 2 comments

Comments

@lunixbochs
Copy link

Thanks for the great paper!

I've created another open source implementation of the SHA RNN here: https://github.com/talonvoice/sha-rnn

I trained with similar parameters to the single head model at the end of your readme and achieved a bpc of 1.113 on test with LAMB, slightly worse than your 1.077, but still better than the Mogrifier LSTM. My epochs were also 1h instead of 30m. I used pytorch MultiHeadAttention for now instead of reimplementing the custom attention, which might be the reason for the different speed and bpc.

I have some notes in my README about the effort. It's possible I made some mistakes in the model or training as nobody else has reviewed the code yet.

@Smerity
Copy link
Owner

Smerity commented Jan 17, 2021

Great work! I had a quick skim over the code and like some of the refactors :)

After your initial training run did you repeat training but with a decreased learning rate?
Ah, just saw ReduceLROnPlateau. Potentially try training without the plateau behaviour until far later in?
I've had issues where the reduction of learning rate were premature.

I also saw your feedforwards are 2048 rather than 4096 which could well be part of the difference.

I could imagine the built in PyTorch MultiHeadAttention may have been part of the bpc drop too. Many of the small choices in my codebase and the paper were about making the gradients flow as cleanly as possible which seem to be surprisingly important. This may also be the cause of the speed slowdown given the values pass through un-matmulled (i.e. only inexpensive layernorms or similar).

Regardless thanks for the reproduction! I'm sure with a tiny amount of targeted polish it'd match or beat my own perplexities =]

P.S. Checked out Talon and your various other projects - love the work!

@lunixbochs
Copy link
Author

lunixbochs commented Jul 20, 2021

Just coming back to this now. I did some ablation testing with both of our codebases and swapping basically every combination of the components made almost no difference in speed. So I realized I made a few mistakes 🤦. I'll maybe have some new results soon.

  • The 2048/4096 FF issue you mentioned, I had accidentally duplicated some method arguments so it was using the default value for FF dims instead of the one I was trying to pass (but this didn't really affect perf).
  • The actual perf issue: I forgot to invoke the AMP context manager for my train loop, so my training was in fp32. Fixing AMP makes my codebase about the same speed as yours.

Other differences: I'm not using warmup or SplitCrossEntropyLoss (Adaptive Softmax) yet, though adaptive shouldn't matter much on the char model.

(Also, have you looked into SRU++?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants