Skip to content

Commit

Permalink
Merge branch 'main' into titaiwang/make_trace_function_promote_constant
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Oct 23, 2024
2 parents ed7636f + 2b60939 commit 6b0b324
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 71 deletions.
113 changes: 70 additions & 43 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TInt,
TReal,
TRealOrUInt8,
TRealUnlessFloat16OrInt8,
TRealUnlessInt16OrInt8,
TTensor,
TTensor2,
Expand Down Expand Up @@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:

@torch_op("aten::arange", trace_only=True)
def aten_arange(
end: float,
end: TRealUnlessFloat16OrInt8,
dtype: int = -1,
layout: str = "",
device: str = "",
Expand All @@ -549,10 +550,9 @@ def aten_arange(
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1 or dtype is None:
if isinstance(end, int):
result = op.Range(0, end, 1)
else:
result = op.Range(0.0, end, 1.0)
zero = op.CastLike(0.0, end)
one = op.CastLike(1.0, end)
result = op.Range(zero, end, one)
elif _range_supported(dtype):
end = op.Cast(end, to=dtype)
zero = op.Cast(0, to=dtype)
Expand All @@ -563,7 +563,7 @@ def aten_arange(
# because the input dtype may be e.g. bfloat16 / int8 etc.
# which Range does not support. The output type is ensured because the output
# is casted to the specified dtype.
end = op.Constant(value_float=float(end))
end = op.Cast(end, to=FLOAT.dtype)
zero = op.Constant(value_float=0.0)
one = op.Constant(value_float=1.0)
result = op.Cast(op.Range(zero, end, one), to=dtype)
Expand All @@ -573,8 +573,8 @@ def aten_arange(

@torch_op("aten::arange.start", trace_only=True)
def aten_arange_start(
start: float,
end: float,
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
dtype: int = -1,
layout: str = "",
device: str = "",
Expand All @@ -583,12 +583,8 @@ def aten_arange_start(
"""arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1 or dtype is None:
if isinstance(start, int) and isinstance(end, int):
result = op.Range(start, end, 1)
else:
start = float(start)
end = float(end)
result = op.Range(start, end, 1.0)
one = op.CastLike(1.0, end)
result = op.Range(start, end, one)
elif _range_supported(dtype):
end = op.Cast(end, to=dtype)
start = op.Cast(start, to=dtype)
Expand All @@ -599,46 +595,78 @@ def aten_arange_start(
# because the input dtype may be e.g. bfloat16 / int8 etc.
# which Range does not support. The output type is ensured because the output
# is casted to the specified dtype.
end = op.Constant(value_float=float(end))
start = op.Constant(value_float=float(start))
end = op.Cast(end, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
one = op.Constant(value_float=1.0)
result = op.Cast(op.Range(start, end, one), to=dtype)

return result


def _adjust_args_for_arange_int_dtype(
start: float,
end: float,
step: float,
) -> Tuple[float, float, float]:
if start < 0:
start = math.ceil(start)
if step < 0:
start = math.floor(start)
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
step: TRealUnlessFloat16OrInt8,
) -> Tuple[FLOAT, FLOAT, FLOAT]:
zero = op.Cast(0.0, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)

return float(start), float(end), float(step)
start = op.Where(op.Less(start, zero), op.Ceil(start), start)
start = op.Where(op.Less(step, zero), op.Floor(start), start)

return (start, end, step)


@torch_op("aten::arange.start_step", trace_only=True)
def aten_arange_start_step(
start: float,
end: float,
step: float = 1.0,
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
step: TRealUnlessFloat16OrInt8 = 1.0,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1 or dtype is None:
if isinstance(start, int) and isinstance(end, int):
result = op.Range(start, end, int(step))
if dtype == -1:
# TODO: Because this is a trace_only function, the inputs are not promoted to
# Tensor until it hits ONNX ops. However, if it's dynamic, it should be
# Tensor at this point.
# https://github.com/microsoft/onnxscript/issues/1914
if isinstance(start, (int, float)):
start_is_int = isinstance(start, int)
else:
start = float(start)
end = float(end)
step = float(step)
start_is_int = start.dtype in {
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if isinstance(end, (int, float)):
end_is_int = isinstance(end, int)
else:
end_is_int = end.dtype in {
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if isinstance(step, (int, float)):
step_is_int = isinstance(step, int)
else:
step_is_int = step.dtype in {
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if start_is_int and end_is_int and step_is_int:
result = op.Range(start, end, step)
else:
# to float
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)
result = op.Range(start, end, step)
elif _integral_to_be_adjusted(dtype):
# PyTorch arange op handles these integral types differently from INT64,
Expand All @@ -647,18 +675,18 @@ def aten_arange_start_step(
start, end, step = _adjust_args_for_arange_int_dtype(start, end, step)
result = op.Cast(op.Range(start, end, step), to=dtype)
elif dtype == INT64.dtype:
end = int(end)
start = int(start)
step = int(step)
end = op.Cast(end, to=dtype)
start = op.Cast(start, to=dtype)
step = op.Cast(step, to=dtype)
result = op.Range(start, end, step)
else:
# Cast input to float if dtype is not supported by Range,
# because the input dtype may be e.g. bfloat16,
# which Range does not support. The output type is ensured because the output
# is casted to the specified dtype.
end = float(end)
start = float(start)
step = float(step)
end = op.Cast(end, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)
result = op.Cast(op.Range(start, end, step), to=dtype)

return result
Expand Down Expand Up @@ -4735,8 +4763,8 @@ def aten_linear_backward(

@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: float,
end: float,
start: TFloat,
end: TFloat,
steps: int,
dtype: int = FLOAT.dtype,
layout: str = "",
Expand All @@ -4754,7 +4782,6 @@ def aten_linspace(
if steps == 1:
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)

# TODO(justinchuby): Simplify the logic knowing start and end are floats
rg = aten_arange_start(0, steps, dtype=dtype)
start = op.Cast(start, to=dtype)
end = op.Cast(end, to=dtype)
Expand Down
24 changes: 21 additions & 3 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_enums.DataType.FLOAT8E5M2FNUZ,
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
)
)

Expand Down Expand Up @@ -182,7 +183,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
When the dtype is not one of the numpy native dtypes, the value needs need to be:
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
- ``uint8`` for uint4.
- ``uint8`` for uint4 or float4.
- ``uint8`` for 8-bit data types.
- ``uint16`` for bfloat16
Expand Down Expand Up @@ -213,6 +214,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
raise TypeError(
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
)
if dtype == _enums.DataType.FLOAT4E2M1:
if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
raise TypeError(
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
)
return

try:
Expand Down Expand Up @@ -256,6 +262,8 @@ def _maybe_view_np_array_with_ml_dtypes(
return array.view(ml_dtypes.int4)
if dtype == _enums.DataType.UINT4:
return array.view(ml_dtypes.uint4)
if dtype == _enums.DataType.FLOAT4E2M1:
return array.view(ml_dtypes.float4_e2m1fn)
return array


Expand Down Expand Up @@ -431,7 +439,11 @@ def tobytes(self) -> bytes:
"""
# TODO(justinchuby): Support DLPack
array = self.numpy()
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
if self.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Pack the array into int4
array = _type_casting.pack_int4(array)
else:
Expand Down Expand Up @@ -609,7 +621,11 @@ def _load(self):
)
# Handle the byte order correctly by always using little endian
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
if self.dtype in {
_enums.DataType.INT4,
_enums.DataType.UINT4,
_enums.DataType.FLOAT4E2M1,
}:
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
dt = np.dtype(np.uint8).newbyteorder("<")
count = self.size // 2 + self.size % 2
Expand All @@ -622,6 +638,8 @@ def _load(self):
self._array = _type_casting.unpack_int4(self._array, shape)
elif self.dtype == _enums.DataType.UINT4:
self._array = _type_casting.unpack_uint4(self._array, shape)
elif self.dtype == _enums.DataType.FLOAT4E2M1:
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
else:
self._array = self._array.reshape(shape)

Expand Down
36 changes: 32 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self):
("int4", np.int8, ir.DataType.INT4),
("int4_uint8", np.uint8, ir.DataType.INT4),
("uint4", np.uint8, ir.DataType.UINT4),
("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1),
]
)
def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType):
Expand Down Expand Up @@ -131,34 +132,48 @@ def test_tobytes(self):
tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
self.assertEqual(tensor.tobytes(), array.tobytes())

def test_tobtyes_returns_packed_data_for_int4(self):
def test_tobytes_returns_packed_data_for_int4(self):
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")

def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self):
def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self):
array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.INT4)
self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")

