Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for ExtractImagePatches #2188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nanoskript
Copy link

Closes: #436

This rewrite is based on this comment: #436 (comment) with changes to make it more general and translatable into tf2onnx.

Equivalent TensorFlow function and automated test script (expand)
import tensorflow as tf
import numpy as np
from hypothesis import given, strategies as st, settings, assume


def our_extract_image_patches(sizes, strides, rates, padding):
    # TensorFlow's constraints.
    assert sizes[0] == 1 and sizes[3] == 1
    assert strides[0] == 1 and strides[3] == 1
    assert rates[0] == 1 and rates[3] == 1
    assert padding in ["SAME", "VALID"]

    # Extract size.
    [_, size_rows, size_cols, _] = sizes

    @tf.function
    def function(tensor):
        # Input shape of [N, H, W, C].
        tensor_shape = tensor.shape

        # Transpose and reshape to [N * C, H, W, 1].
        tensor = tf.transpose(tensor, perm=[0, 3, 1, 2])
        tensor = tf.reshape(tensor, [
            tensor_shape[0] * tensor_shape[3],
            tensor_shape[1],
            tensor_shape[2],
            1,
        ])

        # Convolve with identity kernel into [N * C, ?H, ?W, K].
        k = size_rows * size_cols
        kernel = tf.reshape(tf.eye(k), [size_rows, size_cols, 1, k])
        convolution = tf.nn.conv2d(tensor, kernel, strides=strides, padding=padding, dilations=rates)

        # Reshape into [N, C, ?H, ?W, K].
        reshaped = tf.reshape(convolution, [
            tensor_shape[0],
            tensor_shape[3],
            convolution.shape[1],
            convolution.shape[2],
            k,
        ])

        # Transpose and reshape into [N, ?H, ?W, C * K].
        patches = tf.transpose(reshaped, perm=[0, 2, 3, 4, 1])
        return tf.reshape(patches, [
            tensor_shape[0],
            convolution.shape[1],
            convolution.shape[2],
            tensor_shape[3] * k,
        ])

    return function


def tf_extract_image_patches(sizes, strides, rates, padding):
    @tf.function
    def function(tensor):
        return tf.image.extract_patches(
            tensor,
            sizes=sizes,
            strides=strides,
            rates=rates,
            padding=padding,
        )

    return function


@settings(max_examples=5000)
@given(
    st.lists(st.integers(min_value=1, max_value=20), min_size=4, max_size=4),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.sampled_from(["VALID", "SAME"]),
)
def test_equal(shape, size_rows, size_cols, stride_rows, stride_cols, dil_rows, dil_cols, padding):
    sizes = [1, size_rows, size_cols, 1]
    strides = [1, stride_rows, stride_cols, 1]
    rates = [1, dil_rows, dil_cols, 1]

    try:
        tensor = tf.cast(tf.reshape(tf.range(np.prod(shape)), shape), dtype=tf.float32)
        tfs = tf_extract_image_patches(sizes, strides, rates, padding)(tensor)

        if 0 in tfs.shape:
            # We cannot handle operations that produce empty outputs.
            assume(False)
    except ValueError:
        # Ignore input if TensorFlow would fail.
        assume(False)
        return

    ours = our_extract_image_patches(sizes, strides, rates, padding)(tensor)
    assert tf.reduce_all(tf.math.equal(tfs, ours)).numpy()

Output from pytest convolve.py --hypothesis-show-statistics (no failures):

convolve.py::test_equal:

  - during generate phase (70.74 seconds):
    - Typical runtimes: ~ 1-14 ms, of which < 1ms in data generation
    - 5000 passing examples, 0 failing examples, 3913 invalid examples

  - Stopped because settings.max_examples=5000

@nanoskript nanoskript force-pushed the add-extract-image-patches branch 2 times, most recently from b4e1d24 to 2da669d Compare June 16, 2023 12:58
@nanoskript nanoskript marked this pull request as ready for review June 16, 2023 13:02
@fatcat-z
Copy link
Collaborator

fatcat-z commented Jul 30, 2023

Thanks you for putting that solution into this PR, and it looks great!

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op. Each rewriter will search the ONNX graph following a given pattern. Once the pattern is matched, those involved onnx ops will be replaced with some other ops for an optimization in further inference.

In this case, ExtractImagePatches is just a tf op which is not supported by tf2onnx yet. So, your implementations should be put into nn.py file instead of adding a rewriter. Please add it into nn.py, just like adding a new tf op support.

Please feel free to refer to this comment for more details.

@nanoskript
Copy link
Author

Hi @fatcat-z,

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.

I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:

run_rewriters(g, rewriters, continue_on_error)

where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate.

I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in nn.py, please let me know.

@fatcat-z
Copy link
Collaborator

fatcat-z commented Aug 2, 2023

Hi @fatcat-z,

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.

I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:

run_rewriters(g, rewriters, continue_on_error)

where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate.
I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in nn.py, please let me know.

No, graphs_from_tf() function will transfer the tf graph to onnx graph meaning each tf op has been converted to onnx op, if possible. Afterwards, process_parsed_graph() will be called to finish those rewriters and optimizations.

Yes, please implement this as an op in nn.py instead of creating a new rewriter. Thanks.

@ruihu102
Copy link

Is this new operator going to be merged into main?

@nanoskript
Copy link
Author

Is this new operator going to be merged into main?

I don't think this operator can be considered to be new but I'm aiming to get the requested changes done sometime within the week.

@nanoskript
Copy link
Author

Sorry for the delay! I've implemented this operation inside of nn.py instead of as a rewriter. Let me know if anything else needs to be changed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ValueError: tensorflow op ExtractImagePatches is not supported
3 participants