diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 1e97514b..0b5f9408 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -915,21 +915,22 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid total_mask_sum = torch.sum(seq_mask) avg_mask_sum = ( total_mask_sum / B ) - # Do a quick return, if there is no tokens of value to learn from due to full masking - if num_devices > 1 and total_mask_sum == 0: - return 0 + # # Do a quick return, if there is no tokens of value to learn from due to full masking + # # DO NOT DO THIS : This causes multi node / multi GPU to go out of sync + # if num_devices <= 1 and total_mask_sum == 0: + # return 0 # Checkpoint steps def checkpointed_step(idx, targets, mask, last_shift_states, last_wkv_states): - # Skip if there is no tokens of value to learn from - if idx.shape[1] == 0: - # Prepare dummy loss - train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() - sample_loss = train_loss.clone().detach().requires_grad_(False) - - # Return the checkpoint values - return sample_loss, train_loss, last_shift_states, last_wkv_states, 0 + # # Skip if there is no tokens of value to learn from + # if idx.shape[1] == 0: + # # Prepare dummy loss + # train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_() + # sample_loss = train_loss.clone().detach().requires_grad_(False) + + # # Return the checkpoint values + # return sample_loss, train_loss, last_shift_states, last_wkv_states, 0 # Get the logits, and the new states logits, new_shift_states, new_wkv_states = self(