def test_tobtyes_returns_packed_data_for_uint4(self):
def test_tobytes_returns_packed_data_for_uint4(self):
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self):
def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self):
array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.UINT4)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_tobytes_returns_packed_data_for_float4e2m1(self):
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self):
array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
# Test odd sized array
assert len(array) % 2 == 1
tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1)
self.assertEqual(tensor.tobytes(), b"\x10r\x0f")

def test_metadata(self):
array = np.random.rand(1, 2).astype(np.float32)
tensor = _core.Tensor(array)
Expand Down Expand Up @@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype):
# about permission errors
del tensor

def test_external_tensor_float4e2m1(self):
expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn)
tensor_proto = ir.serde.serialize_tensor(
ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1)
)
with tempfile.TemporaryDirectory() as temp_dir:
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir)
np.testing.assert_array_equal(tensor.numpy(), expected_array)
# Close the mmap file by deleting the reference to tensor so Windows doesn't complain
# about permission errors
del tensor

def test_external_tensor_empty_tensor(self):
expected_array = np.array([], dtype=np.float32)
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array))
Expand Down
9 changes: 9 additions & 0 deletions onnxscript/ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
FLOAT8E5M2FNUZ = 20
UINT4 = 21
INT4 = 22
FLOAT4E2M1 = 23

@classmethod
def from_numpy(cls, dtype: np.dtype) -> DataType:
Expand Down Expand Up @@ -121,6 +122,7 @@ def __str__(self) -> str:
DataType.FLOAT8E5M2FNUZ: 1,
DataType.UINT4: 0.5,
DataType.INT4: 0.5,
DataType.FLOAT4E2M1: 0.5,
}


Expand Down Expand Up @@ -150,5 +152,12 @@ def __str__(self) -> str:
np.dtype(ml_dtypes.uint4): DataType.UINT4,
}

# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
_NP_TYPE_TO_DATA_TYPE.update(
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
if hasattr(ml_dtypes, "float4_e2m1fn")
else {}
)

# ONNX DataType to Numpy dtype.
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
Loading

0 comments on commit 6b0b324

Please sign in to comment.