diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 393c310c..02c93441 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -131,10 +131,6 @@ def is_torch_version_above(required_version): def deepspeed_checkpoint(*args, **kwargs): return deepspeed.checkpointing.checkpoint(*args, **kwargs) -@TCompileDisable -def wkv_op(time_decay, time_first, k, v, wkv_state): - return torch.ops.rwkv.wkv(time_decay, time_first, k, v, wkv_state) - ######################################################################################################## # RWKV: State Blocks ######################################################################################################## @@ -178,7 +174,7 @@ def create(N, B, C, n_head, head_size, device, dtype): # @ TCompileMax (no difference) @staticmethod def empty(N, B, C, n_head, head_size, device, dtype): - # HEAD nad HEADSIZE + # @TODO: confirm if dtype can be changed from .flaot to dtype=dtype (when bf16) wkv_states = torch.empty((N, B, n_head, head_size, head_size), device=device, # dtype=dtype) @@ -202,7 +198,7 @@ def __setitem__(self, layer: int, state: BlockState): class RWKV_TimeMix(JITModClass): - def __init__(self, layer_id, n_layer, n_embd, dim_att): + def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att): super().__init__() self.dim_att = dim_att @@ -210,9 +206,6 @@ def __init__(self, layer_id, n_layer, n_embd, dim_att): self.n_embd = n_embd self.layer_id = layer_id - head_size = 64 - n_head = n_embd // head_size - assert n_embd % n_head == 0 self.n_head = n_head self.head_size = head_size @@ -239,8 +232,9 @@ def __init__(self, layer_id, n_layer, n_embd, dim_att): self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) - # time_first (no longer fancy) - self.time_first = nn.Parameter(torch.ones(n_head) * (-3.0)) + # V5-R2 changes + self.time_faaaa = nn.Parameter(torch.ones(n_head) * 0.05) + # self.time_first = nn.Parameter(torch.ones(n_head) * (-3.0)) # self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear(n_embd, dim_att, bias=False) @@ -265,22 +259,16 @@ def _forward_rkv_chunk(self, x, B, TT, last_state: TimeMixState): k = self.key(xk).view(B, TT, self.n_head, self.head_size).transpose(1, 2).transpose(-2, -1) # BTC -> BHTS -> BHST v = self.value(xv).view(B, TT, self.n_head, self.head_size).transpose(1, 2) # BTC -> BHTS - # # Enforce bf16 type for kv, as this can be mis init - # # when being called directly via inference - # if r.dtype != torch.bfloat16: - # r = r.to(torch.bfloat16) - # if k.dtype != torch.bfloat16: - # k = k.to(torch.bfloat16) - # if v.dtype != torch.bfloat16: - # v = v.to(torch.bfloat16) - return r, k, v def _forward_wkbs_chunk(self, T, r, k, v): H = self.n_head w = torch.exp(-torch.exp(self.time_decay.float())).unsqueeze(-1) - u = torch.exp(self.time_first.float()).unsqueeze(-1) + + # V5-R2 changes + u = self.time_faaaa.float().unsqueeze(-1) + # u = torch.exp(self.time_first.float()).unsqueeze(-1) ws = w.pow(T).reshape(1, H, 1, 1) ind = torch.arange(T-1, -1, -1, device=r.device).unsqueeze(0).repeat(H, 1) @@ -313,44 +301,15 @@ def _forward_state_chunk(self, r, k, v, w, wk, wb, ws, x_l, last_state: TimeMixS if r.dtype == torch.bfloat16 and s.dtype != torch.bfloat16: s = s.contiguous().to(torch.bfloat16) - # print("") - # print("B,H,TT,S", B, H, TT, S) - # print("S-zero", torch.zeros(B, H, S, S, device=r.device, dtype=r.dtype).shape) - # print("wkv", last_state.wkv_state[:, :, :].shape) - # print("") - x = torch.zeros(B, H, TT, S, device=r.device, dtype=r.dtype) # output - # # Check if r is bf16 or float - # if s.dtype == torch.bfloat16: - # print("s is bf16") - # elif s.dtype == torch.float32: - # print("s is float32") - # else: - # print("s is neither bf16 nor float32") - - # # Check if r is bf16 or float - # if r.dtype == torch.bfloat16: - # print("r is bf16") - # elif r.dtype == torch.float32: - # print("r is float32") - # else: - # print("r is neither bf16 nor float32") - ######################################################################## for i in range(TT // T): + rr = r[:, :, i*T:i*T+T, :] kk = k[:, :, :, i*T:i*T+T] vv = v[:, :, i*T:i*T+T, :] - # # Check if r is bf16 or float - # if rr.dtype == torch.bfloat16: - # print("rr is bf16") - # elif rr.dtype == torch.float32: - # print("rr is float32") - # else: - # print("rr is neither bf16 nor float32") - x[:, :, i*T:i*T+T, :] = ((rr @ kk) * w) @ vv + (rr @ s) * wb s = ws * s + (kk * wk) @ vv @@ -359,11 +318,6 @@ def _forward_state_chunk(self, r, k, v, w, wk, wb, ws, x_l, last_state: TimeMixS x = x.transpose(1, 2).contiguous().view(B * TT, H*S) # BHTS -> BTHS -> BTC x = self.ln_x(x).view(B, TT, H*S) - # # Return with logits outputs, and new timemix state - # print("") - # print("S-right", s.shape) - # print("n_layer, n_embd, layer_id", self.n_layer, self.n_embd, self.layer_id) - # print("") return self.output(x), TimeMixState(x_l, s) def _forward_chunk(self, x, last_state: TimeMixState): @@ -443,7 +397,7 @@ def forward(self, x, last_state: ChannelMixState): class Block(nn.Module): - def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att, dim_ffn): + def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dropout, dim_att, dim_ffn): super().__init__() self.layer_id = layer_id @@ -453,22 +407,41 @@ def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att, dim_ff if self.layer_id == 0: self.ln0 = nn.LayerNorm(n_embd) - self.att = RWKV_TimeMix(layer_id, n_layer, n_embd, dim_att) + self.att = RWKV_TimeMix(layer_id, n_layer, n_embd, n_head, head_size, dim_att) self.ffn = RWKV_ChannelMix(layer_id, n_layer, n_embd, dim_ffn) + # Setup droupout at block level + self.dropout = dropout + if dropout > 0: + self.drop0 = nn.Dropout(p = dropout) + self.drop1 = nn.Dropout(p = dropout) + def forward(self, x, last_state: BlockState): if self.layer_id == 0: x = self.ln0(x) + att_out, att_state = self.att( self.ln1(x), last_state.time_mix_state, ) - x = x + att_out - ffn_out, ffn_state = self.ffn( - self.ln2(x), - last_state.channel_mix_state, - ) - x = x + ffn_out + + if self.dropout > 0.0: + # Handle with dropout + x = self.drop0(x + att_out) + ffn_out, ffn_state = self.ffn( + self.ln2(x), + last_state.channel_mix_state, + ) + x = self.drop1(x + ffn_out) + else: + # Handle without dropout + x = x + att_out + ffn_out, ffn_state = self.ffn( + self.ln2(x), + last_state.channel_mix_state, + ) + x = x + ffn_out + return x, BlockState(att_state, ffn_state) @@ -534,12 +507,17 @@ def __init__(self, lr_final: float = -1.0, lr_period: int = -1, lr_period_type: str = 'epoch', + # Dropout rate + dropout: float = 0.0, # Adam optimizer settings beta1: float = 0.9, beta2: float = 0.99, adam_eps: float = 1.0e-08, weight_decay: float = 0.01, warmup_steps: int = -1, + # loss bias start + position_loss_bias: float = 1.0, + position_loss_bias_in_validation: bool = False, # Backprop settings grad_cp: bool = True, bptt_learning: bool = True, @@ -608,6 +586,7 @@ def __init__(self, self.lr_final = lr_final self.lr_period = lr_period self.lr_period_type = lr_period_type + self.dropout = dropout self.warmup_steps = warmup_steps self.beta1 = beta1 self.beta2 = beta2 @@ -619,6 +598,10 @@ def __init__(self, self.substep_cuda_cache_clear = substep_cuda_cache_clear self.substep_logging = substep_logging + # Save the position loss params + self.position_loss_bias = position_loss_bias + self.position_loss_bias_in_validation = position_loss_bias_in_validation + dim_att = dim_att or n_embd dim_ffn = dim_ffn or n_embd * 4 self.dim_att = dim_att @@ -652,12 +635,16 @@ def __init__(self, # is_python_module=False) self.blocks = nn.ModuleList([ - Block(i, n_layer, n_embd, n_head, head_size, dim_att, dim_ffn) for i in range(n_layer) + Block(i, n_layer, n_embd, n_head, head_size, dropout, dim_att, dim_ffn) for i in range(n_layer) ]) self.ln_out = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) + # Dropout handling + if dropout > 0: + self.drop0 = nn.Dropout(p = dropout) + # load the state, and GC the original cpu copy if model_weights != None: self.load_state_dict(model_weights) @@ -682,6 +669,19 @@ def configure_optimizers(self): # Log the learning rate, and various other parameters if self.trainer.local_rank == 0: + + # Add the important notes, for informing users of common gotchas + print(( + "#\n" + "# RWKV lighting_trainer.py important notes \n" + "# https://github.com/RWKV/RWKV-infctx-trainer \n" + "#\n" + "# - Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues\n" + "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n" + "# - When resuming from checkpoint, the estimated time is inaccurate\n" + "#" + )) + lr_init_e = "{:.3e}".format(lr_init) lr_final_e = "{:.3e}".format(lr_final) print(f"\n[RWKV.model] Configuring optimizer with\n"+ @@ -707,8 +707,11 @@ def configure_optimizers(self): lr_1x.add(n) elif "time_decay" in n: lr_2x.add(n) - elif "time_first" in n: - lr_3x.add(n) + # V5-R2 changes + elif "time_faaaa" in n: + lr_2x.add(n) + # elif "time_first" in n: + # lr_3x.add(n) else: lr_1x.add(n) lr_1x = sorted(list(lr_1x)) @@ -883,6 +886,10 @@ def forward(self, idx: torch.Tensor, last_shift_states: torch.Tensor = None, x = self.emb(idx) + # Handle dropout (input) + if self.dropout > 0.0: + x = self.drop0(x) + new_states = BlockStateList.empty(self.n_layer, B, self.n_embd, self.n_head, self.head_size, x.device, x.dtype) @@ -966,41 +973,37 @@ def manual_backward(self, loss: torch.Tensor, *args, **kwargs): def compute_loss(self, batch, batch_idx, is_training_run: bool): seq = batch['input_ids'] assert isinstance(seq, torch.Tensor) and seq.ndim == 2 - seq_mask = batch['attention_mask'] + ori_seq_mask = batch['attention_mask'] # Check if attent mask is set, if not initialize it - if seq_mask is None or seq_mask.ndim != 2: - seq_mask = torch.ones_like(seq[:, 1:]) - - # Perform cutoff for training run - if is_training_run: - prev_step = 0 - - # Avoid using the zip operation, as torch.compile throws an exception on it - # with `zip not reconized as a valid function` - # --- - # for step, len_cut in zip(self.ctx_len_warmup_steps, - # self.ctx_len_cutoffs): - # --- - for i in range(min(len(self.ctx_len_warmup_steps), len(self.ctx_len_cutoffs))): - step = self.ctx_len_warmup_steps[i] - len_cut = self.ctx_len_cutoffs[i] - - if prev_step <= self.global_step < step and len_cut < seq.shape[ - 1] - 1: - pos = randint(0, seq.shape[1] - len_cut - 1) - - # Original - # seq = seq[:, pos:pos + len_cut + 1] - - # Changed to use masking for prefix cutoff (i do not know if this makes sense) - seq = seq[:, :pos + len_cut + 1] - seq_mask = seq_mask[:, :pos + len_cut + 1] - # Set the attention mask to 0 for the skipped tokens - seq_mask[:, :pos] = 0 - break - prev_step = step - + if ori_seq_mask is None or ori_seq_mask.ndim != 2: + ori_seq_mask = torch.ones_like(seq[:, 1:]) + + # Get the starting and ending loss bias + loss_bias_start = self.position_loss_bias + loss_bias_end = 2.0 - loss_bias_start + + # total_mask_sum + total_mask_sum = torch.sum(ori_seq_mask) + + # Skip loss bias calculation, if loss_bias_start is 1.0 + if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False): + seq_mask = ori_seq_mask + else: + # Lets get a linear multiplier for the loss bias + # seq_mask_sum = torch.sum(ori_seq_mask) + bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device) + + # Boolean flag of seq_mask > 0 + seq_mask_index = ori_seq_mask[0] > 0 + + # Apply the bias mask only to positive seq_mask values + final_mask = torch.zeros(ori_seq_mask.shape[1], device=ori_seq_mask.device) + final_mask[seq_mask_index] = ori_seq_mask[0][seq_mask_index] * bias_mask + + # And save it as seq_mask + seq_mask = final_mask.unsqueeze(0) + do_bptt_learning = self.bptt_learning and is_training_run idx, targets = seq[:, :-1], seq[:, 1:] @@ -1101,7 +1104,11 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, if self.trainer.num_devices > 1: if self.bptt_learning_range <= 0: # We perform forward/backward on the shared max segment count across all GPUs - forward_segment_count = self.trainer.strategy.reduce(segment_count, reduce_op="max").item() + forward_segment_count = self.trainer.strategy.reduce(segment_count, reduce_op="max") + # Convert to int, if its a torch tensor + if isinstance(forward_segment_count, torch.Tensor): + forward_segment_count = forward_segment_count.item() + # We perform as many backward pass as we need to be equal or more then bptt_learning_range backward_segment_count = forward_segment_count else: # We perform as many forward pass as we need to be equal or more then bptt_learning_range @@ -1382,7 +1389,8 @@ def decode(self, tokens: list): # Forwarding logic, withoout torch._no_grad() context def _forward( self, tokens, - stateObj = None + stateObj = None, + all_logits = False ): logits_arr = None @@ -1396,11 +1404,17 @@ def _forward( shift_states = stateObj["shift_states"] wkv_states = stateObj["wkv_states"] + # The all_logits array, if requested + all_logits_arr = None + # For each token, process the state, in batches up to ctx_len for i in range(0, token_len, self.ctx_len): + # Token set + token_set = tokens[i:i+self.ctx_len] + # Check if tokens are already tensors batch_tokens = torch.tensor( - tokens[i:i+self.ctx_len], + token_set, dtype=torch.long, device=self.device ).unsqueeze(0) @@ -1409,16 +1423,27 @@ def _forward( batch_tokens, shift_states, wkv_states ) + # Build the all_logits array + if all_logits: + if all_logits_arr is None: + all_logits_arr = logits_arr[0] + else: + all_logits_arr = torch.cat([all_logits_arr, logits_arr[0]], dim=0) + # Return the logits and state - return logits_arr[0][-1], { "shift_states": shift_states, "wkv_states": wkv_states } + if all_logits: + return all_logits_arr, { "shift_states": shift_states, "wkv_states": wkv_states } + else: + return logits_arr[0][-1], { "shift_states": shift_states, "wkv_states": wkv_states } # Forwarding logic, with torch._no_grad() context def forward( self, tokens:list, - stateObj = None + stateObj = None, + all_logits = False ): with torch.no_grad(): - return self._forward(tokens, stateObj) + return self._forward(tokens, stateObj, all_logits) # Sampling logits def sample_logits(