Skip to content

Commit

Permalink
Merge pull request #6 from RWKV/main-dev-infctx
Browse files Browse the repository at this point in the history
Main dev infctx
  • Loading branch information
PicoCreator authored Aug 20, 2023
2 parents ef5a88d + f996fac commit 55c6431
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,15 @@ def _forward_state_chunk(self, r, k, v, w, wk, wb, ws, x_l, last_state: TimeMixS
x = torch.zeros(B, H, TT, S, device=r.device, dtype=r.dtype) # output

########################################################################
# for i in range(TT // T):
# (optimizing out the for loop, since TT//T is always 1)
i = 1
for i in range(TT // T):

rr = r[:, :, i*T:i*T+T, :]
kk = k[:, :, :, i*T:i*T+T]
vv = v[:, :, i*T:i*T+T, :]
rr = r[:, :, i*T:i*T+T, :]
kk = k[:, :, :, i*T:i*T+T]
vv = v[:, :, i*T:i*T+T, :]

x[:, :, i*T:i*T+T, :] = ((rr @ kk) * w) @ vv + (rr @ s) * wb
x[:, :, i*T:i*T+T, :] = ((rr @ kk) * w) @ vv + (rr @ s) * wb

s = ws * s + (kk * wk) @ vv
s = ws * s + (kk * wk) @ vv
########################################################################

x = x.transpose(1, 2).contiguous().view(B * TT, H*S) # BHTS -> BTHS -> BTC
Expand Down

0 comments on commit 55c6431

Please sign in to comment.