-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experiments with MPT7b with seqlen > 2048 #14
Comments
Hello! You're very right - well spotted. For the subsequent prompting experiments I indeed didn't increase the seq length to 8k, while I did do this for the perplexity experiments as can be seen here. I'll rerun the experiments with the configuration set to 8k or 32k or so, I think that'll be quite interesting.
|
I've completed the experiments in 7a437d1. When setting the max_seq_len to 8192, the result is extremely poor responses after 2048 tokens for I hope that clarifies it!
|
I tried this too, I'm aiming to evaluate embedding quality wrt different window lengths. I am using the hf I believe the issue is that the model is encoding the entire context in a single shot using. I think there are several approaches that could be used to encode long sequences. the simplest of which would be to truncate the full context to the window length and then autoregressively feed through the rest of the sequence token by token. Definitely some performance issue with that approach. From what I can see this problem will exist for all models, not just MPT? Please correct if I'm wrong. Any tips on how to approach implementing a streaming inferences. |
had a very rough crack, as expected really slow. also didnt really get a lot of matching in logits between the outputs so maybe some issues. anyway, probably the wrong place to implement this, would need to be unique per model I guess. hope it helps to get the idea across at least. def overwrite_forward(module):
from tqdm import tqdm
import torch
# Create the new cache
module.attention_sink_kv_cache = AttentionSinkKVCache(**attention_sink_kwargs)
# Keep track of the old forward method, we need it in the wrapped one
old_forward = module.forward
# Wrap the forward by overriding the past_key_values using the cache
def wrapped_forward(self, *args, **kwargs):
outputs = old_forward(*args, **kwargs)
outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
return outputs
def wrapped_wrapped_forward(self, *args, **kwargs):
print(kwargs.keys())
# print('attention_mask', kwargs['attention_mask'])
# print('position_ids', kwargs['position_ids'])
# print('head_mask', kwargs['head_mask'])
x = args[0]
attention = kwargs['attention_mask']
window_size = self.attention_sink_kv_cache.attention_sink_window_size
t = tqdm(total=x.shape[-1])
while x.shape[-1] > 0:
kwargs['use_cache'] = True
if kwargs['past_key_values'] is None:
x_step = x[:,:window_size]
attn_step = attention[:,:window_size]
x = x[:,window_size:]
attention = attention[:,window_size:]
else:
x_step = x[:,:1]
attn_step = torch.cat([attn_step, attention[:,:1]], -1)
attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
x = x[:,1:]
attention = attention[:,1:]
t.update(x_step.shape[-1])
kwargs['attention_mask'] = attn_step
output = wrapped_forward(self, x_step, **kwargs)
kwargs['past_key_values'] = output.past_key_values
output.past_key_values = None
return output
module.forward = types.MethodType(wrapped_wrapped_forward, module) the main idea is here, where we either encode a full window and afterwards step through while x.shape[-1] > 0:
kwargs['use_cache'] = True
if kwargs['past_key_values'] is None:
x_step = x[:,:window_size]
attn_step = attention[:,:window_size]
x = x[:,window_size:]
attention = attention[:,window_size:]
else:
x_step = x[:,:1]
attn_step = torch.cat([attn_step, attention[:,:1]], -1)
attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
x = x[:,1:]
attention = attention[:,1:] I think you could improve performance by creating an overlapped window while encoding, something like Lots of parameters to tweak here, ( |
here
Can you comment on what the RuntimeError was?
You have ran mpt7b with seq len > 8k.
If the mpt7b model config has max_seq_len=2048, by design, if seq len exceeds the configured value the model will throw an error.
To fix this, simply configure with the a longer seq len.
The text was updated successfully, but these errors were encountered: