Skip to content

Commit

Permalink
Add a condiction when we try to remove a transpose node. (#2272)
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Zhang <[email protected]>
  • Loading branch information
fatcat-z committed Nov 28, 2023
1 parent ae4c39e commit 07dfaa8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,14 @@ def _add_handler(self, trans, node):
return True
return self._handle_node_having_branches(trans, node)

def _output_node_has_single_consumer_node(self, node):
output_node = self._g.get_node_by_name(node.output[0])
return output_node and output_node.output and self._nodes_has_single_consumer_node([output_node])

def _transpose_handler(self, trans, node):
perm = trans.get_attr_value("perm")
perm_inv = invert_perm(perm)
if is_tranpose_of_type(node, perm_inv):
if is_tranpose_of_type(node, perm_inv) and self._output_node_has_single_consumer_node(node):
for g in {self._g, node.graph}:
g.replace_all_inputs(node.output[0], trans.input[0]) # ops=g.get_nodes()

Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


version = '1.15.1'
git_version = 'dc6155b52a137d858456fcc6bc720c327eec5612'
git_version = 'ae4c39ed3bdab7edf487d73d5892a573684d1d6a'

0 comments on commit 07dfaa8

Please sign in to comment.