Skip to content

Commit

Permalink
Merge pull request #45 from RWKV/v5-batching-support
Browse files Browse the repository at this point in the history
V5 batching support + extras
  • Loading branch information
PicoCreator authored Nov 16, 2023
2 parents 1e287e9 + 5607e7a commit 75b3ba5
Show file tree
Hide file tree
Showing 7 changed files with 787 additions and 69 deletions.
17 changes: 17 additions & 0 deletions RWKV-v5/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ trainer:
# your GPU processing time %, and avoid idle time for the GPU between batches
target_batch_size: 32

# Microbatching chunks which we split our data by, this substentially increase vram usage
# for each GPU step, but increase throughput of the training process substentially.
#
# So if you have 16 datasample per batch per GPU. And microbatch_size of 2, you have 8 substep
#
# It is generally recommended to tune this to be the highest you can resonably support
# on your GPU as it has a direct impact on your overall tokens / second count.
#
# Typically you tune the microbatch_size first, before tuning the target_batch_size
microbatch_size: 1

# You can alternatively set the accumulate_grad_batches per GPU directly
# (not recommended)
#
Expand Down Expand Up @@ -409,6 +420,12 @@ data:
# A minimum of 2 columns is required, with non empty data, for the merge to occur
# If no match is found, this will fallback to the default prompt/completion or text column,
# or throw an error if the default fallback is not found
#
# IMPORTANT NOTE: as newlines are commonly used for multi_column_suffix, etc.
# you should use single quotes to ensure such values dun get escaped.
# eg. multi_column_suffix: ['\n\n']
#
# See: https://github.com/RWKV/RWKV-infctx-trainer/issues/34
# ---
# multi_column_keys: ['instruction', 'input', 'output']
# multi_column_prefix: ['Instruction:\n', 'Input:\n', 'Output:\n']
Expand Down
63 changes: 59 additions & 4 deletions RWKV-v5/src/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lightning import LightningDataModule

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler

Expand Down Expand Up @@ -293,10 +294,15 @@ def map_tokenizer(x):
input_ids += column_encodings['input_ids']
token_type_ids += column_encodings['token_type_ids']

# Override the training attention mask if masking is set to false
if len(multi_column_train_mask) < i and multi_column_train_mask[i] is False:
# Configure the attention masks accordingly
if i > len(multi_column_train_mask):
# If the corresponding `multi_column_train_mask` is not set, we will assume as valid training data
attention_mask += ([1] * len(column_encodings['input_ids']))
elif multi_column_train_mask[i] is False:
# If the `multi_column_train_mask` is set, but configured as false, we should not pay attention to it
attention_mask += ([0] * len(column_encodings['input_ids']))
else:
else: # multi_column_train_mask[i] is True
# This means it is true, lets pay attention once again
attention_mask += ([1] * len(column_encodings['input_ids']))

# Add the suffix
Expand Down Expand Up @@ -494,6 +500,48 @@ def add_length(example):
# Save the dataset to disk
src_dataset.save_to_disk(kargs["data_path"])

# Dataloader collator for merging multiple dataset records together
# we use token 0 for padding, with a learning mask value of 0
def dataloader_collator_fn(records):
# Get the maximum number of records
# (aka the batch size)
records_len = len(records)

# Compute the total length of the records
input_ids_len = 0
token_type_ids_len = 0
attention_mask_len = 0

# Loop through the records and compute the max length
for i in range(records_len):
input_ids_len = max(input_ids_len, len(records[i]["input_ids"]))
token_type_ids_len = max(token_type_ids_len, len(records[i]["token_type_ids"]))
attention_mask_len = max(attention_mask_len, len(records[i]["attention_mask"]))

# First row of the records
first_row = records[0]

# Create the output arrays, with the default 0 values (no learning mask)
out_input_ids = torch.zeros((records_len, input_ids_len), dtype=first_row["input_ids"].dtype)
out_token_type_ids = torch.zeros((records_len, token_type_ids_len), dtype=first_row["token_type_ids"].dtype)
out_attention_mask = torch.zeros((records_len, attention_mask_len), dtype=first_row["attention_mask"].dtype)
out_data_ctx_len = torch.zeros((records_len), dtype=torch.int32)

