diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc index 519b495e76be30..3ffd74e9e594b7 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc @@ -213,10 +213,13 @@ class GemmDimensionAdapter { return GemmDimensionAdapter{*dot, std::move(analysis)}; } - bool DimensionsAndStrides(const HloInstruction& hlo, - const TritonFusionAnalysis::Scope scope, - std::vector& dimensions, - std::vector& strides) { + struct Result { + std::vector sizes; + std::vector strides; + }; + + std::optional DimensionsAndStrides( + const HloInstruction& hlo, const TritonFusionAnalysis::Scope scope) { const DotDimensionNumbers& dims = dot_.dot_dimension_numbers(); // GEMM fusions require a specific canonical order of dimensions. constexpr int kBatchDimensionIndex = 0; @@ -253,29 +256,33 @@ class GemmDimensionAdapter { case TritonFusionAnalysis::Scope::META: LOG(FATAL) << "Unsupported scope."; } - dimensions.reserve(dim_indices.size()); - strides.reserve(dim_indices.size()); + + Result result; + result.sizes.reserve(dim_indices.size()); + result.strides.reserve(dim_indices.size()); + for (const int index : dim_indices) { const auto* spec = analysis_.IterSpec(scope, &hlo, index); if (spec == nullptr) { - dimensions.push_back(1); - strides.push_back(strides.empty() ? 1 : strides.back()); + result.sizes.push_back(1); + result.strides.push_back( + result.strides.empty() ? 1 : result.strides.back()); continue; } else { if (spec->size() == 1) { // The dimension is not split, nothing to do. } else if (spec->size() == 2) { if (FusionLevel(hlo) < 3) { - return false; + return std::nullopt; } if (!dims.lhs_batch_dimensions().empty()) { VLOG(8) << "Noncontracting dimension split is not compatible with " "batch dimensions."; - return false; + return std::nullopt; } if (index != lhs_noncontracting_index) { VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } switch (scope) { case TritonFusionAnalysis::Scope::LHS: @@ -285,40 +292,40 @@ class GemmDimensionAdapter { if (lhs_noncontracting_split_ != spec->back().count) { VLOG(8) << "Output non-contracting dimension has to be split " "the same way as the LHS input one if it is split."; - return false; + return std::nullopt; } break; default: VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } // Assign the major part of the noncontracting dimension to the // unused batch one. - CHECK_EQ(dimensions[kBatchDimensionIndex], 1); - dimensions[kBatchDimensionIndex] = spec->back().count; - strides[kBatchDimensionIndex] = spec->back().stride; + CHECK_EQ(result.sizes[kBatchDimensionIndex], 1); + result.sizes[kBatchDimensionIndex] = spec->back().count; + result.strides[kBatchDimensionIndex] = spec->back().stride; } else { VLOG(8) << "The dimension is split multiple times."; - return false; + return std::nullopt; } - dimensions.push_back(spec->front().count); - strides.push_back(spec->front().stride); + result.sizes.push_back(spec->front().count); + result.strides.push_back(spec->front().stride); } } if (lhs_noncontracting_split_ > 1 && scope == TritonFusionAnalysis::Scope::OUTPUT && - dimensions[kBatchDimensionIndex] == 1) { + result.sizes[kBatchDimensionIndex] == 1) { // LHS input noncontracting dimension is split but the corresponding // output one is not. Assign part of the output one to the unused batch // dimension. - dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_; - dimensions[kOutputLHSNonContractingDimensionIndex] /= + result.sizes[kBatchDimensionIndex] = lhs_noncontracting_split_; + result.sizes[kOutputLHSNonContractingDimensionIndex] /= lhs_noncontracting_split_; - strides[kBatchDimensionIndex] = - strides[kOutputLHSNonContractingDimensionIndex] * - dimensions[kOutputLHSNonContractingDimensionIndex]; + result.strides[kBatchDimensionIndex] = + result.strides[kOutputLHSNonContractingDimensionIndex] * + result.sizes[kOutputLHSNonContractingDimensionIndex]; } - return true; + return result; } private: @@ -397,8 +404,7 @@ absl::StatusOr> HloFusionToCuDnnGraph( return std::nullopt; } auto add_parameter = [&](const HloInstruction& parameter, - std::vector& dimensions, - std::vector strides) { + const GemmDimensionAdapter::Result& dims) { const std::optional data_type = ToCudnnDataType(parameter.shape().element_type()); if (!data_type.has_value()) { @@ -407,8 +413,8 @@ absl::StatusOr> HloFusionToCuDnnGraph( } hlo_to_cudnn[¶meter] = graph.tensor( graph::Tensor_attributes() - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims.sizes) + .set_stride(dims.strides) .set_data_type(*data_type) .set_name(std::string(parameter.name())) .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number()))); @@ -419,14 +425,13 @@ absl::StatusOr> HloFusionToCuDnnGraph( TritonFusionAnalysis::Scope::OUTPUT}) { for (const HloInstruction* parameter : adapter->analysis_.ScopeParameters(scope)) { - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*parameter, scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } - if (!add_parameter(*parameter, dimensions, strides)) { + if (!add_parameter(*parameter, *dims)) { return std::nullopt; } } @@ -507,19 +512,19 @@ absl::StatusOr> HloFusionToCuDnnGraph( // setting output of the unary shapes results in the rejection of // the cuDNN graph. if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) { - const auto scope = adapter->analysis_.QueryInstructionScope(*hlo); - std::vector dimensions; - std::vector strides; + const std::optional scope = + adapter->analysis_.QueryInstructionScope(*hlo); if (!scope.has_value()) { LOG(FATAL) << "No scope for instruction: " << hlo->ToShortString(); } - if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*hlo, *scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported hlo for querying dimensions: " << hlo->ToShortString(); } else { - hlo_to_cudnn[hlo]->set_dim(dimensions); + hlo_to_cudnn[hlo]->set_dim(dims->sizes); } } } else if (hlo->operand_count() == 2) { @@ -563,17 +568,17 @@ absl::StatusOr> HloFusionToCuDnnGraph( if (instructions.back()->shape().IsTuple()) { output = instructions.back()->operand(0); } - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides( - *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*output, + TritonFusionAnalysis::Scope::OUTPUT); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } hlo_to_cudnn[output] ->set_output(true) - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims->sizes) + .set_stride(dims->strides) .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count())); if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) { json dump;