Skip to content

Commit

Permalink
Fix and simplify native_batch_norm | fix(torchlib) (#1289)
Browse files Browse the repository at this point in the history
- Merged implementations for `_native_batch_norm_legit_functional`
- Fix the implementation where `momentum` was not replaced with
`1-momentum` to handle the ONNX spec difference.
- Update the test decorator to support skipping by device type. Skip the
tests only on cpu.

TODO:
- Fix results w/ gpu

Fixes #1256
  • Loading branch information
justinchuby authored Mar 8, 2024
1 parent f9f2fa2 commit 3b59a74
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 113 deletions.
191 changes: 84 additions & 107 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4237,7 +4237,7 @@ def aten_instance_norm(
running_mean,
running_var,
epsilon=eps,
momentum=1 - momentum,
momentum=1.0 - momentum,
training_mode=False,
)
return op.Reshape(norm, op.Shape(input))
Expand Down Expand Up @@ -5648,17 +5648,31 @@ def aten_native_batch_norm(
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))

# Have to split to 2 private functions, because training_function return 3 outputs
# While inference_function return 1 output
if training is True:
norm, mean, var = _aten_native_batch_norm_training_onnx(
input, weight, bias, running_mean, running_var, axes, training, momentum, eps
# We have to split to two private functions, because BatchNormalization returns
# three outputs when training_mode=True and one when it is False.
if training:
norm, input_mean, input_rstd, _, _ = _aten_native_batch_norm_training_onnx(
input,
weight,
bias,
running_mean,
running_var,
axes,
momentum=1.0 - momentum,
eps=eps,
)
else:
norm, mean, var = _aten_native_batch_norm_inference_onnx(
input, weight, bias, running_mean, running_var, training, momentum, eps
norm, input_mean, input_rstd, _, _ = _aten_native_batch_norm_inference_onnx(
input,
weight,
bias,
running_mean,
running_var,
momentum=1.0 - momentum,
eps=eps,
)
return norm, mean, var

return norm, input_mean, input_rstd


@torch_op("aten::native_batch_norm", private=True)
Expand All @@ -5669,22 +5683,25 @@ def _aten_native_batch_norm_training_onnx(
running_mean: TFloat,
running_var: TFloat,
axes: INT64,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat]:
# Assert(training is True)
norm, running_mean, running_var = op.BatchNormalization(
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
"""Batch normalization training mode.
NOTE: momentum in PyTorch is 1.0-momentum in ONNX.
When calling this function be sure to pass 1.0-momentum when momentum is obtained from PyTorch.
"""
norm, running_mean, _ = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
training_mode=True,
)
# Compute var and rstd
# Compute mean and rstd
# Mean, var, and rstd computation and results are expected to be
# in higher precision when inputs are float16.
upcast_input = op.Cast(input, to=FLOAT.dtype)
Expand All @@ -5695,7 +5712,19 @@ def _aten_native_batch_norm_training_onnx(
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(upcast_input, axes, keepdims=False)
return norm, mean, rstd

# Compute the running var the PyTorch way
# https://github.com/pytorch/pytorch/blob/5cc511f72fe073bbd8c10d796d72dce67f5cd5c4/torch/_decomp/decompositions.py#L1646

n = op.Cast(op.Size(input) / op.Shape(input)[1], to=FLOAT.dtype)
unbiased_var = var * (n / (n - 1.0))

# NOTE: momentum in ONNX is 1.0-momentum in PyTorch
new_running_var = (
op.CastLike((1.0 - momentum) * unbiased_var, running_var) + momentum * running_var
)

return norm, mean, rstd, running_mean, new_running_var


@torch_op("aten::native_batch_norm", private=True)
Expand All @@ -5705,11 +5734,14 @@ def _aten_native_batch_norm_inference_onnx(
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat]:
# Assert(training is False)
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
"""Batch normalization inference mode.
NOTE: momentum in PyTorch is 1.0-momentum in ONNX.
When calling this function be sure to pass 1.0-momentum when momentum is obtained from PyTorch.
"""
norm = op.BatchNormalization(
input,
weight,
Expand All @@ -5718,13 +5750,15 @@ def _aten_native_batch_norm_inference_onnx(
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
training_mode=False,
)
# NOTE: mean and var are omitted in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var
# CUDA and CPU gives different shapes:
# https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457
# We use CUDA's output here
invstd = op.Div(1.0, op.Sqrt(running_var + eps))
# https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475
# TODO(justinchuby): Make sure the output types are correct
return norm, running_mean, invstd, running_mean, running_var


# TODO: This op is using duplicated code from aten_native_batch_norm,
Expand Down Expand Up @@ -5760,92 +5794,35 @@ def aten__native_batch_norm_legit_functional(
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))

# Have to split to 2 private functions, because training_function return 3 outputs
# While inference_function return 1 output
if training is True:
norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx(
input, weight, bias, running_mean, running_var, axes, training, momentum, eps
# We have to split to two private functions, because BatchNormalization returns
# three outputs when training_mode=True and one when it is False.
if training:
norm, input_mean, input_rstd, running_mean, running_var = (
_aten_native_batch_norm_training_onnx(
input,
weight,
bias,
running_mean,
running_var,
axes,
momentum=1.0 - momentum,
eps=eps,
)
)
else:
(
norm,
mean,
var,
new_mean,
new_var,
) = _aten__native_batch_norm_inference_functional_onnx(
input, weight, bias, running_mean, running_var, training, momentum, eps
norm, input_mean, input_rstd, running_mean, running_var = (
_aten_native_batch_norm_inference_onnx(
input,
weight,
bias,
running_mean,
running_var,
momentum=1.0 - momentum,
eps=eps,
)
)
return norm, mean, var, new_mean, new_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_training_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
axes: INT64,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is True)
norm, running_mean, running_var = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# Compute var and rstd
# Mean, var, and rstd computation and results are expected to be
# in higher precision when inputs are float16.
upcast_input = op.Cast(input, to=FLOAT.dtype)
mean = op.ReduceMean(upcast_input, axes)
input_sub_mean = op.Sub(upcast_input, mean)
sqr = op.Mul(input_sub_mean, input_sub_mean)
var = op.ReduceMean(sqr, axes, keepdims=False)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(upcast_input, axes, keepdims=False)
# NOTE: Fixed to be FLOAT dtype
running_mean = op.Cast(running_mean, to=FLOAT.dtype)
running_var = op.Cast(running_var, to=FLOAT.dtype)
return norm, mean, rstd, running_mean, running_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_inference_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is False)
norm = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# NOTE: mean and var are ommited in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var, running_mean, running_var
return norm, input_mean, input_rstd, running_mean, running_var


def aten_native_batch_norm_backward(
Expand Down
7 changes: 5 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,8 @@ def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad
if args[0] is not None and args[1] is not None:
yield opinfo_core.SampleInput(
sample.input,
args=(args[2], args[3], args[0], args[1], training, momentum, eps),
args=(args[2], args[3], args[0], args[1]),
kwargs={"training": training, "momentum": momentum, "eps": eps},
)


Expand All @@ -1019,7 +1020,9 @@ def sample_inputs__native_batch_norm_legit_no_stats(
eps = sample.kwargs.get("eps", 1e-5)
if args[0] is not None and args[1] is None:
yield opinfo_core.SampleInput(
sample.input, args=(args[2], args[3], training, momentum, eps)
sample.input,
args=(args[2], args[3]),
kwargs={"training": training, "momentum": momentum, "eps": eps},
)


Expand Down
16 changes: 14 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:


def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
Expand All @@ -84,6 +84,12 @@ def _should_skip_xfail_test_sample(
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if (
decorator_meta.device_type is not None
and decorator_meta.device_type != device_type
):
# Not applicable for this device_type
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
Expand Down Expand Up @@ -200,7 +206,13 @@ def run_test_output_match(
),
kwargs=repr(cpu_sample.kwargs),
):
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)
try:
device_type = cpu_sample.args[0].device.type
except (AttributeError, IndexError):
device_type = "cpu"
test_behavior, reason = _should_skip_xfail_test_sample(
op.name, cpu_sample, dtype, device_type
)

with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
Expand Down
8 changes: 8 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DecorateMeta:
variant_name: str
decorator: Callable[..., Any]
dtypes: Optional[Collection[torch.dtype]]
device_type: Optional[str]
reason: str
test_behavior: str
matcher: Optional[Callable[[Any], bool]] = None
Expand All @@ -85,6 +86,7 @@ def xfail(
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
Expand All @@ -96,6 +98,7 @@ def xfail(
variant_name: Optional OpInfo variant_test_name.
reason: The reason for the failure.
dtypes: The dtypes to expect the failure.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the xfail is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the xfail is enabled.
Expand All @@ -107,6 +110,7 @@ def xfail(
variant_name=variant_name,
decorator=unittest.expectedFailure,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
reason=reason,
enabled_if=enabled_if,
Expand All @@ -121,6 +125,7 @@ def skip(
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
Expand All @@ -132,6 +137,7 @@ def skip(
variant_name: Optional OpInfo variant_test_name.
reason: The reason for skipping.
dtypes: The dtypes to skip.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the skip is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the skip is enabled.
Expand All @@ -143,6 +149,7 @@ def skip(
variant_name=variant_name,
decorator=unittest.skip(f"Skip: {reason}"),
dtypes=dtypes,
device_type=device_type,
reason=reason,
matcher=matcher,
enabled_if=enabled_if,
Expand Down Expand Up @@ -174,6 +181,7 @@ def add_decorate_info(
decorate_meta.test_class_name or test_class_name,
base_test_name,
dtypes=decorate_meta.dtypes,
device_type=decorate_meta.device_type,
active_if=decorate_meta.enabled_if,
)
decorators.append(new_decorator)
Expand Down
Loading

0 comments on commit 3b59a74

Please sign in to comment.