Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten_convolution_backward function #1707

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,7 @@
return result


@torch_op("aten::convolution_backward", trace_only=True)
def aten_convolution_backward(
grad_output: TensorType,
input: TensorType,
Expand All @@ -2108,7 +2109,71 @@
) -> tuple[TensorType, TensorType, TensorType]:
"""convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
# Compute weight.grad : dW_t = X_t * dZ_t
input_t = op.Transpose(input, perm=[1,0,2,3])
dz_t = op.Transpose(grad_output, perm=[1,0,2,3])
dw_t = op.Conv(input_t, dz_t)
dw = op.Transpose(dw_t, perm=[1,0,2,3])
axes = op.Constant(value_ints=[0,2,3])
db = op.ReduceSum(grad_output, axes, keepdims=0)

Check warning on line 2118 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2113-L2118

Added lines #L2113 - L2118 were not covered by tests

# Compute x.grad: dx = dZ(+0) * W_rot180
# Assume: grad_output=(20,13,48,38)
z_height = op.Shape(grad_output, start=2, end=3) # 48
z_width = op.Shape(grad_output, start=3, end=4) # 38

Check warning on line 2123 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2122-L2123

Added lines #L2122 - L2123 were not covered by tests

# if stride[0] != 1: # dilation
# dz_height = z_height * stride[0] - stride[0] + 1
# dz_width = z_width * stride[1] - stride[1] + 1
# pos = _help(z_height, dz_width, stride)
# pos = []
# for j in range(z_height):
# for i in range(0, dz_width, stride[1]):
# pos.append(i + j * dz_width * stride[0])

# index_tensor = op.Constant(value_ints=pos)
# index_tensor = op.Reshape(index_tensor, z_shape)
# # this should not work because the kernel_shape is attribute
# dz = op.MaxUnpool(grad_output, index_tensor, kernel_shape=[dz_height - z_height + 1, dz_width - z_width + 1])

# # Computing padding size
Comment on lines +2127 to +2141

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# Assume: input=(20,16,50,40)
x_height = op.Shape(input, start=2, end=3) # 50
x_width = op.Shape(input, start=3, end=4) # 40

Check warning on line 2142 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2141-L2142

Added lines #L2141 - L2142 were not covered by tests
# Assume: weight=(13,16,3,3)
w_height = op.Shape(weight, start=2, end=3) # 3
w_width = op.Shape(weight, start=3, end=4) # 3
tmp_int = x_height - z_height + w_height - 1 # 50-48+3-1=4
tmp_float = op.Cast(tmp_int, to=FLOAT.dtype)
pad_height = op.Cast(op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype) # 4/2=2
tmp_int = x_width - z_width + w_width - 1 # 40-38+3-1=4
tmp_float = op.Cast(tmp_int, to=FLOAT.dtype)
pad_width = op.Cast(op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype) # 4/2=2
pads = op.Concat( # [0,0,2,2,0,0,2,2]

Check warning on line 2152 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2144-L2152

Added lines #L2144 - L2152 were not covered by tests
op.Constant(value_ints=[0]), op.Constant(value_ints=[0]), pad_height, pad_width, # begin of dim0, dim1, dim2, dim3
op.Constant(value_ints=[0]), op.Constant(value_ints=[0]), pad_height, pad_width, axis=0) # end of dim0, dim1, dim2, dim3
dz_pad = op.Pad(grad_output, pads) # enlarge the grad_output to (20,13,52,42)

Check warning on line 2155 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2155

Added line #L2155 was not covered by tests

# Transpose from (13,16,3,3) to (16,13,3,3)
w_transpose = op.Transpose(weight, perm=[1,0,2,3])

Check warning on line 2158 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2158

Added line #L2158 was not covered by tests
# Rotate weight (13,16,3,3) with 180 degree: np.rot90(w, 2) -> (13,6,3,3)
w_shape_0 = op.Shape(w_transpose, start=0, end=1) # 13
w_shape_1 = op.Shape(w_transpose, start=1, end=2) # 6
w_shape_2 = op.Constant(value_ints=[1]) # 1
w_shape_3 = op.Constant(value_ints=[-1]) # -1
w_shape_new = op.Concat(w_shape_0, w_shape_1, w_shape_2, w_shape_3, axis=0) # (13,16,1,-1)
w_new = op.Reshape(w_transpose, w_shape_new) # reshape to (13,16,1,-1)

Check warning on line 2165 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2160-L2165

Added lines #L2160 - L2165 were not covered by tests
# reverse the values in the last dim (axes=3), e.g. [1,2,3....,9] -> [9,...,3,2,1]
starts = op.Constant(value_int=[-1])
ends = op.Constant(value_int=[-1000])
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
axes = op.Constant(value_int=[3])
steps = op.Constant(value_int=[-1])
w_slice = op.Slice(w_new, starts, ends, axes, steps) # weight[:,:,:,-1:-1000:-1]
weight_rot180 = op.Reshape(w_slice, op.Shape(w_transpose)) # reshape to (13,16,3,3)

Check warning on line 2172 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2167-L2172

Added lines #L2167 - L2172 were not covered by tests
# dx = dz(pad0) * w(rot180)
dx = op.Conv(dz_pad, weight_rot180)

Check warning on line 2174 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2174

Added line #L2174 was not covered by tests
# Todo: when dx is bigger than input, e.g. 29x29 vs. 28x28, need to delete last row and column of dx
return dx, dw, db

Check warning on line 2176 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2176

Added line #L2176 was not covered by tests


def aten_convolution_backward_overrideable(
Expand Down
Loading