diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8b776fa0b..555a34f81 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5652,11 +5652,11 @@ def aten_native_batch_norm( # three outputs when training_mode=True and one when it is False. if training is True: norm, mean, var = _aten_native_batch_norm_training_onnx( - input, weight, bias, running_mean, running_var, axes, training, momentum, eps + input, weight, bias, running_mean, running_var, axes, momentum=momentum, eps=eps ) else: norm, mean, var = _aten_native_batch_norm_inference_onnx( - input, weight, bias, running_mean, running_var, axes, training, momentum, eps + input, weight, bias, running_mean, running_var, axes, momentum=momentum, eps=eps ) return norm, mean, var