diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index c22689a212..db6629e73c 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -306,7 +306,7 @@ def lion_step_unfused(grads: torch.Tensor, beta2: float, weight_decay: float = 0) -> torch.Tensor: # f32 cast to match fused impl + for compatibility with f32 grads or weights - momentums = momentums.to(torch.float32) + momentums = momentums.to(dtype=torch.float32) grads = grads.to(dtype=torch.float32) update = momentums.lerp(grads, 1 - beta1).sign_()