Skip to content

Commit

Permalink
Fix catch_errors flag and check axes of expand_dims in einsum optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
acmore committed Dec 14, 2023
1 parent 07dfaa8 commit e791c8b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
g = process_tf_graph(tf_graph, const_node_values=const_node_values,
custom_op_handlers=custom_op_handlers, **kwargs)
if constants.ENV_TF2ONNX_CATCH_ERRORS in os.environ:
catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
catch_errors = os.environ.get(constants.ENV_TF2ONNX_CATCH_ERRORS).upper() == "TRUE"
else:
catch_errors = not large_model
onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers)
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/optimizer/einsum_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def _compute_output_row_expand_dims(self, row, row2=None, ab=False):
self._check_row_(row, True)
self._check_arg_('axes', tuple)
axes = self.kwargs['axes']
if not axes:
raise RuntimeError("Parameter axes of expand_dims should not be empty.")
for axis in axes:
if not isinstance(axis, tuple):
raise TypeError(
Expand Down

0 comments on commit e791c8b

Please sign in to comment.