diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index e1b28d29..80e803d3 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -1,4 +1,4 @@ -name: Docker Image Publish +name: Docker Env Image (cuda-11-8) on: push: @@ -16,7 +16,8 @@ env: jobs: build: - + name: Docker Env Image (cuda-11-8) + runs-on: ubuntu-latest permissions: contents: read @@ -27,7 +28,7 @@ jobs: steps: # Get and log the free space - - name: Build + - name: Get system free space (Before reclaim) run: | echo "Free space:" df -h @@ -53,7 +54,7 @@ jobs: docker-images: true # Get and log the free space - - name: Build + - name: Get system free space (After reclaim) run: | echo "Free space:" df -h @@ -90,7 +91,11 @@ jobs: uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - + + - name: downcase IMAGE_NAME + run: | + echo "IMAGE_NAME_LC=${IMAGE_NAME,,}" >>${GITHUB_ENV} + # Build and push Docker image with Buildx (don't push on PR) # https://github.com/docker/build-push-action - name: Build and push Docker image @@ -99,7 +104,7 @@ jobs: with: context: "{{defaultContext}}:docker/env-cuda-11-8" push: ${{ github.event_name != 'pull_request' }} # Don't push on PR - tags: "env-cuda-11-8" + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME_LC }}:env-cuda-11-8 # tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha,src=docker/env-cuda-11-8 diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 02c93441..393c310c 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -131,6 +131,10 @@ def is_torch_version_above(required_version): def deepspeed_checkpoint(*args, **kwargs): return deepspeed.checkpointing.checkpoint(*args, **kwargs) +@TCompileDisable +def wkv_op(time_decay, time_first, k, v, wkv_state): + return torch.ops.rwkv.wkv(time_decay, time_first, k, v, wkv_state) + ######################################################################################################## # RWKV: State Blocks ######################################################################################################## @@ -174,7 +178,7 @@ def create(N, B, C, n_head, head_size, device, dtype): # @ TCompileMax (no difference) @staticmethod def empty(N, B, C, n_head, head_size, device, dtype): - # @TODO: confirm if dtype can be changed from .flaot to dtype=dtype (when bf16) + # HEAD nad HEADSIZE wkv_states = torch.empty((N, B, n_head, head_size, head_size), device=device, # dtype=dtype) @@ -198,7 +202,7 @@ def __setitem__(self, layer: int, state: BlockState): class RWKV_TimeMix(JITModClass): - def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att): + def __init__(self, layer_id, n_layer, n_embd, dim_att): super().__init__() self.dim_att = dim_att @@ -206,6 +210,9 @@ def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att): self.n_embd = n_embd self.layer_id = layer_id + head_size = 64 + n_head = n_embd // head_size + assert n_embd % n_head == 0 self.n_head = n_head self.head_size = head_size @@ -232,9 +239,8 @@ def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att): self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) - # V5-R2 changes - self.time_faaaa = nn.Parameter(torch.ones(n_head) * 0.05) - # self.time_first = nn.Parameter(torch.ones(n_head) * (-3.0)) + # time_first (no longer fancy) + self.time_first = nn.Parameter(torch.ones(n_head) * (-3.0)) # self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear(n_embd, dim_att, bias=False) @@ -259,16 +265,22 @@ def _forward_rkv_chunk(self, x, B, TT, last_state: TimeMixState): k = self.key(xk).view(B, TT, self.n_head, self.head_size).transpose(1, 2).transpose(-2, -1) # BTC -> BHTS -> BHST v = self.value(xv).view(B, TT, self.n_head, self.head_size).transpose(1, 2) # BTC -> BHTS + # # Enforce bf16 type for kv, as this can be mis init + # # when being called directly via inference + # if r.dtype != torch.bfloat16: + # r = r.to(torch.bfloat16) + # if k.dtype != torch.bfloat16: + # k = k.to(torch.bfloat16) + # if v.dtype != torch.bfloat16: + # v = v.to(torch.bfloat16) + return r, k, v def _forward_wkbs_chunk(self, T, r, k, v): H = self.n_head w = torch.exp(-torch.exp(self.time_decay.float())).unsqueeze(-1) - - # V5-R2 changes - u = self.time_faaaa.float().unsqueeze(-1) - # u = torch.exp(self.time_first.float()).unsqueeze(-1) + u = torch.exp(self.time_first.float()).unsqueeze(-1) ws = w.pow(T).reshape(1, H, 1, 1) ind = torch.arange(T-1, -1, -1, device=r.device).unsqueeze(0).repeat(H, 1) @@ -301,15 +313,44 @@ def _forward_state_chunk(self, r, k, v, w, wk, wb, ws, x_l, last_state: TimeMixS if r.dtype == torch.bfloat16 and s.dtype != torch.bfloat16: s = s.contiguous().to(torch.bfloat16) + # print("") + # print("B,H,TT,S", B, H, TT, S) + # print("S-zero", torch.zeros(B, H, S, S, device=r.device, dtype=r.dtype).shape) + # print("wkv", last_state.wkv_state[:, :, :].shape) + # print("") + x = torch.zeros(B, H, TT, S, device=r.device, dtype=r.dtype) # output + # # Check if r is bf16 or float + # if s.dtype == torch.bfloat16: + # print("s is bf16") + # elif s.dtype == torch.float32: + # print("s is float32") + # else: + # print("s is neither bf16 nor float32") + + # # Check if r is bf16 or float + # if r.dtype == torch.bfloat16: + # print("r is bf16") + # elif r.dtype == torch.float32: + # print("r is float32") + # else: + # print("r is neither bf16 nor float32") + ######################################################################## for i in range(TT // T): - rr = r[:, :, i*T:i*T+T, :] kk = k[:, :, :, i*T:i*T+T] vv = v[:, :, i*T:i*T+T, :] + # # Check if r is bf16 or float + # if rr.dtype == torch.bfloat16: + # print("rr is bf16") + # elif rr.dtype == torch.float32: + # print("rr is float32") + # else: + # print("rr is neither bf16 nor float32") + x[:, :, i*T:i*T+T, :] = ((rr @ kk) * w) @ vv + (rr @ s) * wb s = ws * s + (kk * wk) @ vv @@ -318,6 +359,11 @@ def _forward_state_chunk(self, r, k, v, w, wk, wb, ws, x_l, last_state: TimeMixS x = x.transpose(1, 2).contiguous().view(B * TT, H*S) # BHTS -> BTHS -> BTC x = self.ln_x(x).view(B, TT, H*S) + # # Return with logits outputs, and new timemix state + # print("") + # print("S-right", s.shape) + # print("n_layer, n_embd, layer_id", self.n_layer, self.n_embd, self.layer_id) + # print("") return self.output(x), TimeMixState(x_l, s) def _forward_chunk(self, x, last_state: TimeMixState): @@ -397,7 +443,7 @@ def forward(self, x, last_state: ChannelMixState): class Block(nn.Module): - def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dropout, dim_att, dim_ffn): + def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dim_att, dim_ffn): super().__init__() self.layer_id = layer_id @@ -407,41 +453,22 @@ def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dropout, dim_at if self.layer_id == 0: self.ln0 = nn.LayerNorm(n_embd) - self.att = RWKV_TimeMix(layer_id, n_layer, n_embd, n_head, head_size, dim_att) + self.att = RWKV_TimeMix(layer_id, n_layer, n_embd, dim_att) self.ffn = RWKV_ChannelMix(layer_id, n_layer, n_embd, dim_ffn) - # Setup droupout at block level - self.dropout = dropout - if dropout > 0: - self.drop0 = nn.Dropout(p = dropout) - self.drop1 = nn.Dropout(p = dropout) - def forward(self, x, last_state: BlockState): if self.layer_id == 0: x = self.ln0(x) - att_out, att_state = self.att( self.ln1(x), last_state.time_mix_state, ) - - if self.dropout > 0.0: - # Handle with dropout - x = self.drop0(x + att_out) - ffn_out, ffn_state = self.ffn( - self.ln2(x), - last_state.channel_mix_state, - ) - x = self.drop1(x + ffn_out) - else: - # Handle without dropout - x = x + att_out - ffn_out, ffn_state = self.ffn( - self.ln2(x), - last_state.channel_mix_state, - ) - x = x + ffn_out - + x = x + att_out + ffn_out, ffn_state = self.ffn( + self.ln2(x), + last_state.channel_mix_state, + ) + x = x + ffn_out return x, BlockState(att_state, ffn_state) @@ -507,17 +534,12 @@ def __init__(self, lr_final: float = -1.0, lr_period: int = -1, lr_period_type: str = 'epoch', - # Dropout rate - dropout: float = 0.0, # Adam optimizer settings beta1: float = 0.9, beta2: float = 0.99, adam_eps: float = 1.0e-08, weight_decay: float = 0.01, warmup_steps: int = -1, - # loss bias start - position_loss_bias: float = 1.0, - position_loss_bias_in_validation: bool = False, # Backprop settings grad_cp: bool = True, bptt_learning: bool = True, @@ -586,7 +608,6 @@ def __init__(self, self.lr_final = lr_final self.lr_period = lr_period self.lr_period_type = lr_period_type - self.dropout = dropout self.warmup_steps = warmup_steps self.beta1 = beta1 self.beta2 = beta2 @@ -598,10 +619,6 @@ def __init__(self, self.substep_cuda_cache_clear = substep_cuda_cache_clear self.substep_logging = substep_logging - # Save the position loss params - self.position_loss_bias = position_loss_bias - self.position_loss_bias_in_validation = position_loss_bias_in_validation - dim_att = dim_att or n_embd dim_ffn = dim_ffn or n_embd * 4 self.dim_att = dim_att @@ -635,16 +652,12 @@ def __init__(self, # is_python_module=False) self.blocks = nn.ModuleList([ - Block(i, n_layer, n_embd, n_head, head_size, dropout, dim_att, dim_ffn) for i in range(n_layer) + Block(i, n_layer, n_embd, n_head, head_size, dim_att, dim_ffn) for i in range(n_layer) ]) self.ln_out = nn.LayerNorm(n_embd) self.head = nn.Linear(n_embd, vocab_size, bias=False) - # Dropout handling - if dropout > 0: - self.drop0 = nn.Dropout(p = dropout) - # load the state, and GC the original cpu copy if model_weights != None: self.load_state_dict(model_weights) @@ -669,19 +682,6 @@ def configure_optimizers(self): # Log the learning rate, and various other parameters if self.trainer.local_rank == 0: - - # Add the important notes, for informing users of common gotchas - print(( - "#\n" - "# RWKV lighting_trainer.py important notes \n" - "# https://github.com/RWKV/RWKV-infctx-trainer \n" - "#\n" - "# - Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues\n" - "# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications\n" - "# - When resuming from checkpoint, the estimated time is inaccurate\n" - "#" - )) - lr_init_e = "{:.3e}".format(lr_init) lr_final_e = "{:.3e}".format(lr_final) print(f"\n[RWKV.model] Configuring optimizer with\n"+ @@ -707,11 +707,8 @@ def configure_optimizers(self): lr_1x.add(n) elif "time_decay" in n: lr_2x.add(n) - # V5-R2 changes - elif "time_faaaa" in n: - lr_2x.add(n) - # elif "time_first" in n: - # lr_3x.add(n) + elif "time_first" in n: + lr_3x.add(n) else: lr_1x.add(n) lr_1x = sorted(list(lr_1x)) @@ -886,10 +883,6 @@ def forward(self, idx: torch.Tensor, last_shift_states: torch.Tensor = None, x = self.emb(idx) - # Handle dropout (input) - if self.dropout > 0.0: - x = self.drop0(x) - new_states = BlockStateList.empty(self.n_layer, B, self.n_embd, self.n_head, self.head_size, x.device, x.dtype) @@ -973,37 +966,41 @@ def manual_backward(self, loss: torch.Tensor, *args, **kwargs): def compute_loss(self, batch, batch_idx, is_training_run: bool): seq = batch['input_ids'] assert isinstance(seq, torch.Tensor) and seq.ndim == 2 - ori_seq_mask = batch['attention_mask'] + seq_mask = batch['attention_mask'] # Check if attent mask is set, if not initialize it - if ori_seq_mask is None or ori_seq_mask.ndim != 2: - ori_seq_mask = torch.ones_like(seq[:, 1:]) - - # Get the starting and ending loss bias - loss_bias_start = self.position_loss_bias - loss_bias_end = 2.0 - loss_bias_start - - # total_mask_sum - total_mask_sum = torch.sum(ori_seq_mask) - - # Skip loss bias calculation, if loss_bias_start is 1.0 - if loss_bias_start == 1.0 or (is_training_run == False and self.position_loss_bias_in_validation == False): - seq_mask = ori_seq_mask - else: - # Lets get a linear multiplier for the loss bias - # seq_mask_sum = torch.sum(ori_seq_mask) - bias_mask = torch.linspace(loss_bias_start, loss_bias_end, int(total_mask_sum.item()), device=ori_seq_mask.device) - - # Boolean flag of seq_mask > 0 - seq_mask_index = ori_seq_mask[0] > 0 - - # Apply the bias mask only to positive seq_mask values - final_mask = torch.zeros(ori_seq_mask.shape[1], device=ori_seq_mask.device) - final_mask[seq_mask_index] = ori_seq_mask[0][seq_mask_index] * bias_mask - - # And save it as seq_mask - seq_mask = final_mask.unsqueeze(0) - + if seq_mask is None or seq_mask.ndim != 2: + seq_mask = torch.ones_like(seq[:, 1:]) + + # Perform cutoff for training run + if is_training_run: + prev_step = 0 + + # Avoid using the zip operation, as torch.compile throws an exception on it + # with `zip not reconized as a valid function` + # --- + # for step, len_cut in zip(self.ctx_len_warmup_steps, + # self.ctx_len_cutoffs): + # --- + for i in range(min(len(self.ctx_len_warmup_steps), len(self.ctx_len_cutoffs))): + step = self.ctx_len_warmup_steps[i] + len_cut = self.ctx_len_cutoffs[i] + + if prev_step <= self.global_step < step and len_cut < seq.shape[ + 1] - 1: + pos = randint(0, seq.shape[1] - len_cut - 1) + + # Original + # seq = seq[:, pos:pos + len_cut + 1] + + # Changed to use masking for prefix cutoff (i do not know if this makes sense) + seq = seq[:, :pos + len_cut + 1] + seq_mask = seq_mask[:, :pos + len_cut + 1] + # Set the attention mask to 0 for the skipped tokens + seq_mask[:, :pos] = 0 + break + prev_step = step + do_bptt_learning = self.bptt_learning and is_training_run idx, targets = seq[:, :-1], seq[:, 1:] @@ -1104,11 +1101,7 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states, if self.trainer.num_devices > 1: if self.bptt_learning_range <= 0: # We perform forward/backward on the shared max segment count across all GPUs - forward_segment_count = self.trainer.strategy.reduce(segment_count, reduce_op="max") - # Convert to int, if its a torch tensor - if isinstance(forward_segment_count, torch.Tensor): - forward_segment_count = forward_segment_count.item() - # We perform as many backward pass as we need to be equal or more then bptt_learning_range + forward_segment_count = self.trainer.strategy.reduce(segment_count, reduce_op="max").item() backward_segment_count = forward_segment_count else: # We perform as many forward pass as we need to be equal or more then bptt_learning_range @@ -1389,8 +1382,7 @@ def decode(self, tokens: list): # Forwarding logic, withoout torch._no_grad() context def _forward( self, tokens, - stateObj = None, - all_logits = False + stateObj = None ): logits_arr = None @@ -1404,17 +1396,11 @@ def _forward( shift_states = stateObj["shift_states"] wkv_states = stateObj["wkv_states"] - # The all_logits array, if requested - all_logits_arr = None - # For each token, process the state, in batches up to ctx_len for i in range(0, token_len, self.ctx_len): - # Token set - token_set = tokens[i:i+self.ctx_len] - # Check if tokens are already tensors batch_tokens = torch.tensor( - token_set, + tokens[i:i+self.ctx_len], dtype=torch.long, device=self.device ).unsqueeze(0) @@ -1423,27 +1409,16 @@ def _forward( batch_tokens, shift_states, wkv_states ) - # Build the all_logits array - if all_logits: - if all_logits_arr is None: - all_logits_arr = logits_arr[0] - else: - all_logits_arr = torch.cat([all_logits_arr, logits_arr[0]], dim=0) - # Return the logits and state - if all_logits: - return all_logits_arr, { "shift_states": shift_states, "wkv_states": wkv_states } - else: - return logits_arr[0][-1], { "shift_states": shift_states, "wkv_states": wkv_states } + return logits_arr[0][-1], { "shift_states": shift_states, "wkv_states": wkv_states } # Forwarding logic, with torch._no_grad() context def forward( self, tokens:list, - stateObj = None, - all_logits = False + stateObj = None ): with torch.no_grad(): - return self._forward(tokens, stateObj, all_logits) + return self._forward(tokens, stateObj) # Sampling logits def sample_logits( diff --git a/docker/env-cuda-11-8/Dockerfile b/docker/env-cuda-11-8/Dockerfile index 4de3bae8..2a6c38ea 100644 --- a/docker/env-cuda-11-8/Dockerfile +++ b/docker/env-cuda-11-8/Dockerfile @@ -2,10 +2,11 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 # Install ninja, and several common stuff -RUN apt-get update && apt-get install -y ninja-build htop wget curl git - -# Install python3 and pip3 -RUN apt-get update && apt-get install -y python3.11 python3-pip +RUN apt-get update && \ + apt-get install -y \ + ninja-build \ + htop wget curl git vim \ + python3.11 python3-pip # Install pytorch RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118