Skip to content

Commit

Permalink
Merge pull request #11 from RWKV/main-dev-infctx
Browse files Browse the repository at this point in the history
fix accidental removal of ctx_len_cutoffs
  • Loading branch information
PicoCreator authored Aug 21, 2023
2 parents 4f2525d + ed3bfea commit fd6bbfb
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,35 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool):
# And save it as seq_mask
seq_mask = final_mask.unsqueeze(0)

# Perform cutoff for training run
if is_training_run:
prev_step = 0

# Avoid using the zip operation, as torch.compile throws an exception on it
# with `zip not reconized as a valid function`
# ---
# for step, len_cut in zip(self.ctx_len_warmup_steps,
# self.ctx_len_cutoffs):
# ---
for i in range(min(len(self.ctx_len_warmup_steps), len(self.ctx_len_cutoffs))):
step = self.ctx_len_warmup_steps[i]
len_cut = self.ctx_len_cutoffs[i]

if prev_step <= self.global_step < step and len_cut < seq.shape[
1] - 1:
pos = randint(0, seq.shape[1] - len_cut - 1)

# Original
# seq = seq[:, pos:pos + len_cut + 1]

# Changed to use masking for prefix cutoff (i do not know if this makes sense)
seq = seq[:, :pos + len_cut + 1]
seq_mask = seq_mask[:, :pos + len_cut + 1]
# Set the attention mask to 0 for the skipped tokens
seq_mask[:, :pos] = 0
break
prev_step = step

do_bptt_learning = self.bptt_learning and is_training_run
idx, targets = seq[:, :-1], seq[:, 1:]

Expand Down

0 comments on commit fd6bbfb

Please sign in to comment.