From 740228296c5337e74b5968b7573435050846af21 Mon Sep 17 00:00:00 2001 From: acmore Date: Thu, 14 Dec 2023 15:00:48 +0800 Subject: [PATCH] Fix catch_errors flag and check axes of expand_dims in einsum optimizer Signed-off-by: acmore --- tf2onnx/convert.py | 2 +- tf2onnx/optimizer/einsum_optimizer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 6ee66c096..843973149 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -168,7 +168,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path g = process_tf_graph(tf_graph, const_node_values=const_node_values, custom_op_handlers=custom_op_handlers, **kwargs) if constants.ENV_TF2ONNX_CATCH_ERRORS in os.environ: - catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE" + catch_errors = os.environ.get(constants.ENV_TF2ONNX_CATCH_ERRORS).upper() == "TRUE" else: catch_errors = not large_model onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers) diff --git a/tf2onnx/optimizer/einsum_optimizer.py b/tf2onnx/optimizer/einsum_optimizer.py index 48eda026c..b13ada19a 100644 --- a/tf2onnx/optimizer/einsum_optimizer.py +++ b/tf2onnx/optimizer/einsum_optimizer.py @@ -373,6 +373,8 @@ def _compute_output_row_expand_dims(self, row, row2=None, ab=False): self._check_row_(row, True) self._check_arg_('axes', tuple) axes = self.kwargs['axes'] + if not axes: + raise RuntimeError("Parameter axes of expand_dims should not be empty.") for axis in axes: if not isinstance(axis, tuple): raise TypeError(