Skip to content

Commit

Permalink
Merge branch 'main' into rama/const-prop
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Oct 11, 2024
2 parents 18862b9 + 12f9209 commit 2b497fb
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 99 deletions.
72 changes: 26 additions & 46 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,10 @@
import pathlib
from typing import Callable

import onnx

from onnxscript import ir, optimizer
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data

# Internal flag. Will go away.
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0"
)


@dataclasses.dataclass(frozen=True)
class _OnnxFunctionMeta:
Expand Down Expand Up @@ -83,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor

else:
destination_path = pathlib.Path(model_path)
# Create the directory if it does not exist
data_path = f"{destination_path.name}.data"
proto = ir.serde.serialize_model(model)
onnx.save_model(
proto,
model_path,
save_as_external_data=True,
location=data_path,
)
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/_framework_apis/torch_2_6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.6."""

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]
from onnxscript import ir, optimizer
from onnxscript._framework_apis.torch_2_5 import (
check_model,
convert_version,
get_torchlib_ops,
save_model_with_external_data,
)


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
return model
59 changes: 25 additions & 34 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
RealType,
TFloat,
TFloatHighPrecision,
TFloatOrBFloat16,
TInt,
TReal,
TRealOrUInt8,
Expand Down Expand Up @@ -2031,12 +2030,6 @@ def aten_convolution(
stride = (stride, stride)
strides = list(stride)

if bias is None:
weight_dim_0 = op.Shape(weight, start=0, end=1)
bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
zero = op.CastLike(0.0, input)
bias = op.Expand(zero, bias_shape)

result = _aten_convolution_onnx(
input,
weight,
Expand Down Expand Up @@ -3564,14 +3557,14 @@ def aten_flipud(self: TensorType) -> TensorType:


@torch_op("aten::floor", traceable=True)
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_floor(self: TFloat) -> TFloat:
"""floor(Tensor self) -> Tensor"""

return op.Floor(self)


@torch_op("math::floor", traceable=True)
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
def python_math_floor(self: TFloat) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)
Expand Down Expand Up @@ -4533,7 +4526,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:


@torch_op("aten::isinf")
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isinf(self: TFloat) -> BOOL:
"""isinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4542,14 +4535,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isnan")
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
def aten_isnan(self: TFloat) -> BOOL:
"""isnan(Tensor self) -> Tensor"""

return op.IsNaN(self)


@torch_op("aten::isneginf")
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
def aten_isneginf(self: TFloat) -> BOOL:
"""isneginf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4558,7 +4551,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isposinf")
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isposinf(self: TFloat) -> BOOL:
"""isposinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand Down Expand Up @@ -4778,42 +4771,42 @@ def aten_linspace(


@torch_op("aten::log", traceable=True)
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log(self: TFloat) -> TFloat:
"""log(Tensor self) -> Tensor"""

return op.Log(self)


@torch_op("aten::log10", traceable=True)
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log10(self: TFloat) -> TFloat:
"""log10(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))


@torch_op("aten::log1p")
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log1p(self: TFloat) -> TFloat:
"""log1p(Tensor self) -> Tensor"""

return op.Log(op.Add(self, 1.0))


@torch_op("aten::log2", traceable=True)
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log2(self: TFloat) -> TFloat:
"""log2(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))


@torch_op("aten::logaddexp", traceable=True)
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""

return op.Log(op.Add(op.Exp(self), op.Exp(other)))


@torch_op("aten::logaddexp2", traceable=True)
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
Expand All @@ -4822,7 +4815,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr


@torch_op("aten::logcumsumexp", traceable=True)
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
"""logcumsumexp(Tensor self, int dim) -> Tensor"""

if IsScalar(self):
Expand Down Expand Up @@ -4908,12 +4901,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:


@torch_op("aten::logit", private=True)
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def _aten_logit_onnx(self: TFloat) -> TFloat:
return op.Log(op.Div(self, op.Sub(1.0, self)))


@torch_op("aten::logit", private=True)
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
eps = op.CastLike(eps, self)
one = op.CastLike(1.0, self)
temporary_self = op.Where(self <= one - eps, self, one - eps)
Expand All @@ -4923,7 +4916,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat


@torch_op("aten::logit", trace_only=True)
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
"""logit(Tensor self, float? eps=None) -> Tensor"""
if eps is None:
return _aten_logit_onnx(self)
Expand Down Expand Up @@ -6041,9 +6034,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:


@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

result, mask = op.Dropout(input, p, train)
Expand Down Expand Up @@ -7055,7 +7046,7 @@ def aten_real(self: TensorType) -> TensorType:


@torch_op("aten::reciprocal")
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

return op.Reciprocal(self)
Expand All @@ -7074,7 +7065,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Improve fp16 precision by following the logic in
Expand Down Expand Up @@ -7355,7 +7346,7 @@ def aten_rrelu(


@torch_op("aten::rsqrt", traceable=True)
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_rsqrt(self: TFloat) -> TFloat:
"""rsqrt(Tensor self) -> Tensor"""

return op.Reciprocal(op.Sqrt(self))
Expand Down Expand Up @@ -7562,7 +7553,7 @@ def aten_sgn(self: TensorType) -> TensorType:


@torch_op("aten::sigmoid", traceable=True)
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sigmoid(self: TFloat) -> TFloat:
"""sigmoid(Tensor self) -> Tensor"""

return op.Sigmoid(self)
Expand Down Expand Up @@ -7724,7 +7715,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:


@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand All @@ -7741,7 +7732,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB


@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -7812,7 +7803,7 @@ def aten_split_with_sizes_copy(


@torch_op("aten::sqrt", traceable=True)
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sqrt(self: TFloat) -> TFloat:
"""sqrt(Tensor self) -> Tensor"""

return op.Sqrt(self)
Expand Down Expand Up @@ -8402,7 +8393,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:


@torch_op("aten::trunc")
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""

# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
Expand Down
11 changes: 5 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
Expand Down Expand Up @@ -364,13 +363,13 @@ def aten_conv_depthwise3d(

@torch_op("aten::cross_entropy_loss", traceable=True)
def aten_cross_entropy_loss(
self: TFloatOrBFloat16,
self: TFloat,
target: IntType,
weight: Optional[TFloatOrBFloat16] = None,
weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloatOrBFloat16:
) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""

if reduction == 0: # "none"
Expand Down Expand Up @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te


@torch_op("aten::leaky_relu")
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""

return op.LeakyRelu(self, alpha=negative_slope)
Expand Down Expand Up @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:


@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""

return op.Log(op.Sigmoid(self))
Expand Down
Loading

0 comments on commit 2b497fb

Please sign in to comment.