Skip to content

Commit

Permalink
torch.compiled version, fp64 chunklen128
Browse files Browse the repository at this point in the history
  • Loading branch information
SmerkyG committed Feb 2, 2024
1 parent 90059d3 commit 4c88076
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion RWKV-v5/src/module/TimeMix.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _forward_nocuda_optimized(self, x, last_state: tuple[torch.Tensor,torch.Tens
shift_state_out = x[:,-1]

# 24 is optimal chunk length (longer will use too much memory and cause precision problems or even numerical instability, shorter is inefficient)
chunk_len = 24
chunk_len = 128

# padding to support fast path for non-exact chunk size multiple sequence lengths
n_padding = (chunk_len - x.size(-2) % chunk_len) % chunk_len
Expand Down
7 changes: 5 additions & 2 deletions RWKV-v5/src/module/rwkv_inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from torch import Tensor

# 24 is optimal chunk length (longer will use too much memory and cause precision problems or even numerical instability, shorter is inefficient)
def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24,precision_dtype:torch.dtype=torch.float32):
@torch.jit.ignore
@torch.compile
def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24):
precision_dtype:torch.dtype=torch.float64
"""
expects
r : (B,H,L,K)
Expand Down Expand Up @@ -38,7 +41,7 @@ def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24,precision_dtype:torch.dtype=t
if precision_dtype == torch.float32:
precision_min_val = 0.005 # good for fp32 (1.175e-38 ^ (1/16.0) < 0.00426)
else: #elif precision_dtype == torch.float64:
precision_min_val = 1e-10 # good for fp64 (1.7e-308 ^ (1/16.0) < 5.8e-20)
precision_min_val = 0.005 # to match fp32
w = w.clamp(precision_min_val)

# calculate cumulative decay in log space where it won't overflow
Expand Down

0 comments on commit 4c88076

Please sign in to comment.