Skip to content

Commit

Permalink
feat: add support for ExtractImagePatches
Browse files Browse the repository at this point in the history
Signed-off-by: Nanoskript <[email protected]>
  • Loading branch information
nanoskript committed Jun 16, 2023
1 parent b27aa05 commit 2da669d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
matrix_diag_part = tf.compat.v1.matrix_diag_part
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars
extract_image_patches = tf.image.extract_patches
elif Version(tf.__version__) >= Version("1.13"):
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
Expand All @@ -94,6 +95,7 @@
matrix_diag_part = tf.compat.v1.matrix_diag_part
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars
extract_image_patches = tf.compat.v1.extract_image_patches
else:
conv2d_backprop_input = tf.nn.conv2d_backprop_input
conv3d_transpose = tf.nn.conv3d_transpose
Expand All @@ -111,6 +113,7 @@
is_inf = tf.is_inf
floormod = tf.floormod
matrix_diag_part = tf.matrix_diag_part
extract_image_patches = tf.extract_image_patches


def make_xval(shape):
Expand Down Expand Up @@ -6248,5 +6251,21 @@ def func(tensor, indices, updates):
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})

def test_extract_image_patches(self):
for rates in [[1, 1], [1, 4], [4, 1], [3, 3]]:
for _, padding, x_shape, sizes, strides in get_conv_getdata():
def func(x):
return extract_image_patches(
x,
sizes=sizes,
strides=strides,
rates=[1] + rates + [1],
padding=padding,
name=_TFOUTPUT
)

x_val = make_xval(x_shape)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

if __name__ == '__main__':
unittest_main()
2 changes: 2 additions & 0 deletions tf2onnx/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2
from tf2onnx.rewriter.gru_tf2_rewriter import rewrite_gru_tf2
from tf2onnx.rewriter.fused_op_rewriter import rewrite_fused_ops
from tf2onnx.rewriter.extract_image_patches_rewriter import rewrite_extract_image_patches


__all__ = [
Expand Down Expand Up @@ -53,4 +54,5 @@
"rewriter_lstm_tf2",
"rewrite_gru_tf2",
"rewrite_fused_ops",
"rewrite_extract_image_patches",
]
86 changes: 86 additions & 0 deletions tf2onnx/rewriter/extract_image_patches_rewriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0


"""
tf2onnx.rewriter.extract_image_patches_rewriter - Rewrites ExtractImagePatches into supported operations.
"""

import numpy as np
from tf2onnx import utils
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher


# pylint: disable=missing-docstring

def rewrite_extract_image_patches(g, ops):
pattern = OpTypePattern("ExtractImagePatches", name="extract_image_patches")
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match_result in match_results:
operation = match_result.get_op("extract_image_patches")
input_shape = g.get_shape(operation.input[0])
output_shape = operation.output_shapes[0]

sizes = operation.get_attr_value("ksizes")
strides = operation.get_attr_value("strides")
rates = operation.get_attr_value("rates")
padding = operation.get_attr_str("padding")

# Our constraints.
utils.make_sure(0 not in output_shape, "Empty ExtractImagePatches output is unsupported.")
[_, size_rows, size_cols, _] = sizes

# Transform input into [N * C, H, W, 1].
transformed_input = g.make_node("Reshape", inputs=[
g.make_node("Transpose", inputs=operation.input, attr=dict(perm=[0, 3, 1, 2])).output[0],
g.make_const(utils.make_name("new_shape"), np.int64([
input_shape[0] * input_shape[3],
input_shape[1],
input_shape[2],
1,
])).output[0],
])

# Create identity kernel.
k = size_rows * size_cols
identity_kernel = g.make_node("Reshape", inputs=[
g.make_node("EyeLike", inputs=[
g.make_node("ConstantOfShape", inputs=[
g.make_const(utils.make_name("eye_size"), np.array([k, k], dtype=np.int64)).output[0],
]).output[0],
]).output[0],
g.make_const(utils.make_name("new_shape"), np.array([
size_rows,
size_cols,
1,
k,
], dtype=np.int64)).output[0],
])

# Convolve into [N * C, ?H, ?W, K].
convolution = g.make_node("Conv2D", inputs=[transformed_input.output[0], identity_kernel.output[0]],
attr=dict(strides=strides, dilations=rates, padding=padding, data_format="NHWC"),
shapes=[[input_shape[0] * input_shape[3], output_shape[1], output_shape[2], k]],
dtypes=operation.output_dtypes, skip_conversion=False)

# Transform into [N, ?H, ?W, C * K].
output_node = g.make_node("Reshape", inputs=[
g.make_node("Transpose", inputs=[
g.make_node("Reshape", inputs=[
convolution.output[0],
g.make_const(utils.make_name("new_shape"), np.array([
input_shape[0],
input_shape[3],
output_shape[1],
output_shape[2],
k,
], dtype=np.int64)).output[0],
]).output[0],
], attr=dict(perm=[0, 2, 3, 4, 1])).output[0],
g.make_const(utils.make_name("new_shape"), np.array(output_shape, dtype=np.int64)).output[0],
])

# Replace node.
g.replace_all_inputs(operation.output[0], output_node.output[0])
g.remove_node(operation.name)
return g.get_nodes()
1 change: 1 addition & 0 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def compat_handler(ctx, node, **kwargs):
rewriter_lstm_tf2,
rewrite_gru_tf2,
rewrite_single_direction_lstm,
rewrite_extract_image_patches,
# bi-directional
rewrite_bi_direction_lstm,
rewrite_single_direction_gru,
Expand Down

0 comments on commit 2da669d

Please sign in to comment.