diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 6d898eb4ef..51c2cb4e65 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -52,7 +52,7 @@ def forward(self, x): def rms_norm(x, weight=None, eps=1e-5): - output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight return output