Skip to content

Latest commit

 

History

History
471 lines (377 loc) · 14.5 KB

main.md

File metadata and controls

471 lines (377 loc) · 14.5 KB
jupyter
jupytext kernelspec
text_representation
extension format_name format_version jupytext_version
.md
markdown
1.3
1.16.1
display_name language name
Python 3 (ipykernel)
python
python3
import torch.nn as nn
import torch.nn.functional as F
import torch
from dataclasses import dataclass
torch.cuda.is_available()
@dataclass
class ModelArgs:
    batch_size: int = 32
    d_model: int = 512
    hidden_dim: int = 1024 # hidden dim for feed forward layer
    num_blocks: int = 32
    num_q_heads: int = 32
    num_kv_heads: int = 16
    vocab_size: int = -1 # initialized later
    eps: float = 1e-6 # eps for RMSNorm
    max_batch_size: int = 32
    max_seq_len: int = 512
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    freq_base: int = 10000
    epochs: int = 5
def compute_freqs(d_model:int, seq_len: int, device: torch.device | str, base: int):
    """
    d_model: embedding dim
    seq_len: sequence length
    device: cuda / cpu
    base: base for exponential of theta values
    """

    assert d_model % 2 == 0, "d_model has to be even"
    
    theta = 1. / (base ** (torch.arange(0, d_model, 2) / d_model)).to(device)
    m = torch.arange(seq_len).to(device)
    freqs = torch.outer(m, theta).float() # Since each m value corresponds to a single token, multiply every value of m by every value of theta, kind of like a nested for loop.
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs) # turn into complex form, z = r*cis(theta), in this case, r is 1
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: torch.device):
    """
    x: input sequence to add positional embedding, (batch, seq_len, emb_dim)
    freqs_complex: frequencies for rotary postitional embeddings
    device: cuda / cpu
    """
    # print(x.shape, freqs_complex.shape)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (batch, seq_len, _, 2)
    # print(x_complex.shape)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # print(freqs_complex.shape)
    x_rotated = x_complex * freqs_complex # * is for element-wise multiplication
    # print(x_rotated.shape)
    x_rotated = torch.view_as_real(x_rotated)
    # print(x_rotated.shape)
    x_rotated = x_rotated.reshape(*x.shape)
    # print(x_rotated.shape)
    return x_rotated.type_as(x).to(device)

a = compute_freqs(32, 10, torch.device("cuda"), 10000)
print(a.shape)
b = torch.randn((3, 10, 5, 32)).to(torch.device("cuda"))
apply_rotary_embeddings(b, a, torch.device("cuda"))
class RMSNorm(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.eps = args.eps
        self.gamma = nn.Parameter(torch.ones(args.d_model))

    def rms(self, x: torch.Tensor):
        x = torch.pow(x, 2)
        x = torch.mean(x, dim=-1, keepdim=True)
        x = torch.sqrt(x + self.eps) # add eps to in case x = 0 (sqrt(0) is undefined in math)
        return x

    def forward(self, x: torch.Tensor):
        return x / self.rms(x) * self.gamma

args = ModelArgs(d_model=32)
r = RMSNorm(args)
a = torch.randn((5, 10, 32))
r(a).shape
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.W = nn.Linear(args.d_model, args.hidden_dim)
        self.V = nn.Linear(args.d_model, args.hidden_dim)
        self.f = nn.Linear(args.hidden_dim, args.d_model)
    
    def forward(self, x: torch.Tensor):
        swiglu = F.silu(self.W(x)) * self.V(x)
        return self.f(swiglu)

l = FeedForward(args)
a = torch.randn((5, 10, 32))
l(a).shape
class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.num_rep = args.num_q_heads // args.num_kv_heads
        
        self.q = nn.Linear(args.d_model, args.num_q_heads * args.d_model, bias=False)
        self.k = nn.Linear(args.d_model, args.num_kv_heads * args.d_model, bias=False)
        self.v = nn.Linear(args.d_model, args.num_kv_heads * args.d_model, bias=False)

        self.out = nn.Linear(args.num_q_heads * args.d_model, args.d_model, bias=False)
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, args.num_kv_heads, args.d_model))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, args.num_kv_heads, args.d_model))

    
    def repeat_kv(self, x: torch.Tensor, n_rep: int):
        batch_size, seq_len, num_kv_heads, emb_dim = x.shape
        if n_rep == 1:
            return x
        return x[:, :, :, None, :].expand(batch_size, seq_len, num_kv_heads, n_rep, emb_dim).reshape(batch_size, seq_len, num_kv_heads * n_rep, emb_dim)

    
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor, use_kv_cache=False):
        batch_size, seq_len, _ = x.shape

        q = self.q(x) # (batch_size, seq_len, head_num, emb_dim)
        k = self.k(x)
        v = self.v(x)
        
        q = q.view(batch_size, seq_len, self.args.num_q_heads, self.args.d_model)
        k = k.view(batch_size, seq_len, self.args.num_kv_heads, self.args.d_model)
        v = v.view(batch_size, seq_len, self.args.num_kv_heads, self.args.d_model)

        q = apply_rotary_embeddings(q, freqs_complex, device=self.args.device)
        k = apply_rotary_embeddings(k, freqs_complex, device=self.args.device)
        
        keys, values = k, v
        if use_kv_cache:
            self.cache_k[:batch_size, start_pos: start_pos + seq_len] = k
            self.cache_v[:batch_size, start_pos: start_pos + seq_len] = v

            keys = self.cache_k[:batch_size, :start_pos + seq_len] # all cache including added key
            values = self.cache_v[:batch_size, :start_pos + seq_len]

        keys = self.repeat_kv(keys, self.num_rep)
        values = self.repeat_kv(values, self.num_rep)

        q = q.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 3, 1)
        values = values.permute(0, 2, 3, 1) # (batch_size, head_num, emb_dim, seq_len)
        
        scores = torch.matmul(q, keys) / self.args.d_model # (batch_size, head_num, seq_len, seq_len)
        scores = F.softmax(scores.float(), dim=-1).type_as(q)
        
        output = torch.matmul(scores, values) # (batch_size, head_num, seq_len, emb_dim)
        output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)

        return self.out(output)
class EncoderBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.attention = MultiHeadAttention(args)
        self.feed_forward = FeedForward(args)
        self.norm1 = RMSNorm(args)
        self.norm2 = RMSNorm(args)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        h = x + self.attention(self.norm1(x), start_pos, freqs_complex)
        out = h + self.feed_forward(self.norm2(h))
        return out
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.args = args

        self.embedding_layer = nn.Embedding(args.vocab_size, args.d_model)

        self.encoder_layers = nn.ModuleList()
        for _ in range(args.num_blocks):
            self.encoder_layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args)
        self.output = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.freqs_complex = compute_freqs(args.d_model, args.max_seq_len, device=args.device, base=args.freq_base)

    def forward(self, tokens: torch.Tensor, start_pos: int):
        h = self.embedding_layer(tokens)
        freqs_complex = self.freqs_complex[:, start_pos: start_pos + tokens.shape[1]]

        h = self.norm(h)

        for layer in self.encoder_layers:
            h = layer(h, start_pos, freqs_complex)

        output = self.output(h).float()
        return output
import torch
import torch.nn as nn

class TransformerSparseEmbeddings(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.args = args

        # Use Sparse Embedding
        self.embedding_layer = nn.EmbeddingBag(args.vocab_size, args.d_model, sparse=True)

        self.encoder_layers = nn.ModuleList()
        for _ in range(args.num_blocks):
            self.encoder_layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args)
        self.output = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.freqs_complex = compute_freqs(args.d_model, args.max_seq_len, device=args.device, base=args.freq_base)

    def forward(self, tokens: torch.Tensor, start_pos: int):
        # Modify the embedding lookup to work with sparse embeddings
        offsets = torch.arange(0, tokens.size(0) * tokens.size(1), tokens.size(1), device=tokens.device)
        h = self.embedding_layer(tokens.view(-1), offsets)

        h = h.view(tokens.size(0), tokens.size(1), -1)
        freqs_complex = self.freqs_complex[:, start_pos: start_pos + tokens.shape[1]]

        h = self.norm(h)

        for layer in self.encoder_layers:
            h = layer(h, start_pos, freqs_complex)

        output = self.output(h).float()
        return output
import sentencepiece as spm
import pandas as pd
df = pd.read_csv('./DeepLearning/Ja-En-LLaMA/en-ja.bicleaner05.txt', sep="\\t", header=None)
df.head()[[3, 4]]
df.info()
df[3][9]
# using trained tokenizers from http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl/
english_tokenizer = spm.SentencePieceProcessor("./DeepLearning/Ja-En-LLaMA/enja_spm_models/spm.en.nopretok.model")
japanese_tokenizer = spm.SentencePieceProcessor("./DeepLearning/Ja-En-LLaMA/enja_spm_models/spm.ja.nopretok.model")
df[3].isna().unique(), df[4].isna().unique()
english_tokenizer.encode("Go to the original video hierarchy of the conversion source, copy and paste the following is fine. ffmpeg -i sample.mp4 -strict -2 video.webm summary I’ve been using the upload and embed method to Youtube to set up videos on the web.", out_type=str)
japanese_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type=str)
english_tokenizer.vocab_size(), japanese_tokenizer.vocab_size()
from collections import Counter
from torchtext.vocab import vocab


def build_vocab(sentences, tokenizer):
    counter = Counter()
    for sentence in sentences:
        counter.update(tokenizer.encode(sentence, out_type=str))
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) # specials: include special tokens in the mapping
print(len(df), len(df) // 10)
df = df.sample(frac=1)
english = df[3].tolist()[:len(df) // 500]
japanese = df[4].tolist()[:len(df) // 500]
assert len(english) == len(japanese)
l = len(english)
train_en = english[:int(0.7 * l)]
val_en = english[int(0.7 * l): int(0.85 * l)]
test_en = english[int(0.85 * l): l]
train_ja = japanese[:int(0.7 * l)]
val_ja = japanese[int(0.7 * l): int(0.85 * l)]
test_ja = japanese[int(0.85 * l): l]
print(len(train_en), len(df))
vocab_ja = build_vocab(japanese, japanese_tokenizer)
vocab_en = build_vocab(english, english_tokenizer)
def data_process(ja, en):
    data = []
    for (raw_ja, raw_en) in zip(ja, en):
        ja_tensor = torch.tensor([vocab_ja[token] for token in japanese_tokenizer.encode(raw_ja.strip("\n"), out_type=str)], dtype=torch.long)
        en_tensor = torch.tensor([vocab_en[token] for token in english_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)], dtype=torch.long)
        data.append((ja_tensor, en_tensor))
    return data
train = data_process(train_ja, train_en)
a = ModelArgs()
print(vars(a))
from torch.nn.utils.rnn import pad_sequence

PAD_IDX = vocab_ja['<pad>']
BOS_IDX = vocab_ja['<bos>']
EOS_IDX = vocab_ja['<eos>']

def generate_batch(data_batch):
    ja_batch, en_batch = [], []
    for (ja_item, en_item) in data_batch:
        ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX) # pad sequences into equal length
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return ja_batch, en_batch
from torch.utils.data import DataLoader

args = ModelArgs()

train_iter = DataLoader(train, batch_size=args.batch_size, shuffle=True, collate_fn=generate_batch) # type: ignore
len(vocab_ja)
args.vocab_size = len(vocab_ja)
# args = ModelArgs(batch_size=32, d_model=64, hidden_dim=512, num_blocks=8, num_q_heads=32, num_kv_heads=16, vocab_size=100)
print(vars(args))
args = ModelArgs(batch_size=32, d_model=64, hidden_dim=512, num_blocks=8, num_q_heads=32, num_kv_heads=16, vocab_size=len(vocab_ja))
transformer = TransformerSparseEmbeddings(args)
transformer = transformer.to(args.device)
for (i, j) in train_iter:
    print(i.shape, j.shape)
    break
len(train_iter)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001)

def train_epoch(model, train_iter, optimizer, device): # https://chatgpt.com/c/2be13c26-50d8-4583-97b6-397c1fe2d028
    model.train()
    losses = 0
    for i, (x, y) in enumerate(train_iter):
        x = x.to(device)
        y = y.to(device)
        y = y[:-1, :]
        logits = model(x, 0)

        y_out = y[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), y_out.reshape(-1))
        loss.backward()
        optimizer.step()
        losses += loss.item()
    return losses / len(train_iter)
def evaluate(model, val_iter, device):
    model.eval()
    losses = 0
    for idx, (src, tgt) in (enumerate(val_iter)):
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:-1, :]
        logits = model(src, 0)
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(val_iter)
import time
for epoch in range(1, args.epochs+1):
    start_time = time.time()
    train_loss = train_epoch(transformer, train_iter, optimizer, args.device)
    end_time = time.time()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s"))