Skip to content

Commit

Permalink
Bug fixes in split-sequence optimization (#1872)
Browse files Browse the repository at this point in the history
A couple of bugs in the optimization for split-sequence:

* Handle the case where there is only one split-value (as the op-builder
returns a single IR value instead of a list of IR values in this case).
* Use 1D constant [axis] instead of scalar axis in Squeeze op.
  • Loading branch information
gramalingam authored Sep 19, 2024
1 parent 0de44ba commit b0ca0c3
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,16 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
else:
return None

# If Split returns a single value, we need to wrap it into a list.
if isinstance(split_values, ir.Value):
split_values = [split_values]

keepdims = _get_int_attribute(node, "keepdims", 1)
if keepdims is None:
return None
if keepdims == 0:
# squeeze the split dimension if keepdims is 0
axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"])
axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"])
squeezed_values = []
for i in range(num_outputs):
squeezed = op.Squeeze(
Expand Down

0 comments on commit b0ca0c3

Please sign in to comment.