Skip to content

Commit

Permalink
PR tensorflow#15934: [GPU][NFC] Refactor cuDNN fusion compiler.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#15934

The added structure Result will be used to add support of slicing.
Copybara import of the project:

--
f29f47186debf3aa3dc63b9717e23d35607cc000 by Ilia Sergachev <[email protected]>:

[GPU][NFC] Refactor cuDNN fusion compiler.

The added structure Result will be used to add support of slicing.

Merging this change closes tensorflow#15934

PiperOrigin-RevId: 662822349
  • Loading branch information
sergachev authored and tensorflower-gardener committed Aug 14, 2024
1 parent b90e9fc commit 660973f
Showing 1 changed file with 52 additions and 47 deletions.
99 changes: 52 additions & 47 deletions third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,13 @@ class GemmDimensionAdapter {
return GemmDimensionAdapter{*dot, std::move(analysis)};
}

bool DimensionsAndStrides(const HloInstruction& hlo,
const TritonFusionAnalysis::Scope scope,
std::vector<int64_t>& dimensions,
std::vector<int64_t>& strides) {
struct Result {
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
};

std::optional<Result> DimensionsAndStrides(
const HloInstruction& hlo, const TritonFusionAnalysis::Scope scope) {
const DotDimensionNumbers& dims = dot_.dot_dimension_numbers();
// GEMM fusions require a specific canonical order of dimensions.
constexpr int kBatchDimensionIndex = 0;
Expand Down Expand Up @@ -253,29 +256,33 @@ class GemmDimensionAdapter {
case TritonFusionAnalysis::Scope::META:
LOG(FATAL) << "Unsupported scope.";
}
dimensions.reserve(dim_indices.size());
strides.reserve(dim_indices.size());

Result result;
result.sizes.reserve(dim_indices.size());
result.strides.reserve(dim_indices.size());

for (const int index : dim_indices) {
const auto* spec = analysis_.IterSpec(scope, &hlo, index);
if (spec == nullptr) {
dimensions.push_back(1);
strides.push_back(strides.empty() ? 1 : strides.back());
result.sizes.push_back(1);
result.strides.push_back(
result.strides.empty() ? 1 : result.strides.back());
continue;
} else {
if (spec->size() == 1) {
// The dimension is not split, nothing to do.
} else if (spec->size() == 2) {
if (FusionLevel(hlo) < 3) {
return false;
return std::nullopt;
}
if (!dims.lhs_batch_dimensions().empty()) {
VLOG(8) << "Noncontracting dimension split is not compatible with "
"batch dimensions.";
return false;
return std::nullopt;
}
if (index != lhs_noncontracting_index) {
VLOG(8) << "Only LHS noncontracting dimension can be split.";
return false;
return std::nullopt;
}
switch (scope) {
case TritonFusionAnalysis::Scope::LHS:
Expand All @@ -285,40 +292,40 @@ class GemmDimensionAdapter {
if (lhs_noncontracting_split_ != spec->back().count) {
VLOG(8) << "Output non-contracting dimension has to be split "
"the same way as the LHS input one if it is split.";
return false;
return std::nullopt;
}
break;
default:
VLOG(8) << "Only LHS noncontracting dimension can be split.";
return false;
return std::nullopt;
}
// Assign the major part of the noncontracting dimension to the
// unused batch one.
CHECK_EQ(dimensions[kBatchDimensionIndex], 1);
dimensions[kBatchDimensionIndex] = spec->back().count;
strides[kBatchDimensionIndex] = spec->back().stride;
CHECK_EQ(result.sizes[kBatchDimensionIndex], 1);
result.sizes[kBatchDimensionIndex] = spec->back().count;
result.strides[kBatchDimensionIndex] = spec->back().stride;
} else {
VLOG(8) << "The dimension is split multiple times.";
return false;
return std::nullopt;
}
dimensions.push_back(spec->front().count);
strides.push_back(spec->front().stride);
result.sizes.push_back(spec->front().count);
result.strides.push_back(spec->front().stride);
}
}
if (lhs_noncontracting_split_ > 1 &&
scope == TritonFusionAnalysis::Scope::OUTPUT &&
dimensions[kBatchDimensionIndex] == 1) {
result.sizes[kBatchDimensionIndex] == 1) {
// LHS input noncontracting dimension is split but the corresponding
// output one is not. Assign part of the output one to the unused batch
// dimension.
dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_;
dimensions[kOutputLHSNonContractingDimensionIndex] /=
result.sizes[kBatchDimensionIndex] = lhs_noncontracting_split_;
result.sizes[kOutputLHSNonContractingDimensionIndex] /=
lhs_noncontracting_split_;
strides[kBatchDimensionIndex] =
strides[kOutputLHSNonContractingDimensionIndex] *
dimensions[kOutputLHSNonContractingDimensionIndex];
result.strides[kBatchDimensionIndex] =
result.strides[kOutputLHSNonContractingDimensionIndex] *
result.sizes[kOutputLHSNonContractingDimensionIndex];
}
return true;
return result;
}

private:
Expand Down Expand Up @@ -397,8 +404,7 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
return std::nullopt;
}
auto add_parameter = [&](const HloInstruction& parameter,
std::vector<int64_t>& dimensions,
std::vector<int64_t> strides) {
const GemmDimensionAdapter::Result& dims) {
const std::optional<fe::DataType_t> data_type =
ToCudnnDataType(parameter.shape().element_type());
if (!data_type.has_value()) {
Expand All @@ -407,8 +413,8 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
}
hlo_to_cudnn[&parameter] = graph.tensor(
graph::Tensor_attributes()
.set_dim(dimensions)
.set_stride(strides)
.set_dim(dims.sizes)
.set_stride(dims.strides)
.set_data_type(*data_type)
.set_name(std::string(parameter.name()))
.set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number())));
Expand All @@ -419,14 +425,13 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
TritonFusionAnalysis::Scope::OUTPUT}) {
for (const HloInstruction* parameter :
adapter->analysis_.ScopeParameters(scope)) {
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions,
strides)) {
const std::optional<GemmDimensionAdapter::Result> dims =
adapter->DimensionsAndStrides(*parameter, scope);
if (!dims.has_value()) {
VLOG(3) << "Unsupported dimensions.";
return std::nullopt;
}
if (!add_parameter(*parameter, dimensions, strides)) {
if (!add_parameter(*parameter, *dims)) {
return std::nullopt;
}
}
Expand Down Expand Up @@ -507,19 +512,19 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
// setting output of the unary shapes results in the rejection of
// the cuDNN graph.
if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
const std::optional<TritonFusionAnalysis::Scope> scope =
adapter->analysis_.QueryInstructionScope(*hlo);
if (!scope.has_value()) {
LOG(FATAL) << "No scope for instruction: "
<< hlo->ToShortString();
}
if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
strides)) {
const std::optional<GemmDimensionAdapter::Result> dims =
adapter->DimensionsAndStrides(*hlo, *scope);
if (!dims.has_value()) {
VLOG(3) << "Unsupported hlo for querying dimensions: "
<< hlo->ToShortString();
} else {
hlo_to_cudnn[hlo]->set_dim(dimensions);
hlo_to_cudnn[hlo]->set_dim(dims->sizes);
}
}
} else if (hlo->operand_count() == 2) {
Expand Down Expand Up @@ -563,17 +568,17 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
if (instructions.back()->shape().IsTuple()) {
output = instructions.back()->operand(0);
}
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
if (!adapter->DimensionsAndStrides(
*output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) {
const std::optional<GemmDimensionAdapter::Result> dims =
adapter->DimensionsAndStrides(*output,
TritonFusionAnalysis::Scope::OUTPUT);
if (!dims.has_value()) {
VLOG(3) << "Unsupported dimensions.";
return std::nullopt;
}
hlo_to_cudnn[output]
->set_output(true)
.set_dim(dimensions)
.set_stride(strides)
.set_dim(dims->sizes)
.set_stride(dims->strides)
.set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count()));
if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) {
json dump;
Expand Down

0 comments on commit 660973f

Please sign in to comment.