Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten_convolution_backward function #1707

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

xiaowuhu
Copy link
Contributor

@xiaowuhu xiaowuhu commented Jun 25, 2024

Roadmap:

  1. We should use ConvGrad function which existing in onnxruntime-training library. but its domain is com.microsoft, so we cannot get the schema, then cannot run the operator in onnxscript.
  2. We should use col2im and im2col to finish this job, but onnx only provide col2im, NOT provide im2col.
  3. So I use a magic way:

A. Compute dW

$$ dW = X * dZ $$

But need to transpose X to [1,0,2,3], transpose dZ to [1,0,2,3], then using common op.Conv on them, get dW but also need transpose back to [1,0,2,3].

N C H W -> N C H W
X 8 3 7 6 -> 3 8 7 6
dZ 8 2 4 2 -> 2 8 4 2
dW 2 3 4 5 <- 3 2 4 5

B. Compute dX

It is similar but more complicated:

$$ dX = dZ_0^0 * W^{rot180} $$

N C H W -> N C H W
dZ 8 2 4 2 -> 8 2 10 10
W 2 3 4 5 -> 3 2 4 5
dX 8 3 7 6 <- 8 3 7 6
  1. add padding around dZ
  2. transpose W with [1,0,2,3], then rot W with 180 degree
  3. do common Conv operation to get dX

To Do list:

  1. If the forward conv stride != 1, need to do dilation in dZ before compute dX in B.
  2. Sometimes when do conv we may get the output size as 13.5 then do floor(13.5)=13, but when do backward, we also get the dX with size (29x29) but actually it should be (28x28). This need to Slice(dX, 1, 1, x_height, x_width).
  3. When we found dW is bigger than W, we need to Slice(dW, 0, 0, weight_height, weight_width)

Comment on lines +2125 to +2139
# if stride[0] != 1: # dilation
# dz_height = z_height * stride[0] - stride[0] + 1
# dz_width = z_width * stride[1] - stride[1] + 1
# pos = _help(z_height, dz_width, stride)
# pos = []
# for j in range(z_height):
# for i in range(0, dz_width, stride[1]):
# pos.append(i + j * dz_width * stride[0])

# index_tensor = op.Constant(value_ints=pos)
# index_tensor = op.Reshape(index_tensor, z_shape)
# # this should not work because the kernel_shape is attribute
# dz = op.MaxUnpool(grad_output, index_tensor, kernel_shape=[dz_height - z_height + 1, dz_width - z_width + 1])

# # Computing padding size

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Copy link

codecov bot commented Jun 25, 2024

Codecov Report

Attention: Patch coverage is 86.95652% with 9 lines in your changes missing coverage. Please review.

Project coverage is 75.23%. Comparing base (c57e9e7) to head (1eb33c3).

Files Patch % Lines
onnxscript/tools/training_helper.py 79.16% 2 Missing and 3 partials ⚠️
onnxscript/function_libs/torch_lib/ops/core.py 91.11% 3 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1707      +/-   ##
==========================================
- Coverage   75.24%   75.23%   -0.01%     
==========================================
  Files         242      242              
  Lines       25861    25923      +62     
  Branches     4660     4671      +11     
==========================================
+ Hits        19458    19504      +46     
- Misses       5517     5528      +11     
- Partials      886      891       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@xiaowuhu xiaowuhu assigned xadupre and fatcat-z and unassigned xiaowuhu Jun 25, 2024
Copy link

github-actions bot commented Jun 25, 2024

Test Results

     26 files   -       1      26 suites   - 1   2h 27m 9s ⏱️ - 56m 19s
 11 805 tests +  3 446   9 556 ✅ + 2 688    2 247 💤 +    757  1 ❌ +1  1 🔥 ±0 
381 311 runs   - 142 687  83 485 ✅  - 22 528  297 818 💤  - 120 166  7 ❌ +7  1 🔥 ±0 

For more details on these failures and errors, see this check.

Results for commit 1eb33c3. ± Comparison against base commit c57e9e7.

This pull request skips 1 test.
docs.test.test_documentation_examples.TestDocumentationExample ‑ test_documentation_examples

♻️ This comment has been updated with latest results.

@xadupre
Copy link
Member

xadupre commented Jun 25, 2024

Is it possible to add a unit test?

Comment on lines +25 to +33
def train_loop(
model: Any,
*args,
loss_fn: Any | None = None,
optimizer: Any | None = None,
dump_onnx_models: bool = False,
dump_prefix: str = "dump_train_loop",
dump_clean_first: bool = True,
) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]:

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

train_loop returns
tuple of size 2
and
tuple of size 3
.
@xiaowuhu
Copy link
Contributor Author

xiaowuhu commented Jul 3, 2024

Is it possible to add a unit test?

Added.

xiaowuhu added a commit that referenced this pull request Jul 4, 2024
Depends on #1707, will add unit test after #1707 merged.
@xiaowuhu xiaowuhu requested a review from justinchuby July 4, 2024 05:21
@xiaowuhu xiaowuhu requested a review from justinchuby July 4, 2024 06:32

class TestBackward(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@unittest.skipIf(not has_transformers(), reason="transformers is missing")

Comment on lines +12 to +13
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama

I wonder why ruff doesn't warn the unused imports

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

4 participants