diff --git a/megatron/training.py b/megatron/training.py index 79f39ccc2e..0aeaabeba5 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1032,6 +1032,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, if args.log_optimizer_states_to_tensorboard and optimizer is not None: opt_stats = [0.0] * 8 opt_stats_2 = [0.0] * 4 + + #TODO(billishyahao): Remove me after bf16_optimizer promotes its state. + if not hasattr(optimizer, "state"): + assert hasattr(optimizer, "optimizer"), f"Optimizer must have optimizer property." + optimizer.state = optimizer.optimizer.state + for _, group in enumerate(optimizer.param_groups): for _, param in enumerate(group['params']): opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item())**2