Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR tensorflow#15919: [GPU] Use CuDnnThunk for FMHA.
Imported from GitHub PR openxla/xla#15919 CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code. The overview of the change is: - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using their order in HLO to match the CuDnnThunk calling convention instead of using custom constants. - cuda_dnn.h/cc: Move dropout seed / offset / increment to the CudnnGraph properties and handle them accordingly during graph execution. - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let it set workspace as it did before + compile and serialize graphs just like cudnn_fusion_compiler aiming CuDnnThunks already does. - Move the remainders of the MHA config / descriptor logic to cudnn_custom_call_compiler from the deleted fused_mha_runner. - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks for MHA custom calls the same universal way that works for cuDNN GEMM fusions. - Delete no more necessary special thunks, runners, lazy ops, command buffer commands. Copybara import of the project: -- 5d5b046a6ee8771e33b6c6b0f41d380205277129 by Ilia Sergachev <[email protected]>: [GPU] Use CuDnnThunk for FMHA. CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code. The overview of the change is: - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using their order in HLO to match the CuDnnThunk calling convention instead of using custom constants. - cuda_dnn.h/cc: Move dropout seed / offset / increment to the CudnnGraph properties and handle them accordingly during graph execution. - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let it set workspace as it dif before + compile and serialize graphs just like cudnn_fusion_compiler aiming CuDnnThunks already does. - Move the remainders of the MHA config / descriptor logic to cudnn_custom_call_compiler from the deleted fused_mha_runner. - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks for MHA custom calls the same universal way that works for cuDNN GEMM fusions. - Delete no more necessary special thunks, runners, lazy ops, command buffer commands. Merging this change closes tensorflow#15919 PiperOrigin-RevId: 661991045
- Loading branch information