# Loop through the records and copy the values to the output arrays
for i in range(records_len):
out_input_ids[i][:len(records[i]["input_ids"])] = records[i]["input_ids"]
out_token_type_ids[i][:len(records[i]["token_type_ids"])] = records[i]["token_type_ids"]
out_attention_mask[i][:len(records[i]["attention_mask"])] = records[i]["attention_mask"]
out_data_ctx_len[i] = len(records[i]["input_ids"])

# Build & return the output object
out = {
'input_ids': out_input_ids,
'token_type_ids': out_token_type_ids,
'attention_mask': out_attention_mask,
'data_ctx_len': out_data_ctx_len
}
return out

class RWKVDataModule(LightningDataModule):
def __init__(
Expand Down Expand Up @@ -595,6 +643,11 @@ def train_dataloader(self):
num_replicas=self.trainer.world_size,
rank=self.trainer.global_rank,
)

microbatch_size = 1
if hasattr(self, "trainer") and hasattr(self.trainer, "microbatch_size"):
microbatch_size = self.trainer.microbatch_size

return DataLoader(
dataset,
sampler=sampler,
Expand All @@ -604,7 +657,9 @@ def train_dataloader(self):
# Prefetching 8 batches
prefetch_factor=8,
# Of batch size 1 datasets
batch_size=1,
batch_size=microbatch_size,
# The collation function
collate_fn=dataloader_collator_fn,
# Pinned in GPU memory
pin_memory=True
)
Expand Down
93 changes: 71 additions & 22 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,21 @@ def forward(ctx, loss, y, token_amount, currentMask):
#
# See also:
# - checkpointed_step
ctx.save_for_backward(y)
ctx.token_amount = token_amount
ctx.currentMask = currentMask
ctx.save_for_backward(y, token_amount, currentMask)
return loss

@staticmethod
def backward(ctx, grad_output):
y, = ctx.saved_tensors
token_amount = ctx.token_amount
y, token_amount, currentMask = ctx.saved_tensors

# to encourage the logits to be close to 0
factor = 1e-4 / token_amount
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
gy = gy * ctx.currentMask[:, None][None, :]

# We ensure the mask is reshaped accordingly, and apply it against gy
gy = gy * currentMask.reshape(gy.shape[0],gy.shape[1],1) # currentMask[:, None][None, :]
return (grad_output, gy, None, None)

### ---
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(self,
grad_cp: bool = True,
bptt_learning: bool = True,
bptt_learning_range: int = -1,
bptt_truncated_learning: bool = False,
bptt_truncated_learning: bool = True,
layerwise_lr: bool = True,
dim_att: Optional[int] = None,
dim_ffn: Optional[int] = None,
Expand Down Expand Up @@ -784,7 +784,8 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool):
self._counting_tokens = 0
if self._counting_time_start is None or batch_idx == 0:
self._counting_time_start = time.time()


# Get the input sequence, and attention mask
seq = batch['input_ids']
assert isinstance(seq, torch.Tensor) and seq.ndim == 2
ori_seq_mask = batch['attention_mask']
Expand All @@ -793,17 +794,30 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool):
if ori_seq_mask is None or ori_seq_mask.ndim != 2:
ori_seq_mask = torch.ones_like(seq[:, 1:])

# Initialize the total_mask_sum (but not compute it)
total_mask_sum = 0

# Number of GPUs used in training, note that if it is > 1
# it is requried that all operations here are in sync with
# all other GPUs, as such "quick return" on this function
# should not be allowed
num_devices = self.trainer.num_devices

### ---
### Positional loss bias handling
### ---

# Get the starting and ending loss bias
loss_bias_start = self.position_loss_bias
loss_bias_end = 2.0 - loss_bias_start

# total_mask_sum
total_mask_sum = torch.sum(ori_seq_mask)

# Skip loss bias calculation, if loss_bias_start is 1.0
if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False):
seq_mask = ori_seq_mask
else:
# Lets get the torch mask sum
total_mask_sum = torch.sum(ori_seq_mask)

# Lets get a linear multiplier for the loss bias
# seq_mask_sum = torch.sum(ori_seq_mask)
bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device)
Expand All @@ -818,12 +832,18 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool):
# And save it as seq_mask
seq_mask = final_mask.unsqueeze(0)

### ---
### Training cutoff logic handling
### ---

# Perform cutoff for training run
if is_training_run:
prev_step = 0

# Avoid using the zip operation, as torch.compile throws an exception on it
# with `zip not reconized as a valid function`
#
# This skip if ctx_len_warmup_steps/ctx_len_cutoffs is not set
# ---
# for step, len_cut in zip(self.ctx_len_warmup_steps,
# self.ctx_len_cutoffs):
Expand All @@ -846,23 +866,35 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool):
seq_mask[:, :pos] = 0
break
prev_step = step


### ---
### Various size checking, and implementing the core checkpoint_step
### ---

# BPTT, and training steps, and various size fetching
do_bptt_learning = self.bptt_learning and is_training_run
idx, targets = seq[:, :-1], seq[:, 1:]

B, T = idx.shape
C = self.n_embd
total_mask_sum = torch.sum(seq_mask)

# If total_mask_sum, we skip, as there is no tokens of value to learn from anyway
if total_mask_sum == 0:
total_mask_sum = torch.sum(seq_mask)
# Do a quick return, if there is no tokens of value to learn from due to full masking
if num_devices > 1 and total_mask_sum == 0:
return 0

# Checkpoint steps
def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
last_wkv_states, prev_steps):
logits, new_shift_states, new_wkv_states = self(
idx, last_shift_states, last_wkv_states)

# Ensure logits, targets, and mask are contiguous
# this is required to avoid view is not compatible with size and stride error
logits = logits.contiguous()
targets = targets.contiguous()
mask = mask.contiguous()

loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction="none")
Expand All @@ -876,14 +908,17 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
new_loss = prev_loss + loss
return new_loss, new_shift_states, new_wkv_states, new_steps

total_loss = torch.tensor(
0, dtype=self.emb.weight.dtype).requires_grad_()
total_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
steps = 0
states = BlockStateList.create(self.n_layer, B, C,
self.n_head, self.head_size,
seq.device, self.emb.weight.dtype)
segment_count = math.ceil(T / self.ctx_len)

### ---
### Learning process logic (BPTT or not)
### ---

#
# BPTT learning, we split the sequence into segments
# and perform a backward pass for each segment, on its own.
Expand Down Expand Up @@ -919,12 +954,12 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
# it also helps ensure the segment cutoff points are more varied, across mixed dataset sizes
# and avoid potentially undesired training behaviour at fixed cutoff points
# (this only applies for segmented learning)
segment_size = min(math.ceil(T / segment_count), self.ctx_len)
segment_size = min(math.ceil(T / segment_count)+1, self.ctx_len)

# Dummy 2D tenros of shape [1,1], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync
# Dummy 2D tensor of shape [1,1], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync
dummy_2d_zero = torch.tensor([[0]], dtype=torch.long, device=cur_device)

# Get the max segment count across all GPUs, in the current batch, which is used to keep all devices are in sync
# Get the max segment count across all GPUs, in the current substep, which is used to keep all devices are in sync
# Once a thread has completed all its segments, it will do dummy checkpoint/forward/backprop with one token,
# and stay in sync with the thread that are still working on their segments
#
Expand Down Expand Up @@ -1118,13 +1153,20 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
global_rank = self.global_rank
global_device_count = self.trainer.num_devices * self.trainer.num_nodes

# Get the total dataset context length
batch_ctx_len = 0
if "data_ctx_len" in batch:
batch_ctx_len = torch.sum(batch["data_ctx_len"]).item()
else:
batch_ctx_len = T * self.trainer.microbatch_size

# Increment the counting tokens, and log it accordingly
self._counting_tokens += T
self._counting_tokens += batch_ctx_len

# Log the line values
wandb.log({
'global_rank': global_rank,
'real_ctx_len': T,
'data_ctx_len': batch_ctx_len / self.trainer.microbatch_size,
'train/loss': total_loss,
f'perf/tokens_total.gpu.{global_rank}': self._counting_tokens,
f'perf/tokens_per_sec.gpu.{global_rank}': self._counting_tokens / max(time.time() - self._counting_time_start, 1),
Expand All @@ -1138,8 +1180,15 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
assert not torch.isnan(total_loss), "total_loss is NaN"
return total_loss

#
# Training and validation steps
#
@TCompileBaseline
def training_step(self, batch, batch_idx):

# print("=== BATCH ID SHAPE ===", batch["input_ids"].shape)
# print("=== BATCH AM SHAPE ===", batch["attention_mask"].shape)

total_loss = self.compute_loss(batch, batch_idx, True)

self.log('train/loss', total_loss, prog_bar=True)
Expand Down
Loading

0 comments on commit 75b3ba5

Please sign in to comment.