Skip to content

Commit

Permalink
feat(tinyshakespeare): support for another dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 7, 2024
1 parent 8a544dd commit b135dcf
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 12 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ pip install -e .

Use the following dummy script to make sure everything is working as expected
```bash
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200
```

Then make use of the training script
Expand Down Expand Up @@ -97,6 +97,7 @@ python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model
- [x] Auto Save checkpoints
- [x] Auto Save checkpoints to W&B
- [ ] Auto Download model weights from git / huggingface
- [ ] W&B hyperparam sweep script
- [x] Script to load checkpoint in interactive mode
- [ ] Training script to PyTorch Lighting
- [x] Integrate with [efficient-kan](https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py)
Expand Down
113 changes: 107 additions & 6 deletions kan_gpt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@

class WebTextDataset(Dataset):
"""
Dataset for the Sort problem. E.g. for problem length 6:
Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
Which will feed into the transformer concatenated as:
input: 0 0 2 1 0 1 0 0 0 1 1
output: I I I I I 0 0 0 1 1 2
where I is "ignore", as the transformer is reading the input sequence
WebText Dataset
"""

def __init__(self, split, model_type, block_size=1024, vocab_size=50257):
Expand Down Expand Up @@ -87,3 +82,109 @@ def __getitem__(self, idx):
# y = y.unsqueeze(0)

return x, y


class TinyShakespeareDataset(Dataset):
"""
Tiny Shakespeare dataset
"""

TRAIN = 0.8
VALID = 0.1
TEST = 0.1

def __init__(
self,
split,
model_type,
block_size=1024,
vocab_size=50257,
):
assert split in {"train", "test", "valid"}

self.split = split
self.block_size = block_size
self.vocab_size = vocab_size
self.tokenizer = GPT2Tokenizer.from_pretrained(model_type)

self.tokenized_dataset_path = (
f"datasets/tinyshakespeare/input.{split}.pkl"
)

if not os.path.isfile(self.tokenized_dataset_path):
self.tokenized_dataset = []

self.tsp_path = "datasets/tinyshakespeare/input.txt"

assert os.path.isfile(self.tsp_path)

self.data = open(self.tsp_path, "r").readlines()

# Select data based on split
# First self.TRAIN % is for train
# Next self.VALID % is for validation
# Last self.TEST % is for test
if self.split == "train":
self.data = self.data[
: int((1 - (self.VALID + self.TEST)) * len(self.data))
]
elif self.split == "val":
self.data = self.data[
int((1 - (self.VALID + self.TEST)) * len(self.data)) : int(
((1 - (self.VALID + self.TEST))) * len(self.data)
)
+ int((1 - (self.TEST)) * len(self.data))
]
elif self.split == "test":
self.data = self.data[
int(((1 - (self.VALID + self.TEST))) * len(self.data)) :
]

tokenized_data = []
tokenized_lengths = []

for text in tqdm(
self.data, desc="Tokenizing", total=len(self.data)
):
tokenized = self.tokenizer.encode(
text=text, add_special_tokens=False
)
tokenized_length = len(tokenized)

tokenized_data.append(tokenized)
tokenized_lengths.append(tokenized_length)

self.tokenized_dataset += tokenized

with open(self.tokenized_dataset_path, "wb") as f:
pickle.dump(self.tokenized_dataset, f)

with open(self.tokenized_dataset_path, "rb") as f:
self.tokenized_dataset = pickle.load(f)

def __len__(self):
return len(self.tokenized_dataset) - 2 * self.block_size

def get_vocab_size(self):
return self.vocab_size

def get_block_size(self):
return self.block_size

def __getitem__(self, idx):

x = self.tokenized_dataset[idx : idx + self.block_size]
y = self.tokenized_dataset[
idx + self.block_size : idx + 2 * self.block_size
]

assert len(x) == self.block_size, f"Unexpected len: {len(x)}"
assert len(y) == self.block_size, f"Unexpected len: {len(y)}"

x = torch.tensor(x, dtype=torch.long)
y = torch.tensor(y, dtype=torch.long)

# x = x.unsqueeze(0)
# y = y.unsqueeze(0)

return x, y
19 changes: 15 additions & 4 deletions kan_gpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from wandb.sdk.wandb_run import Run

import wandb
from kan_gpt.dataset import WebTextDataset
from kan_gpt.dataset import TinyShakespeareDataset, WebTextDataset
from kan_gpt.mingpt.model import GPT as MLP_GPT
from kan_gpt.mingpt.trainer import Trainer
from kan_gpt.model import GPT as KAN_GPT
Expand Down Expand Up @@ -70,19 +70,25 @@ def main(args):
"learning_rate": args.learning_rate,
"max_iters": args.max_iters,
"num_workers": args.num_workers,
"dataset": args.dataset,
"architecture": args.architecture,
"device": args.device,
}

model_type = args.model_type

if args.dataset == "webtext":
Dataset = WebTextDataset
elif args.dataset == "tinyshakespeare":
Dataset = TinyShakespeareDataset

# print an example instance of the dataset
if args.dummy_dataset:
train_dataset = WebTextDataset("test", "gpt2")
train_dataset = Dataset("test", "gpt2")
else:
train_dataset = WebTextDataset("train", "gpt2")
train_dataset = Dataset("train", "gpt2")

test_dataset = WebTextDataset("test", "gpt2")
test_dataset = Dataset("test", "gpt2")

print("test_dataset: ", len(test_dataset))
print("train_dataset: ", len(train_dataset))
Expand Down Expand Up @@ -179,6 +185,11 @@ def batch_end_callback(trainer):
parser.add_argument("--num_workers", default=0)
parser.add_argument("--batch_size", default=64)

parser.add_argument(
"--dataset",
choices=["webtext", "tinyshakespeare"],
default="tinyshakespeare",
)
parser.add_argument(
"--architecture", choices=["MLP", "KAN"], default="KAN"
)
Expand Down
6 changes: 6 additions & 0 deletions scripts/download_tinyshakespeare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

mkdir -p datasets/tinyshakespeare
cd datasets/tinyshakespeare

wget -c https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

0 comments on commit b135dcf

Please sign in to comment.