Skip to content

Commit

Permalink
[XLA:GPU] Enable --xla_gpu_enable_pipelined_{all_gather,all_reduce,re…
Browse files Browse the repository at this point in the history
…duce_scatter} by default.

PiperOrigin-RevId: 665219784
  • Loading branch information
golechwierowicz authored and tensorflower-gardener committed Aug 20, 2024
1 parent 14f860e commit fe13321
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 8 additions & 5 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_highest_priority_async_stream(true);

opts.set_xla_gpu_enable_pipelined_collectives(false);
opts.set_xla_gpu_enable_pipelined_all_reduce(false);
opts.set_xla_gpu_enable_pipelined_all_gather(false);
opts.set_xla_gpu_enable_pipelined_reduce_scatter(false);
opts.set_xla_gpu_enable_pipelined_all_reduce(true);
opts.set_xla_gpu_enable_pipelined_all_gather(true);
opts.set_xla_gpu_enable_pipelined_reduce_scatter(true);
opts.set_xla_gpu_enable_pipelined_p2p(false);

opts.set_xla_gpu_run_post_layout_collective_pipeliner(false);
Expand Down Expand Up @@ -1447,8 +1447,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_enable_pipelined_collectives",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_collectives),
debug_options->xla_gpu_enable_pipelined_collectives(),
"Enable pipelinling of collective instructions (all-reduce, all-gather, "
"and reduce-scatter)."));
"Enable pipelinling of collective instructions. It has the same effect "
"as setting xla_gpu_enable_pipelined_all_reduce, "
"xla_gpu_enable_pipelined_all_gather, "
"xla_gpu_enable_pipelined_reduce_scatter and "
"xla_gpu_enable_pipelined_p2p flags to true."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_pipelined_all_reduce",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_reduce),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2956,6 +2956,7 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) {

HloModuleConfig ref_config;
debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false);
ref_config.set_debug_options(debugoptions);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_ref, ref_config));
Expand All @@ -2965,6 +2966,7 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) {
HloModuleConfig opt_config;
debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
opt_config.set_debug_options(debugoptions);
debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false);
TF_ASSERT_OK_AND_ASSIGN(auto module_with_adddress_computation_flag,
ParseAndReturnVerifiedModule(hlo_ref, opt_config));
TF_ASSERT_OK_AND_ASSIGN(
Expand Down

0 comments on commit fe13321

Please sign in to comment.