Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix - remove attn.bias keys from GPT state dict in 'from_pretrine… #122

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

amnonbleich
Copy link

bug fix - remove attn.bias keys from GPT state dict in 'from_pretrined'. otherwise assertion fails. if that's not a bug, would be happy to hear what is the reasoning. in addition, the above mentioned keys are not used elsewhere, only in the assertion

@erno123
Copy link

erno123 commented Jan 16, 2024

The root cause of the problem is that persistent=False is set for the attn.bias keys in the original Hugging Face code (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py, line 133). It means that these keys are not included in the state dictionary at HF while they still are in minGPT. That's while the assetion fails in line 200 of minGPT/model.py.

So a better solution is to also set the same persistent=False option for the attn.bias keys in line 48 of minGPT/model.py, like this:
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size), persistent=False)

Also, the attn.masked_bias keys get the same persistent=False option in the HF code, hence they aren't included in the HF state dictionary. So excluding them in line 196 of minGPT/model.py is unnecesary. And consequently we don't need the keys variable at all, we can directly use sd_hf instead everywhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants