From a001c9279bee80f6a54dda3bddffe805e6c229eb Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 4 Aug 2023 01:55:31 -0700 Subject: [PATCH] [XLA:GPU] Expose a flag to allow disabling region-based live range analysis. PiperOrigin-RevId: 553733609 --- tensorflow/compiler/xla/debug_options_flags.cc | 9 +++++++++ tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 13 ++++++++++--- tensorflow/compiler/xla/xla.proto | 4 +++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 2e871ff51ed681..94d577a9380100 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -170,6 +170,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); + + opts.set_xla_gpu_copy_insertion_use_region_analysis(true); return opts; } @@ -1189,6 +1191,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Dumps autotuned Triton fusions to the directory specified by " "xla_dump_to or stdout. Each fusion is dumped only once, as an optimized " "HLO.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_copy_insertion_use_region_analysis", + bool_setter_for( + &DebugOptions::set_xla_gpu_copy_insertion_use_region_analysis), + debug_options->xla_gpu_copy_insertion_use_region_analysis(), + "If true, use the new fine-grain region-based live range interference" + " analysis in the copy insertion optimization pass.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 72aa066be84475..d0f0233e40fc44 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -856,6 +856,8 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { + const DebugOptions& debug_options = hlo_module->config().debug_options(); + // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -880,9 +882,14 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { } pipeline.AddPass(GetCanShareBuffer()); - constexpr int64_t kNoRegionBasedLiveRangeAnalysisLimit = -1; - pipeline.AddPass(GetCanShareBuffer(), - kNoRegionBasedLiveRangeAnalysisLimit); + if (debug_options.xla_gpu_copy_insertion_use_region_analysis()) { + constexpr int64_t kNoRegionBasedLiveRangeAnalysisLimit = -1; + pipeline.AddPass(GetCanShareBuffer(), + kNoRegionBasedLiveRangeAnalysisLimit); + } else { + pipeline.AddPass(GetCanShareBuffer()); + } + // We are using a sub-pipeline here, so that the verifier only runs after both // GpuHorizontalLoopFusion and HloDCE. auto& sub_pipeline = diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 7e60a3ffb03f15..d7516ff28b693a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -595,7 +595,9 @@ message DebugOptions { bool xla_gpu_dump_autotuned_triton_fusions = 232; - // Next id: 236 + bool xla_gpu_copy_insertion_use_region_analysis = 236; + + // Next id: 237 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.