Skip to content

Commit

Permalink
Add/gpt2 (#3)
Browse files Browse the repository at this point in the history
* add: gpt2
  • Loading branch information
shenxiangzhuang authored May 20, 2024
1 parent d723a43 commit 910b719
Show file tree
Hide file tree
Showing 16 changed files with 1,556 additions and 10 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ repos:
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
# - repo: https://github.com/pycqa/flake8
# rev: 7.0.0
# hooks:
# - id: flake8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.2
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright 2024 Xiangzhuang Shen

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@

ToyLLM is a simple language model that can be used to generate text.
It is based on the [GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html) model.


# Acknowledgements
The project is highly inspired by the following projects:
- [rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch)
- [neelnanda-io/TransformerLens](https://github.com/neelnanda-io/TransformerLens)
165 changes: 165 additions & 0 deletions dataset/the-verdict.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
site_name: ToyLLM
repo_url: https://github.com/shenxiangzhuang/mppt
repo_url: https://github.com/toy-ai/toyllm
repo_name: toy-ai/toyllm
edit_uri: ""
site_description: Toy LLM
Expand Down
646 changes: 643 additions & 3 deletions poetry.lock

Large diffs are not rendered by default.

18 changes: 17 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ license = "Apache-2.0"
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.11"
python = ">=3.11,<4.0"
torch = "^2.2.0"
tiktoken = "^0.6.0"
numpy = "^1.26.4"
jaxtyping = "^0.2.28"
matplotlib = "^3.8.3"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.5.0"
Expand All @@ -27,6 +31,7 @@ black = "^23.11.0"
mypy = "^1.7.1"
ruff = "^0.1.7"
ipython = "^8.22.2"
nvitop = "^1.3.2"


[virtualenvs]
Expand All @@ -36,6 +41,17 @@ in-project = true
[tool.isort]
profile = "black"


[tool.ruff]
line-length = 120

[tool.ruff.lint]
ignore = ["F722"]

[tool.black]
line-length = 120


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
91 changes: 91 additions & 0 deletions toyllm/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import pathlib
import urllib.request

import tiktoken
import torch
from torch.utils.data import DataLoader, Dataset

logger = logging.getLogger(__name__)


def get_dataset_dir() -> pathlib.Path:
return pathlib.Path(__file__).parents[1] / "dataset"


class GPTDataset(Dataset):
def __init__(self, txt: str, tokenizer: tiktoken.Encoding, max_length: int, stride: int):
"""
:param txt: txt data
:param tokenizer: tokenizer object
:param max_length: max length
:param stride: stride size
"""
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []

# Tokenize the entire text
token_ids = tokenizer.encode(txt)

# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids[i : i + max_length]
target_chunk = token_ids[i + 1 : i + max_length + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
return self.input_ids[idx], self.target_ids[idx]


class GPTDataloader:
def __init__(
self,
tokenizer: tiktoken.Encoding,
max_length: int,
stride: int,
batch_size: int,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.stride = stride
self.batch_size = batch_size

def create_dataloader(self, text: str, shuffle=True, drop_last=True) -> DataLoader:
# Create dataset
dataset = GPTDataset(text, self.tokenizer, self.max_length, self.stride)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle, drop_last=drop_last)

return dataloader


def read_simple_text_file() -> str:
file_name = "the-verdict.txt"
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"

file_path = get_dataset_dir() / file_name

# if not exists, download it first
if not file_path.exists():
logger.info(f"Downloading {url} to {file_path}")
with urllib.request.urlopen(url) as response:
text_data = response.read().decode("utf-8")
with open(file_path, "w", encoding="utf-8") as file:
file.write(text_data)

logger.info(f"Saved {file_path}")
# open the file
with open(file_path, "r", encoding="utf-8") as file:
text_data = file.read()
return text_data


if __name__ == "__main__":
text = read_simple_text_file()
print(len(text))
4 changes: 4 additions & 0 deletions toyllm/device/__init__.py → toyllm/device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch


def get_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


current_device = get_device()
Empty file added toyllm/model/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions toyllm/model/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dataclasses import dataclass


@dataclass
class GPTModelConfig:
"""
GPT Model Architecture Config
:param vocab_size: vocabulary size
:param ctx_len: length of context/block
:param emb_dim: embedding size
:param n_heads: number of attention heads
:param n_layers: number of transformer layers
:param drop_rate: dropout rate
:param qkv_bias: query key value bias terms
"""

vocab_size: int
ctx_len: int
emb_dim: int
n_heads: int
n_layers: int
drop_rate: float
qkv_bias: bool


gpt_config_124_m = GPTModelConfig(
vocab_size=50257,
ctx_len=1024,
emb_dim=768,
n_heads=12,
n_layers=12,
drop_rate=0.1,
qkv_bias=False,
)


@dataclass
class GPTTrainingConfig:
"""
GPT training config: hyperparameters for GPT model training
"""

learning_rate: float = 5e-4
num_epochs: int = 10
batch_size: int = 2
weight_decay: float = 0.1
156 changes: 156 additions & 0 deletions toyllm/model/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import logging
import pathlib
from typing import Optional

import jaxtyping
import tiktoken
import torch
from typeguard import typechecked as typechecker

from toyllm.device import current_device
from toyllm.model.config import GPTModelConfig, gpt_config_124_m
from toyllm.model.gpt import GPTModel
from toyllm.tokenizer import gpt2_tokenizer, text_to_token_ids, token_ids_to_text

logger = logging.getLogger(__name__)


class TextGenerator:
def __init__(
self,
model_config: Optional[GPTModelConfig] = None,
model_instance: Optional[GPTModel] = None,
model_file_path: Optional[pathlib.Path] = None,
tokenizer: tiktoken.Encoding = gpt2_tokenizer,
seed: int = 42,
):
self.model_config = model_config
self.model_instance = model_instance
self.model_file_path = model_file_path
self.tokenizer = tokenizer
self.seed = seed

self.gpt_model = self.__get_gpt_model()

def __get_gpt_model(self) -> GPTModel:
torch.manual_seed(self.seed)
if self.model_instance is not None:
model = self.model_instance
self.model_config = model.config
elif self.model_config is not None:
model = GPTModel(self.model_config)
# TODO: load mode weight
if self.model_file_path is not None:
model.load_state_dict(torch.load(self.model_file_path))
else:
logger.warning("Debug mode: with random model weight")
else:
raise ValueError("Can not initialize GPT Model without model_instance or model_config")
# disable dropout and so on
model.eval()
if model.device != current_device:
model.to(current_device)
return model

@property
def context_length(self) -> int:
assert self.model_config is not None, "Model config is None"
return self.model_config.ctx_len

def generate(
self,
prompt_text: str,
max_gen_tokens: int = 10,
top_k: Optional[int] = None,
temperature: Optional[float] = None,
) -> str:
"""
:param prompt_text: prompt text
:param max_gen_tokens: maximum number of tokens to generate
:param top_k: only keep `top_k`(logits) candidate tokens to select from.
A little `top_k` will reduce the randomness of generated output.
`top_k` must be greater than 0, like 5, 10 and so on.
:param temperature: "Temperatures greater than 1 will result in more uniformly distributed token probabilities
after applying the softmax; temperatures smaller than 1 will result in
more confident (sharper or more peaky) distributions after applying the softmax"
(https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/01_main-chapter-code/ch05.ipynb)
The default temperature value is 0.6 in llama2.
"""
# prompt text to tokens: (1, n_tokens)
prompt_tokens = text_to_token_ids(prompt_text, self.tokenizer).to(self.gpt_model.device)

for _ in range(max_gen_tokens):
# Crop current context if it exceeds the supported context size(ctx_len)
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context

# (batch, n_tokens) --(crop context)--> (batch, n_tokens' = min(ctx_len, n_tokens))
context_text_token_ids = prompt_tokens[:, -self.context_length :]

# Get the predictions
# use `inference_mode` rather than `no_grad`(https://stackoverflow.com/questions/74191070)
with torch.inference_mode():
# (batch, n_token') --(forward)--> (batch, n_token', vocab_size)
logits = self.gpt_model(context_text_token_ids)

# Focus only on the last time step
# (batch, n_tokens', vocab_size) --(keep last time step token)--> (batch, vocab_size)
logits = logits[:, -1, :]

# logits filter & scale
if top_k is not None:
logits = self._logits_top_k_filter(logits, top_k)
if temperature is not None:
probs = self._logits_temperature_scale(logits, temperature)
# Sample from the scaled multinomial distribution
# (batch, vocab_size)--(keep the max prob token)--> (batch, 1)
next_token_id = torch.multinomial(probs, num_samples=1)
else:
# Get the idx of the vocab entry with the highest logits value
# (batch, vocab_size)--(keep the max prob token)--> (batch, 1)
next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

# Append sampled index to the running sequence
# (batch, n_tokens') --(append next token)--> (batch, n_tokens' + 1)
prompt_tokens = torch.cat((prompt_tokens, next_token_id), dim=1)

generate_text = token_ids_to_text(prompt_tokens)

return generate_text

@jaxtyping.jaxtyped(typechecker=typechecker)
@staticmethod
def _logits_top_k_filter(
logits: jaxtyping.Float[torch.Tensor, "batch_size vocab_size"],
top_k: int,
) -> jaxtyping.Float[torch.Tensor, "batch_size vocab_size"]:
"""
ref1: https://github.com/rasbt/LLMs-from-scratch/blob/62fb11d5e0449a6d49bda7337d6cfa5a735718da/ch05/01_main-chapter-code/generate.py#L166-L185
ref2: https://github.com/huggingface/transformers/blob/c4d4e8bdbd25d9463d41de6398940329c89b7fb6/src/transformers/generation_utils.py#L903-L941
ref3: https://github.com/meta-llama/llama/blob/main/llama/generation.py#L188-L192
"""
top_k = min(top_k, logits.size(-1)) # make sure top_k <= vocab size
top_k_logits, _top_k_indexes = torch.topk(logits, k=top_k, dim=-1)
min_logit = top_k_logits[:, -1]
logits = torch.where(logits < min_logit, torch.tensor(float("-inf")).to(logits.device), logits)
return logits

@jaxtyping.jaxtyped(typechecker=typechecker)
@staticmethod
def _logits_temperature_scale(
logits: jaxtyping.Float[torch.Tensor, "batch_size vocab_size"],
temperature: float,
) -> jaxtyping.Float[torch.Tensor, "batch_size vocab_size"]:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
return probs


if __name__ == "__main__":
text_generator = TextGenerator(model_config=gpt_config_124_m)

prompt_text = "Hello, I am"
generate_text = text_generator.generate(prompt_text=prompt_text, top_k=10, temperature=0.9)
print(generate_text)
Loading

0 comments on commit 910b719

Please sign in to comment.