Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ATen fallback for scaled_dot_product_attention #21107

Merged
merged 35 commits into from
Jul 22, 2024

Conversation

prathikr
Copy link
Contributor

@prathikr prathikr commented Jun 19, 2024

Description

Introduces an ATen fallback for torch.nn.functional.scaled_dot_product_attention. This operator was introduced in torch 2.0 and, since then, has had many updates including the implementation of memory efficient attention for V100 machines. The current torchscript exporter exports a subgraph for attention which does not provide the same memory savings that PyTorch's memory efficient attention kernel provides. Allowing fallback to PyTorch ATen op for attention helps mitigate memory spike issues for models leveraging memory efficient attention.

Motivation and Context

Memory issues arose when integrating ONNX Runtime Training with AML Stable Diffusion.

@prathikr prathikr requested review from pengwa and centwang July 12, 2024 21:32
@prathikr prathikr requested a review from pengwa July 17, 2024 18:37
@prathikr prathikr merged commit 11ad299 into main Jul 22, 2024
93 of 98 checks passed
@prathikr prathikr deleted the prathikrao/attn-aten-fallback branch July 22, 2024 23:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants