Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiments with MPT7b with seqlen > 2048 #14

Open
vchiley opened this issue Oct 10, 2023 · 4 comments
Open

Experiments with MPT7b with seqlen > 2048 #14

vchiley opened this issue Oct 10, 2023 · 4 comments

Comments

@vchiley
Copy link

vchiley commented Oct 10, 2023

here

For MPT-7B-chat, a RuntimeError is encountered for transformers when the input length exceeds 2048.

Can you comment on what the RuntimeError was?
You have ran mpt7b with seq len > 8k.

If the mpt7b model config has max_seq_len=2048, by design, if seq len exceeds the configured value the model will throw an error.
To fix this, simply configure with the a longer seq len.

@tomaarsen
Copy link
Owner

Hello!

You're very right - well spotted. For the subsequent prompting experiments I indeed didn't increase the seq length to 8k, while I did do this for the perplexity experiments as can be seen here.
The RuntimeError is indeed just due to hitting 2048 tokens.

I'll rerun the experiments with the configuration set to 8k or 32k or so, I think that'll be quite interesting.

  • Tom Aarsen

@tomaarsen
Copy link
Owner

I've completed the experiments in 7a437d1. When setting the max_seq_len to 8192, the result is extremely poor responses after 2048 tokens for transformers.

I hope that clarifies it!

  • Tom Aarsen

@Nintorac
Copy link

I tried this too, I'm aiming to evaluate embedding quality wrt different window lengths. I am using the hf feature-extraction pipeline.

I believe the issue is that the model is encoding the entire context in a single shot using.

I think there are several approaches that could be used to encode long sequences.

the simplest of which would be to truncate the full context to the window length and then autoregressively feed through the rest of the sequence token by token. Definitely some performance issue with that approach.

From what I can see this problem will exist for all models, not just MPT? Please correct if I'm wrong. Any tips on how to approach implementing a streaming inferences.

@Nintorac
Copy link

Nintorac commented Oct 20, 2023

had a very rough crack, as expected really slow.

also didnt really get a lot of matching in logits between the outputs so maybe some issues.

anyway, probably the wrong place to implement this, would need to be unique per model I guess. hope it helps to get the idea across at least.

        def overwrite_forward(module):
            from tqdm import tqdm
            import torch
            # Create the new cache
            module.attention_sink_kv_cache = AttentionSinkKVCache(**attention_sink_kwargs)

            # Keep track of the old forward method, we need it in the wrapped one
            old_forward = module.forward

            # Wrap the forward by overriding the past_key_values using the cache
            def wrapped_forward(self, *args, **kwargs):
                outputs = old_forward(*args, **kwargs)
                outputs.past_key_values = self.attention_sink_kv_cache(outputs.past_key_values)
                return outputs

            def wrapped_wrapped_forward(self, *args, **kwargs):
                print(kwargs.keys())
                # print('attention_mask', kwargs['attention_mask'])
                # print('position_ids', kwargs['position_ids'])
                # print('head_mask', kwargs['head_mask'])
                x = args[0]
                attention = kwargs['attention_mask']
                window_size = self.attention_sink_kv_cache.attention_sink_window_size
                t = tqdm(total=x.shape[-1])
                while x.shape[-1] > 0:
                    kwargs['use_cache'] = True
                    if kwargs['past_key_values'] is None:
                        x_step = x[:,:window_size]
                        attn_step = attention[:,:window_size]
                        x = x[:,window_size:]
                        attention = attention[:,window_size:]
                    else:
                        x_step = x[:,:1]
                        attn_step = torch.cat([attn_step, attention[:,:1]], -1)
                        attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
                        x = x[:,1:]
                        attention = attention[:,1:]
                    t.update(x_step.shape[-1])
                   
                    kwargs['attention_mask'] = attn_step
                    output = wrapped_forward(self, x_step, **kwargs)
                    kwargs['past_key_values'] = output.past_key_values
                output.past_key_values = None
                return output
               
            module.forward = types.MethodType(wrapped_wrapped_forward, module)

the main idea is here, where we either encode a full window and afterwards step through

                while x.shape[-1] > 0:
                    kwargs['use_cache'] = True
                    if kwargs['past_key_values'] is None:
                        x_step = x[:,:window_size]
                        attn_step = attention[:,:window_size]
                        x = x[:,window_size:]
                        attention = attention[:,window_size:]
                    else:
                        x_step = x[:,:1]
                        attn_step = torch.cat([attn_step, attention[:,:1]], -1)
                        attn_step = attn_step[:,-(window_size+self.attention_sink_kv_cache.attention_sink_size+1):]
                        x = x[:,1:]
                        attention = attention[:,1:]

I think you could improve performance by creating an overlapped window while encoding, something like [x,x,x,x,x,x,o,o,o,o,o,o,o,o,o] where xs are the window and os are new tokens. here the total encoding step length will still be limited to the models context window. so if your sink window=context window then you can at most single step.

Lots of parameters to tweak here, (window_size, encoding_overlap_size) pretty interested to see how embedding quality changes wrt to these.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants