Skip to content

Commit

Permalink
PR tensorflow#15919: [GPU] Use CuDnnThunk for FMHA.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#15919

CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code.

The overview of the change is:
 - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using their order in HLO to match the CuDnnThunk calling convention instead of using custom constants.
 - cuda_dnn.h/cc: Move dropout seed / offset / increment to the CudnnGraph properties and handle them accordingly during graph execution.
 - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let it set workspace as it did before + compile and serialize graphs just like cudnn_fusion_compiler aiming CuDnnThunks already does.
 - Move the remainders of the MHA config / descriptor logic to cudnn_custom_call_compiler from the deleted fused_mha_runner.
 - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks for MHA custom calls the same universal way that works for cuDNN GEMM fusions.
 - Delete no more necessary special thunks, runners, lazy ops, command buffer commands.
Copybara import of the project:

--
5d5b046a6ee8771e33b6c6b0f41d380205277129 by Ilia Sergachev <[email protected]>:

[GPU] Use CuDnnThunk for FMHA.

CuDnnThunk currently used for GEMM fusions is capable of executing
arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of
specialized runtime code.

The overview of the change is:
 - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using
their order in HLO to match the CuDnnThunk calling convention instead of
using custom constants.
 - cuda_dnn.h/cc: Move dropout seed / offset / increment to the
CudnnGraph properties and handle them accordingly during graph
execution.
 - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let
it set workspace as it dif before + compile and serialize graphs just
like cudnn_fusion_compiler aiming CuDnnThunks already does.
 - Move the remainders of the MHA config / descriptor logic to
cudnn_custom_call_compiler from the deleted fused_mha_runner.
 - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks
for MHA custom calls the same universal way that works for cuDNN GEMM
fusions.
 - Delete no more necessary special thunks, runners, lazy ops, command
buffer commands.

Merging this change closes tensorflow#15919

PiperOrigin-RevId: 661991045
  • Loading branch information
sergachev authored and tensorflower-gardener committed Aug 12, 2024
1 parent 71b9752 commit b12aa80
Show file tree
Hide file tree
Showing 23 changed files with 842 additions and 2,703 deletions.
30 changes: 2 additions & 28 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ cc_library(
":execution_stream_assignment",
":gpu_asm_opts_util",
":gpu_conv_runner",
":gpu_fused_mha_runner",
":gpu_norm_runner",
":hlo_fusion_analysis",
":ir_emission_utils",
Expand Down Expand Up @@ -356,9 +355,9 @@ cc_library(
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:convolution_thunk",
"//xla/service/gpu/runtime:copy_thunk",
"//xla/service/gpu/runtime:cudnn_thunk",
"//xla/service/gpu/runtime:custom_call_thunk",
"//xla/service/gpu/runtime:fft_thunk",
"//xla/service/gpu/runtime:fused_mha_thunk",
"//xla/service/gpu/runtime:gemm_thunk",
"//xla/service/gpu/runtime:gpublas_lt_matmul_thunk",
"//xla/service/gpu/runtime:infeed_thunk",
Expand Down Expand Up @@ -1045,31 +1044,6 @@ cc_library(
]),
)

cc_library(
name = "gpu_fused_mha_runner",
srcs = ["gpu_fused_mha_runner.cc"],
hdrs = ["gpu_fused_mha_runner.h"],
deps = [
":backend_configs_cc",
":cublas_cudnn",
":stream_executor_util",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/stream_executor",
"//xla/stream_executor:dnn",
"//xla/stream_executor:lazy_op_runner",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "cusolver_context",
srcs = if_gpu_is_configured(["cusolver_context.cc"]),
Expand Down Expand Up @@ -1779,6 +1753,7 @@ cc_library(
"//xla/service/gpu/transforms:conv_padding_legalization",
"//xla/service/gpu/transforms:conv_rewriter",
"//xla/service/gpu/transforms:cublas_pad_for_gemms",
"//xla/service/gpu/transforms:cudnn_custom_call_compiler",
"//xla/service/gpu/transforms:cudnn_fused_conv_rewriter",
"//xla/service/gpu/transforms:cudnn_fused_mha_rewriter",
"//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion",
Expand All @@ -1787,7 +1762,6 @@ cc_library(
"//xla/service/gpu/transforms:cudnn_pad_for_convolutions",
"//xla/service/gpu/transforms:cudnn_simplify_padding",
"//xla/service/gpu/transforms:cudnn_vectorize_convolutions",
"//xla/service/gpu/transforms:cudnn_workspace_rewriter",
"//xla/service/gpu/transforms:dot_sparsity_rewriter",
"//xla/service/gpu/transforms:gpusolver_rewriter",
"//xla/service/gpu/transforms:sort_rewriter",
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2150,8 +2150,8 @@ absl::StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
}};
BinaryMap dnn_compiled_graphs;
if (stream_exec) {
TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec,
&dnn_compiled_graphs));
TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec,
&dnn_compiled_graphs));
}

const DebugOptions& debug_opts = module->config().debug_options();
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/gpu_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ class GpuCompiler : public LLVMCompiler {
return absl::OkStatus();
}

// Runs cuDNN fusion compiler pass.
virtual absl::Status RunCudnnFusionCompilerPass(
HloModule* module, se::StreamExecutor* stream_exec,
BinaryMap* dnn_compiled_graphs) {
// Runs cuDNN fusion and custom call compiler passes.
virtual absl::Status RunCudnnCompilerPasses(HloModule* module,
se::StreamExecutor* stream_exec,
BinaryMap* dnn_compiled_graphs) {
return absl::OkStatus();
}

Expand Down
Loading

0 comments on commit b12aa80

Please sign in to comment.