-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using (Absolute) Positional Embeddings with Hyena Operators #29
Comments
Thanks for the question! There was a very slight degradation in performance (maybe half a point MLM accuracy), so we kept them in. Our hypothesis is that since it's bidirectional, the CLS token (token zero in the sequence) acts as a bit of a scratch pad for understanding the sequence, and the positional embeddings help carry some of that information forward. This is why we have that extra conv residual branch on the side as well. This is less important for autoregressive models since most of the sequence information is local and only goes in one direction. |
Nice, thanks for the prompt response. Ok, makes sense... but my intuition about Hyena filters is that due to convolution, these filters implicitly embed positional information, so adding extra positional embeddings is not helpful, but could be hurtful... Also, regarding the bidirectionality of the M2-BERT model, your method is different from how HyenaDNA impose it (changing the padding of the sequence)... you train two (implicit) filters, and you concatenate (pad and sum, link) them to generate the bidirectional filter... is it right? if so, how could this impose bidirectionality? |
Good questions! This blog post may clear most things up: https://hazyresearch.stanford.edu/blog/2023-12-11-conv-tutorial . It has to do with the details of what the FFT convolution implements, and that explains the differences between an autoregressive and a bidirectional model. Some more details below.
Convolutions do not embed positional information, they're actually positionally invariant (this is the same intuition behind the "shift invariance" that you think about for ResNet-style convolutions). The Hyena filters in particular have an inductive bias towards local weights, so most of the sequence mixing is local.
HyenaDNA is not a bidirectional model, it's an autoregressive model (next token prediction). A bidirectional model trained on next token prediction wouldn't learn anything, as it can just look up the next token.
The key is in the FFT convolution - it actually implements a circular convolution (https://en.wikipedia.org/wiki/Convolution_theorem). In an autoregressive model, we pad with zeros so that the computation of the circular convolution all turns into zeros. This corresponds to the "throw them away" option in that GIF at the bottom. In a bidirectional model, we pad the input with zeros, but make the convolution filter twice as long. This is analogous to the "wrap it around" version, except the filter has now wrapped all the way back to the beginning of the sequence. |
Nice, thanks for the detailed response. I'll check the blog and Wikipedia post to understand the bidirectionality of the model. Exactly! you're right, convolutions (convolving signal A over signal B) is position invariant... but signals/filters/kernels of a convolution is not... signal A searches for a pattern/order of tokens... Also, you're right, bidirectionality has conflict with CLM (next token prediction), but HyenaDNA implemented the bidirectionality (link) as an experimental test (link). |
Hi @DanFu09
Hope you're well,
I was reading the source code and the config files, and I realized that
use_positional_encodings
isTrue
(link). So, the M2-BERT model is using an absolute positional embeddings (link) before feeding the tokens to Hyena operators.I checked the original Hyena and HyenaDNA source codes, and they haven't used any positional embeddings for their models.
My question is why have you used the positional embeddings? Have you tried not using them? Did it worsen the performance?
The text was updated successfully, but these errors were encountered: