Skip to content

Commit

Permalink
fixed rmsnorm bug. Changed division to multiply since using torch.rsqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
vcoykendall committed Jun 26, 2023
1 parent b1ac9da commit df0667a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x):


def rms_norm(x, weight=None, eps=1e-5):
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
if weight is not None:
return output * weight
return output
Expand Down

0 comments on commit df0667a

Please sign in to comment.