Skip to content

Commit

Permalink
[IR] Fix broadcast_to_matmul (#1542)
Browse files Browse the repository at this point in the history
- Check whether the shape tensor is constant before using it in the
logic. Exiting early if needed.
- Handle cases when the input is 1d or 0d

Thanks @borisfom for the proposed fix!

Fix #1541
  • Loading branch information
justinchuby authored May 20, 2024
1 parent d71b74f commit 69ae7f4
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 90 deletions.
71 changes: 42 additions & 29 deletions docs/tutorial/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import logging

import numpy as np
import onnx

import onnxscript
Expand Down Expand Up @@ -65,71 +64,81 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
def check_if_not_need_reshape(
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
"""Condition to check if we need to replace the pattern.
If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
Returns:
True if we need to replace the pattern, False otherwise.
"""
del context # Reserved for future extensions

input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
shape_c_value = _ir_utils.propagate_const_value(shape_c)
shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
_ir_utils.propagate_const_value(shape_c)
shape_c_tensor = shape_c.const_value
if shape_c_tensor is None:
logger.info("The value 'shape_c' is not statically known.")
return False
if len(shape_c.shape) != 1:

if len(shape_c_tensor.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c.shape,
shape_c_tensor.shape,
)
return False
shape_c_list = shape_c.tolist()

# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None or shape_c is None:
if input_a_shape is None or input_b_shape is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = list(input_a_shape)
input_b_shape = list(input_b_shape)
input_a_shape = input_a_shape.numpy()
input_b_shape = input_b_shape.numpy()
shape_c = shape_c_tensor.numpy().tolist()

a_rank = len(input_a_shape)
b_rank = len(input_b_shape)

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)
# TODO(justinchuby): Check shape size

# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if dim_a < 2:
if a_rank < 2:
if b_rank < 2:
logger.info("Optimization of dot product is not supported yet.")
return False
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
dim_a = len(input_a_shape)
a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if dim_b < 2:
if b_rank < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
dim_b = len(input_b_shape)
b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
Expand All @@ -149,23 +158,27 @@ def check_if_not_need_reshape(

# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if dim_a > dim_b:
if a_rank > b_rank:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
broadcast_matmul_output_shape = [
*longer_shape[: -len(shorter_shape)],
*broadcast_matmul_output_shape,
]
if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
if shape_c_list != broadcast_matmul_output_shape:
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c_list,
shape_c,
broadcast_matmul_output_shape,
)
return False
Expand Down
47 changes: 30 additions & 17 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import typing

import numpy as np

from onnxscript import ir
Expand All @@ -10,24 +12,35 @@


def propagate_const_value(ir_value: ir.Value) -> ir.Value:
"""Temporary method to propagate a constant value to the IR value."""
node = ir_value.producer()
if ir_value.const_value is None and node is not None and node.op_type == "Constant":
attr_names = [
"value_float",
"value_int",
"value_string",
"value",
"value_floats",
"value_ints",
"value_strings",
]
for attr_name in attr_names:
attr_value = node.attributes.get(attr_name)
if attr_value is not None:
# TODO: RefAttr should be also supported?
if isinstance(attr_value, ir.Attr):
ir_value.const_value = attr_value.value # type: ignore[union-attr]
break
if node is None:
return ir_value
if node.op_type != "Constant":
return ir_value
attr_name, attr_value = next(iter(node.attributes.items()))
if attr_value is None or not isinstance(attr_value, ir.Attr):
return ir_value

const_value: ir.TensorProtocol
if attr_name in {"value_float", "value_floats"}:
const_value = ir.Tensor(
np.array(attr_value.value, dtype=np.float32), name=ir_value.name
)
elif attr_name in {"value_int", "value_ints"}:
const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name)
elif attr_name in {"value_string", "value_strings"}:
const_value = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
)
elif attr_name == "value":
const_value = typing.cast(ir.TensorProtocol, attr_value.value)
else:
return ir_value

ir_value.const_value = const_value
ir_value.shape = const_value.shape # type: ignore
ir_value.dtype = const_value.dtype
return ir_value


Expand Down
Loading

0 comments on commit 69ae7f4

Please sign in to comment.