-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with only adding sink tokens in cache #17
Comments
Hello! I'm having a bit of a hard time trying to understand what you're saying. Let me try to break it down a bit.
When the input token sequence is larger than the window size, then the later tokens do still attend to the sink tokens. To give a toy example, we'll consider a scenario with a window size of 10, including 4 attention sink tokens, and a text which is just a space separated alphabet. When generating, the model sees:
So the later tokens, e.g.
Did you perhaps mean that if you have an input that is longer than the window size, then the first non-sink tokens will be discarded? I.e. the question becomes "What if the prompt is 2k tokens and the window size is 1024?" Because in that case, the model won't do well. It will produce normal English text like you would expect, but it won't be able to answer the question correctly if the prompt starts with a question and then 1900 tokens of context.
I don't, but I think that this would not work well, assuming that I understand you correctly. Any model loaded with
Please let me know if I understood you correctly, and if I answered your question!
|
Hi Tom, I don't think you understood the issue I am describing. I did mean the sink tokens not being attended to. Consider the following situation: we are using the Mistral model with windowed flash attention, and we have a starting text of size 30k tokens for example. We do the first forward pass on these tokens, before any cache is created (as there are no prev. tokens). We want to be able to do this efficiently in one forward pass, as this is part of the draw of using windowed attention in the first place. However, if you're only adding sink tokens in the cache, the attention done within this pass will just use the local windows. You will only start using sink tokens when generating tokens after that, however you will be attending to the cached kv of the final 4k tokens of these 30k, all of which will be "poisoned" by having been created without sink attention. There is a separate question of how much doing this large forward pass even matters. Technically information could definitely propagate through the kv from beyond the local window, but whether this actually happens with Mistral or other such models and to what extent is more complicated, and would depend on how they were trained etc. If it doesn't happen at all, then you could just as well take the 4k suffix of the text and only do the initial forward pass on that, and in that case there would also be no point in the windowed flash attention implementation. |
I see! Indeed, I haven't tested this approach when the original input already exceeds the window size. That would definitely be an interesting experiment - so far all of my attempts at this give OOM exceptions. |
Oh, missed this issue, have made some comments on the subject in #14
I tested this with a 16 token context window on Mosaic. I observed it maintain coherence beyond the context window, but it does go off track quite quickly |
It seems that in this implementation you are only adding the "sink" token to the cache, and not using in the original forward pass, so if you are using windowed attention and your input token sequence is longer than the window size, than the later tokens in this pass will not attend to any "sink" tokens. In your tests I think this issue would mostly be avoided because the original prompt is either short before you start generating tokens, or in the case of your streaming test you end up doing a new forward pass separately for each added sub-prompt, which are relatively short. However it's unclear from this how well it would work if you wanted to do a single forward pass on a long text. Do you have any tests showing that this still works with this implementation?
The text was updated successfully, but these errors were encountered: