Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Typing to Llama Training #9

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Add Typing to Llama Training #9

wants to merge 8 commits into from

Conversation

lennart-finke
Copy link
Collaborator

Description

Added Jax and regular typing to llama_train.py

Related Issue

Should close #4.

How Has This Been Tested?

Executing the script in a Colab instance.

@lennart-finke
Copy link
Collaborator Author

Alright, will have to work on this some more as the checks indicate.

@lennart-finke
Copy link
Collaborator Author

Now it might go through. If someone has time, I recommend double checking the jaxtyping hints though, as I was not entirely sure everywhere.

Copy link
Owner

@danbraunai danbraunai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. Bunch of comments. Btw I think you may need to merge main into this branch

Comment on lines 18 to 41

This implementation is based on
- llm.c, licensed under MIT ((c) 2024 Andrei Karpathy) and
- TransformerLens, licensed under MIT ((c) 2022 TransformerLensOrg).


MIT License:
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was accidentally deleted?

simple_stories_train/train_llama.py Outdated Show resolved Hide resolved
rotary_dim: int = 768 // 12 # i.e. same as d_head
rotary_base: int = 10000
n_ctx: int = 1024
n_key_value_heads: int = (
12 // 4
) # Note that llama 3.1 n_key_value_heads does not scale with n_heads
use_grouped_query_attention: bool = True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And these

Comment on lines 77 to 78
self.kv_attn = nn.Linear(config.n_embd, 2 * config.n_embd // self.repeat_kv_heads)
self.q_attn = nn.Linear(config.n_embd, config.n_embd)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bias argument has disappeared here, and for other attributes down below too.

self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.attn_bias)
self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.LLMC_RESIDUAL_SCALE_FLAG = 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? Seems like you want to keep it on the c_proj

simple_stories_train/train_llama.py Show resolved Hide resolved
@@ -576,7 +587,7 @@ def __init__(self, filename_pattern, B, T, process_rank, num_processes):
print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files")

# kick things off
self.current_shard = None
self.current_shard = -1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this change was made?

Comment on lines 621 to 626
# -----------------------------------------------------------------------------
# Python -> C bridge utilities for saving params/grads/activations to .bin files


def write_fp32(tensor: torch.Tensor, file: BufferedWriter):
t = tensor.detach().cpu().to(torch.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we deleted all the below from main. I'm guessing you hadn't pulled that latest main onto your branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks like it.

Comment on lines 811 to 812
import argparse
import time
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best to import at top of file

Comment on lines 1060 to 1063
# -------------------------------------------------------------------------
# PyTorch -> C bridge: save some weights and state for C to load later as reference

# do one forward pass to generate ground truth for our C tests
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we deleted all this too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add typing to train_llama.py
2 participants