Skip to content

Commit

Permalink
Support transposes with more than 3 dimensions.
Browse files Browse the repository at this point in the history
We tile the most minor dimension of the transpose operand and the dimension
that becomes the most minor dimension in the output. This was already the
case for the transposes we supported, but some code assumed that the transpose
will be a direct swap. This change removes this assumption.

PiperOrigin-RevId: 665225200
  • Loading branch information
akuegel authored and tensorflower-gardener committed Aug 20, 2024
1 parent fe13321 commit 1ba6635
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 32 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ xla_test(
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
22 changes: 16 additions & 6 deletions third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis)
block_sizes_.back() = block_size_;
block_sizes_[permutation_.back()] = block_size_;
}
output_block_sizes_ = Permute(block_sizes_, permutation_);
block_counts_.resize(block_sizes_.size());
for (int64_t i = 0; i < block_sizes_.size(); ++i) {
block_counts_[i] = CeilOfRatio(input_shape_[i], block_sizes_[i]);
Expand Down Expand Up @@ -198,9 +199,16 @@ LaunchDimensions MlirTransposeFusion::launch_dimensions() const {

IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing(
bool read, mlir::MLIRContext* ctx) const {
auto thread_offsets = GetThreadOffsets(ctx);
auto thread_offsets = GetThreadOffsets(/*read=*/true, ctx);
if (!read) {
absl::c_copy(Permute(thread_offsets, permutation_), thread_offsets.begin());
// Regarding shared memory indexing, the permutation we need to apply is
// just a swap of the two dimensions that are tiled.
if (MostMinorDimensionUnchanged()) {
std::swap(thread_offsets[thread_offsets.size() - 2],
thread_offsets[permutation_[permutation_.size() - 2]]);
} else {
std::swap(thread_offsets.back(), thread_offsets[permutation_.back()]);
}
}
std::vector<int64_t> dim_var_sizes(6, 1);
dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] =
Expand Down Expand Up @@ -395,7 +403,7 @@ absl::Status MlirTransposeFusion::EmitEntryFunction(
}

llvm::SmallVector<mlir::AffineExpr, 4> MlirTransposeFusion::GetThreadOffsets(
mlir::MLIRContext* ctx) const {
bool read, mlir::MLIRContext* ctx) const {
auto thread = mlir::getAffineDimExpr(
KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx);
auto loop = mlir::getAffineSymbolExpr(0, ctx);
Expand All @@ -406,7 +414,8 @@ llvm::SmallVector<mlir::AffineExpr, 4> MlirTransposeFusion::GetThreadOffsets(
auto minor_dim = mlir::getAffineSymbolExpr(2, ctx);
linear_index = linear_index * input_shape_.back() + minor_dim;
}
return DelinearizeInBoundsIndex(linear_index, block_sizes_);
return DelinearizeInBoundsIndex(linear_index,
read ? block_sizes_ : output_block_sizes_);
}

IndexingMap MlirTransposeFusion::GetIndexing(bool input,
Expand All @@ -418,10 +427,11 @@ IndexingMap MlirTransposeFusion::GetIndexing(bool input,
if (!input) {
absl::c_copy(Permute(block_ids, permutation_), block_ids.begin());
}
auto thread_offsets = GetThreadOffsets(ctx);
auto thread_offsets = GetThreadOffsets(input, ctx);
const auto& permuted_block_sizes = input ? block_sizes_ : output_block_sizes_;
llvm::SmallVector<AffineExpr, 3> offsets;
for (auto [block_id, block_size, thread] :
llvm::zip(block_ids, block_sizes_, thread_offsets)) {
llvm::zip(block_ids, permuted_block_sizes, thread_offsets)) {
offsets.push_back(block_id * block_size + thread);
}
std::vector<int64_t> dim_var_sizes(6, 1);
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/transpose_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ class MlirTransposeFusion : public MlirFusionEmitterBase {
mlir::MLIRContext* ctx) const;
IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const;
llvm::SmallVector<mlir::AffineExpr, 4> GetThreadOffsets(
mlir::MLIRContext* ctx) const;
bool read, mlir::MLIRContext* ctx) const;
bool MostMinorDimensionUnchanged() const;

TransposeDescription transpose_;
absl::InlinedVector<int64_t, 3> permutation_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> block_sizes_; // In input elements.
std::vector<int64_t> output_block_sizes_;
std::vector<int64_t> block_counts_;
int vector_size_;
int block_size_;
Expand Down
87 changes: 87 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mlir/IR/MLIRContext.h"
#include "xla/error_spec.h"
#include "xla/service/gpu/fusions/mlir_emitter_test_base.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
Expand Down Expand Up @@ -144,6 +145,72 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201_SimplifiedTo021) {
)"));
}

TEST_F(MlirTransposeFusionTest, Transpose_ThreadIndexing1302) {
auto kHloString = R"(
HloModule Transpose
%fused_computation {
%param_0 = f32[19, 16, 16, 144] parameter(0)
ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0),
dimensions={1,3,0,2}
}
ENTRY main {
%param = f32[19, 16, 16, 144] parameter(0)
ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput,
calls=%fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
auto* root = module->entry_computation()->root_instruction();
auto analysis = HloFusionAnalysis::Create(*root, device_info_);

MlirTransposeFusion fusion(analysis);
EXPECT_THAT(
fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1] -> (
d3 floordiv 80,
(d3 floordiv 5) mod 16,
d0 floordiv 32 + s0 * 4,
(d3 mod 5) * 32 + d0 mod 32
)
domain:
d0 in [0, 127]
d1 in [0, 0]
d2 in [0, 0]
d3 in [0, 1519]
d4 in [0, 0]
d5 in [0, 0]
s0 in [0, 3]
s1 in [0, 0]
(d3 mod 5) * 32 + d0 mod 32 in [0, 143]
)"));
EXPECT_THAT(
fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1] -> (
(d3 floordiv 5) mod 16,
(d3 mod 5) * 32 + s0 * 4 + d0 floordiv 32,
d3 floordiv 80,
d0 mod 32
)
domain:
d0 in [0, 127]
d1 in [0, 0]
d2 in [0, 0]
d3 in [0, 1519]
d4 in [0, 0]
d5 in [0, 0]
s0 in [0, 7]
s1 in [0, 0]
(d3 mod 5) * 8 + s0 in [0, 35]
d0 mod 32 in [0, 15]
)"));
}

TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule module
Expand Down Expand Up @@ -464,6 +531,26 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D) {
calls=%fused_computation
}
)";

TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}

TEST_F(MlirTransposeFusionTest, Transpose_4D) {
auto kHloString = R"(
HloModule Transpose
%fused_computation {
%param_0 = f32[19, 16, 16, 144] parameter(0)
ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0),
dimensions={1,3,0,2}
}
ENTRY main {
%param = f32[19, 16, 16, 144] parameter(0)
ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput,
calls=%fused_computation
}
)";
TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}
Expand Down
45 changes: 26 additions & 19 deletions third_party/xla/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,36 +598,43 @@ static std::optional<TransposeDescription> FindTiledLogicalTranspose(
// call GetNormalizedLogicalTransposeShape here.
absl::InlinedVector<int64_t, 3> permutation(instr.dimensions().begin(),
instr.dimensions().end());
// A real transpose needs at least 2 transpose dimensions.
if (permutation.size() < 2) {
return std::nullopt;
}
absl::InlinedVector<int64_t, 3> dimensions(instr.shape().dimensions().begin(),
instr.shape().dimensions().end());
int64_t operand_most_minor_dim =
instr.operand(0)->shape().dimensions().back();
if (permutation == absl::InlinedVector<int64_t, 3>{0, 2, 1} ||
(IsMlirTransposeEmitterEnabled(instr) &&
permutation == absl::InlinedVector<int64_t, 3>{1, 0})) {
if ((dimensions[dimensions.size() - 2] >= kMinDimensionToTransposeTiled &&
dimensions.back() >= kMinDimensionToTransposeTiled) ||
(dimensions[dimensions.size() - 2] >= kMinDimensionToTransposeTiled2 &&
dimensions.back() >= kMinDimensionToTransposeTiled2 &&
dimensions[dimensions.size() - 2] * dimensions.back() >=
kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{&instr, dimensions, permutation};
}
} else if (permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}) {
if ((dimensions[0] >= kMinDimensionToTransposeTiled &&
dimensions[2] >= kMinDimensionToTransposeTiled) ||
(dimensions[0] >= kMinDimensionToTransposeTiled2 &&
dimensions[2] >= kMinDimensionToTransposeTiled2 &&
dimensions[0] * dimensions[2] >=
permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}) {
if ((dimensions.back() >= kMinDimensionToTransposeTiled &&
operand_most_minor_dim >= kMinDimensionToTransposeTiled) ||
(dimensions.back() >= kMinDimensionToTransposeTiled2 &&
operand_most_minor_dim >= kMinDimensionToTransposeTiled2 &&
dimensions.back() * operand_most_minor_dim >=
kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{&instr, dimensions, permutation};
}
} else if (IsMlirTransposeEmitterEnabled(instr)) {
if (permutation == absl::InlinedVector<int64_t, 3>{1, 0, 2}) {
if (permutation.back() == dimensions.size() - 1) {
operand_most_minor_dim =
instr.operand(0)->shape().dimensions(dimensions.size() - 2);
auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
if (byte_width * dimensions[2] <= kMaxBytesInMostMinorDimension &&
byte_width * dimensions[2] * std::min(dimensions[0], dimensions[1]) >=
if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension &&
byte_width * dimensions.back() *
std::min(operand_most_minor_dim,
dimensions[dimensions.size() - 2]) >=
kMinDimensionToTransposeTiled) {
return TransposeDescription{&instr, dimensions, permutation};
}
} else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled &&
dimensions.back() >= kMinDimensionToTransposeTiled) ||
(operand_most_minor_dim >= kMinDimensionToTransposeTiled2 &&
dimensions.back() >= kMinDimensionToTransposeTiled2 &&
operand_most_minor_dim * dimensions.back() >=
kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{&instr, dimensions, permutation};
}
}
return std::nullopt;
Expand Down
46 changes: 46 additions & 0 deletions third_party/xla/xla/service/gpu/ir_emission_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,52 @@ ENTRY entry {
EXPECT_FALSE(result.has_value());
}

TEST_F(IrEmissionUtilsTest, FindTiledLogical2103Transpose) {
const char* hlo = R"(
HloModule module
ENTRY entry {
p = f32[33,48,32,2]{3,2,1,0} parameter(0)
ROOT t = f32[32,48,33,2]{3,2,1,0} transpose(p), dimensions={2,1,0,3}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto& debug_options = module->mutable_config().mutable_debug_options();
debug_options.set_xla_gpu_mlir_emitter_level(3);

HloInstruction* tr = module->entry_computation()->root_instruction();

auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
EXPECT_EQ(result->dimensions, InlinedVector({32, 48, 33, 2}));
EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0, 3}));
}

TEST_F(IrEmissionUtilsTest, FindTiledLogical1320Transpose) {
const char* hlo = R"(
HloModule module
ENTRY entry {
p = f32[33,48,32,34]{3,2,1,0} parameter(0)
ROOT t = f32[48,34,32,33]{3,2,1,0} transpose(p), dimensions={1,3,2,0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto& debug_options = module->mutable_config().mutable_debug_options();
debug_options.set_xla_gpu_mlir_emitter_level(3);

HloInstruction* tr = module->entry_computation()->root_instruction();

auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
EXPECT_EQ(result->dimensions, InlinedVector({48, 34, 32, 33}));
EXPECT_EQ(result->permutation, InlinedVector({1, 3, 2, 0}));
}

TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) {
const char* hlo = R"(
HloModule module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ TEST_F(InstructionFusionTest,
TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0));
HloInstruction* transpose2 = builder.AddInstruction(
HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {}));
Shape operand_shape = ShapeUtil::MakeShape(F32, {64, 32});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, operand_shape, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(operand_shape, HloOpcode::kExp, param));
HloInstruction* transpose2 =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {32, 64}), exp1, {1, 0}));

auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
Expand Down

0 comments on commit 1ba6635

Please sign in to comment.