Skip to content

Commit

Permalink
[XLA:GPU] Expose a flag to allow disabling region-based live range an…
Browse files Browse the repository at this point in the history
…alysis.

PiperOrigin-RevId: 553733609
  • Loading branch information
bchetioui authored and tensorflower-gardener committed Aug 4, 2023
1 parent 144e78d commit a001c92
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
9 changes: 9 additions & 0 deletions tensorflow/compiler/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0);
opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1);

opts.set_xla_gpu_copy_insertion_use_region_analysis(true);
return opts;
}

Expand Down Expand Up @@ -1189,6 +1191,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Dumps autotuned Triton fusions to the directory specified by "
"xla_dump_to or stdout. Each fusion is dumped only once, as an optimized "
"HLO."));
flag_list->push_back(tsl::Flag(
"xla_gpu_copy_insertion_use_region_analysis",
bool_setter_for(
&DebugOptions::set_xla_gpu_copy_insertion_use_region_analysis),
debug_options->xla_gpu_copy_insertion_use_region_analysis(),
"If true, use the new fine-grain region-based live range interference"
" analysis in the copy insertion optimization pass."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
13 changes: 10 additions & 3 deletions tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,8 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module,
// Modifies the given HLO module so that it will be accepted by IrEmitter.
// Unlike optimization passes, the passes are necessary for correctness.
Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
const DebugOptions& debug_options = hlo_module->config().debug_options();

// In some cases, we have to place the result of an instruction in a temporary
// buffer. For instance, the buffer that holds an external parameter is
// assumed immutable at this point, and should not be reused for output
Expand All @@ -880,9 +882,14 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
}
pipeline.AddPass<LoopScheduleLinearizer>(GetCanShareBuffer());

constexpr int64_t kNoRegionBasedLiveRangeAnalysisLimit = -1;
pipeline.AddPass<CopyInsertion>(GetCanShareBuffer(),
kNoRegionBasedLiveRangeAnalysisLimit);
if (debug_options.xla_gpu_copy_insertion_use_region_analysis()) {
constexpr int64_t kNoRegionBasedLiveRangeAnalysisLimit = -1;
pipeline.AddPass<CopyInsertion>(GetCanShareBuffer(),
kNoRegionBasedLiveRangeAnalysisLimit);
} else {
pipeline.AddPass<CopyInsertion>(GetCanShareBuffer());
}

// We are using a sub-pipeline here, so that the verifier only runs after both
// GpuHorizontalLoopFusion and HloDCE.
auto& sub_pipeline =
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,9 @@ message DebugOptions {

bool xla_gpu_dump_autotuned_triton_fusions = 232;

// Next id: 236
bool xla_gpu_copy_insertion_use_region_analysis = 236;

// Next id: 237

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit a001c92

Please sign in to comment.