Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experimental] Support for dynamic Tile, dynamic Reshape #623

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.20.8
ghcr.io/pinto0309/onnx2tf:1.20.9

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.20.8
docker.io/pinto0309/onnx2tf:1.20.9

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.20.8'
__version__ = '1.20.9'
6 changes: 4 additions & 2 deletions onnx2tf/ops/Reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def make_node(

# If Reshape's shape contains zeros, get the deformed shape from the output shape
if isinstance(reshape_shape, list) and reshape_shape.count(0) > 0:
new_shape = [-1 if isinstance(s, str) else int(s) for s in output_shape]
before_tensor_shapes = tf.shape(tf_layers_dict[graph_node_input_1.name]['tf_node'])
new_shape = [before_tensor_shapes[idx] if isinstance(s, str) else int(s) for idx, s in enumerate(output_shape)]
reshape_shape = new_shape
elif isinstance(reshape_shape, np.ndarray) and np.count_nonzero(reshape_shape == 0) > 0:
new_shape = [-1 if isinstance(s, str) else int(s) for s in output_shape]
before_tensor_shapes = tf.shape(tf_layers_dict[graph_node_input_1.name]['tf_node'])
new_shape = [before_tensor_shapes[idx] if isinstance(s, str) else int(s) for idx, s in enumerate(output_shape)]
reshape_shape = new_shape

onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
Expand Down
228 changes: 129 additions & 99 deletions onnx2tf/ops/Tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,114 +155,144 @@ def define_tile(
)

tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_1.shape))))
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_2))))
tensor_2_candidate_for_transpositions = None

for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
try:
# Build TF dummy model
input_1 = tf_keras.Input(
shape=validation_data_1.shape[1:],
batch_size=validation_data_1.shape[0] \
if isinstance(validation_data_1.shape[0], int) else None,
name='dummy_input_1',
dtype=validation_data_1.dtype,
)
input_2 = validation_data_2
dummy_tile = define_tile(
target_input_tensor_1=input_1,
target_perm_1=list(tensor_1_candidate_for_transposition),
target_input_tensor_2=input_2,
target_perm_2=list(tensor_2_candidate_for_transposition),
target_name=graph_node.name,
**kwargs
)
# Verify that the output shape matches that of ONNX
# If the combination of each value of a dimension is not correct,
# invalidate the normal processing judgment.
onnx_output_shape_prod = np.prod([dim if not isinstance(dim, str) else -1 for dim in onnx_output_shape])
tile_output_shapes = list(dummy_tile.shape)
tile_output_shape_prod = np.prod([dim if dim is not None else -1 for dim in tile_output_shapes])
if onnx_output_shape_prod != tile_output_shape_prod:
del input_1
del input_2
del dummy_tile
continue

# Perform simple accuracy verification
# Terminate when the error is less than 1e-3
if onnx_tensor_infos:
try:
# Search for the axis with the smallest error
val_model = tf_keras.Model(
inputs=[
input_1,
],
outputs=[
dummy_tile,
],
)
if isinstance(input_tensor_2, int):
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(input_tensor_2)))
elif isinstance(input_tensor_2, np.ndarray) and hasattr(input_tensor_2, '__len__'):
tiles_tensor_length = len(input_tensor_2)
if tiles_tensor_length > 1:
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(input_tensor_2))))
else:
tensor_2_candidate_for_transpositions = list(itertools.permutations(input_tensor_2))
elif tf_keras.backend.is_keras_tensor(input_tensor_2) and hasattr(input_tensor_2.shape, '__len__'):
tiles_tensor_length = len(input_tensor_2.shape)
if tiles_tensor_length > 1:
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(tiles_tensor_length)))
else:
# Dynamic Tensor
pass
else:
# Unknown
pass

# TF dummy inference
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
model=val_model,
inputs=[
input_1,
],
verification_datas=[
validation_data_1,
],
)
if tensor_2_candidate_for_transpositions is not None:
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
try:
# Build TF dummy model
input_1 = tf_keras.Input(
shape=validation_data_1.shape[1:],
batch_size=validation_data_1.shape[0] \
if isinstance(validation_data_1.shape[0], int) else None,
name='dummy_input_1',
dtype=validation_data_1.dtype,
)
input_2 = validation_data_2
dummy_tile = define_tile(
target_input_tensor_1=input_1,
target_perm_1=list(tensor_1_candidate_for_transposition),
target_input_tensor_2=input_2,
target_perm_2=list(tensor_2_candidate_for_transposition),
target_name=graph_node.name,
**kwargs
)
# Verify that the output shape matches that of ONNX
# If the combination of each value of a dimension is not correct,
# invalidate the normal processing judgment.
onnx_output_shape_prod = np.prod([dim if not isinstance(dim, str) else -1 for dim in onnx_output_shape])
tile_output_shapes = list(dummy_tile.shape)
tile_output_shape_prod = np.prod([dim if dim is not None else -1 for dim in tile_output_shapes])
if onnx_output_shape_prod != tile_output_shape_prod:
del input_1
del input_2
del dummy_tile
del val_model
continue

# Perform simple accuracy verification
# Terminate when the error is less than 1e-3
if onnx_tensor_infos:
try:
# Search for the axis with the smallest error
val_model = tf_keras.Model(
inputs=[
input_1,
],
outputs=[
dummy_tile,
],
)

# Validation
onnx_tf_output_pairs = {
(oi[0], ti[0]): (oi[1], ti[1]) \
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
}
"""
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
max_abs_err,
]
# TF dummy inference
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
model=val_model,
inputs=[
input_1,
],
verification_datas=[
validation_data_1,
],
)
del input_1
del input_2
del dummy_tile
del val_model

# Validation
onnx_tf_output_pairs = {
(oi[0], ti[0]): (oi[1], ti[1]) \
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
}
"""
check_results = onnx_tf_tensor_validation(
output_pairs=onnx_tf_output_pairs,
rtol=0.0,
atol=0.0,
)
result_err = sum([val[2] for val in check_results.values()])
if result_err < min_abs_err:
min_abs_err = result_err
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
if min_abs_err < 1e-3:
break
except Exception as ex1:
pass
except Exception as ex2:
pass
else:
continue
break
"""
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
max_abs_err,
]
}
"""
check_results = onnx_tf_tensor_validation(
output_pairs=onnx_tf_output_pairs,
rtol=0.0,
atol=0.0,
)
result_err = sum([val[2] for val in check_results.values()])
if result_err < min_abs_err:
min_abs_err = result_err
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
if min_abs_err < 1e-3:
break
except Exception as ex1:
pass
except Exception as ex2:
pass
else:
continue
break

# Generation of TF OP
tf_layers_dict[graph_node_output.name]['tf_node'] = \
define_tile(
target_input_tensor_1=input_tensor_1,
target_perm_1=min_abs_err_perm_1,
target_input_tensor_2=input_tensor_2,
target_perm_2=min_abs_err_perm_2,
target_name=graph_node.name,
**kwargs
)
if tensor_2_candidate_for_transpositions is not None:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
define_tile(
target_input_tensor_1=input_tensor_1,
target_perm_1=min_abs_err_perm_1,
target_input_tensor_2=input_tensor_2,
target_perm_2=min_abs_err_perm_2,
target_name=graph_node.name,
**kwargs
)
else:
# Dynamic Tensor
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.tile(
input=input_tensor_1 \
if not isinstance(input_tensor_1, np.ndarray) \
else tf.convert_to_tensor(input_tensor_1),
multiples=tf.convert_to_tensor([dim for dim in input_tensor_2]),
)

# Post-process transpose
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
Expand Down
Loading