Skip to content

TorchLib function authoring guide

Justin Chu edited this page Mar 26, 2024 · 9 revisions

Updated: November 2023
Authors: @titaiwangms @justinchuby

Principles

TorchLib functions are pure data. This means we avoid defining runtime behavior as code in the functions.

Check native_functions.yaml

The primary objective of torchlib revolves around transforming a PyTorch model into an ONNX model. To accomplish this, it's essential to initially grasp the function signature within PyTorch, specifically focusing on ATen operators. You can find a comprehensive list of these native functions in PyTorch defined within the native_functions.yaml file.

- func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> Return
  variants: function, method
  dispatch:
    CPU: func_cpu
    CUDA: func_cuda

Developers need to exercise caution when dealing with the ArgType, as each distinct ArgType corresponds to a different TypeVar within torchlib.

Implement OnnxFunction/TracedOnnxFunction

The decorator: torch_op

The torch_op decorator serves the purpose of formally registering the function within the torchlib framework.

def torch_op(
    name: str | tuple[str, ...],
    *,
    registry: Optional[Registry] = None,
    trace_only: bool = False,
    private: bool = False,
    complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
    """Register a torch op.

    Args:
        name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
            Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
            Default overloads should be specified by omitting the overload part,
            i.e. "aten::relu" instead of "aten::relu.default".
        registry: Registry to register the function to. If None, the default registry is used.
        trace_only: Whether the function should only be traced and not compiled.
        private: Whether the function is private (not directly exposed). It should
            be true for all functions with names starting with "_".
        complex: Whether the function supports complex.
    """
    ...

The trace_only feature enhances the functionality of script() by incorporating intricate control-flow through the utilization of the TracedOnnxFunction class. Unlike the process of compiling the entire control-flow enabled function into OnnxFunction, TracedOnnxFunction merely traces it as a standard Python function. This adaptation allows for the handling of unsupported control-flow scenarios.

Function signature

  • Name a function starting with the namespace it's from. For example, aten_abs or prims_abs.
  • Correctly annotate the inputs and attributes with native_function.yaml.

Introduce or define a single TypeVar within tensor_typing that corresponds to the designated ArgType indicated in the native_functions.yaml file. Typically, inputs are expected to conform to tensor types, while attributes are anticipated to be of primitive types. Nevertheless, the specific circumstances evolve on a case-by-case basis due to the implementation of OnnxFunction, adapting to the prerequisites of the employed ONNX operators within the function.

Function body

OnnxFunction

When scripting the function, it's imperative that every computation within the function is executed using ONNX operators. A prefix denoted as opset{version} is employed to indicate the source of the operator. OnnxFunction additionally provides partial support for control-flow operations, as well as streamlined coding practices like 'if' statements, 'for' loops, automatic constant encapsulation, and automatic basic arithmetic encapsulation.

@torch_op("aten::gather")
def aten_gather(
    self: TReal,
    dim: int,
    index: TInt,
    sparse_grad: bool = False,  # pylint: disable=unused-argument
) -> TReal:
    """gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"""

    # When (index) is empty, return (self)
    if op.Size(op.Shape(index)) == 0:  # Support control-flow
        result = self
    else:
        if op.Size(op.Shape(self)) == 0:  # 0 is auto-wrapping of op.Constant(value_float=[0])
            self = op.Reshape(self, op.Constant(value_ints=[-1]))
        if op.Size(index) == 0:  # == is auto-wrapping on op.Equal()
            result = op.CastLike(index, self)
        else:
            index = op.Cast(index, to=INT64.dtype)
            result = op.GatherElements(self, index, axis=dim)
    return result

TracedOnnxFunction

This category of function is essentially a pure Python function that encompasses OnnxFunction. The rationale behind its necessity lies in the fact that the constrained coding capabilities of OnnxFunction cannot adequately address intricate scenarios within the operator. These situations may encompass unsupported operations such as dictionaries, the 'len()' function, checks for 'None', and so on.

@torch_op("aten::layer_norm", trace_only=True)
def aten_layer_norm(
    input: TReal,
    normalized_shape: INT64,
    weight: Optional[TReal] = None,
    bias: Optional[TReal] = None,
    eps: float = 1e-05,
) -> TReal:
    """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""

    # trace_only to use Python to obtain start_axis
    start_axis = -len(normalized_shape)

    if weight is None:  # Unsupported None check
        one = op.Constant(value_float=1.0)
        weight = op.Expand(one, op.Shape(input, start=start_axis))

    if bias is None:  # Unsupported None check
        zero = op.Constant(value_float=0.0)
        bias = op.Expand(zero, op.Shape(input, start=start_axis))

    return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps)  # covers a private OnnxFunction


@torch_op("aten::layer_norm", private=True)
def _aten_layer_norm_onnx(
    input: TReal,
    weight: TReal,
    bias: TReal,
    axis: int,
    eps: float = 1e-05,
) -> TReal:
    """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""

    # TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
    result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
    return result

Op Test System

To make sure the OnnxFunction/TracedOnnxFunction has valid implementation, we provide Op-level correctness test.

How to add a new operator test

This test use PyTorch's OpInfo mechanism to generate test cases for each operator. You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804

  1. To enable test cases for an operator Add a TorchLibOpInfo entry to TORCH_LIB_OPINFO in ops_test_data.py. Explicitly specify trace_only if the op is trace_only. Specify complex if the function is designed for complex inputs.

    The op_info_name in TorchLibOpInfo needs to be unique in the TORCH_LIB_OPINFO list, but complex=True ops can share the same name with non-complex ops because they are tested separately.

  2. Add .skip and/or .xfail to skip or xfail tests. Prefer xfail over skip when possible because that allows us to monitor the behavior and update the test will it passes.

    2a. If a test is now failing because of xpass, because some previous errors are now fixed, removed the corresponding xfail.

  3. If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input wrangler function. See _mean_input_wrangler for an example.

  4. To test different ONNX functions that are registered as overloads of the same op, use ops_test_common.duplicate_opinfo to create new OpInfo with new names and map each to one overload.

Example PRs

Use https://github.com/microsoft/onnxscript/pull/1260/files and https://github.com/microsoft/onnxscript/pull/1284 as examples for implementing an operator and creating tests for it.