Skip to content

Commit

Permalink
PR tensorflow#9150: [NVIDIA XLA GPU] Make CUTLASS gemmDus fusion acce…
Browse files Browse the repository at this point in the history
…pt optional bitcast

Imported from GitHub PR openxla/xla#9150

The CUTLASS gemmDus fusion used to have intermediate bitcast as a required node. Removing this constraint since in some cases a 2-d gemm will directly update a 2-d weight.
Copybara import of the project:

--
610efc47b040a3ce9d1a2a2ec5fad8a5688cb172 by TJ <[email protected]>:

cutlass gemm dus fusion supports optinal bitcast

Merging this change closes tensorflow#9150

PiperOrigin-RevId: 604573353
  • Loading branch information
Tixxx authored and tensorflower-gardener committed Feb 6, 2024
1 parent 6e1fb80 commit e382119
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 7 deletions.
28 changes: 21 additions & 7 deletions third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ struct GemmWithDynamicSlice {
explicit GemmWithDynamicSlice(HloDynamicUpdateSliceInstruction* update_slice)
: update_slice(update_slice) {}

std::vector<HloInstruction*> Instrs() { return {dot, bitcast, update_slice}; }
std::vector<HloInstruction*> Instrs() {
// Bitcast could be optional
if (bitcast == nullptr) {
return {dot, update_slice};
}
return {dot, bitcast, update_slice};
}

HloInstruction* dot = nullptr;
HloInstruction* bitcast = nullptr; // result bitcast
Expand Down Expand Up @@ -152,14 +158,20 @@ static absl::StatusOr<GemmWithUpcast> MatchGemmWithUpcast(
return absl::InternalError("unsupported gemm with upcasing");
}

template <typename Pattern>
auto OptionalBitcast(HloInstruction** optional_bitcast, Pattern pattern) {
return m::AnyOf<HloInstruction>(m::Bitcast(optional_bitcast, pattern),
std::move(pattern));
}

// Returns matched GEMM with result used to update a slice.
static absl::StatusOr<GemmWithDynamicSlice> MatchGemmWithDynamicUpdateSlice(
HloDynamicUpdateSliceInstruction* update_slice) {
GemmWithDynamicSlice match(update_slice);

if (!Match(
const_cast<HloInstruction*>(update_slice->operand(1)),
m::Bitcast(&match.bitcast, m::Dot(&match.dot, m::Op(), m::Op())))) {
if (!Match(const_cast<HloInstruction*>(update_slice->operand(1)),
OptionalBitcast(&match.bitcast,
m::Dot(&match.dot, m::Op(), m::Op())))) {
return absl::InternalError("failed to match update slice instr");
}

Expand Down Expand Up @@ -204,9 +216,12 @@ CutlassGemmWithDynamicUpdateSlicePattern::TryMatch(
match.AddReplacement(matched->dot, [=](HloFusionInstruction* fusion) {
HloComputation* parent = fusion->parent();
auto* dus = Cast<HloDynamicUpdateSliceInstruction>(matched->update_slice);
bool has_bitcast = matched->bitcast != nullptr;
const Shape dus_shape =
has_bitcast ? matched->bitcast->shape() : matched->dot->shape();
auto* slice = parent->AddInstruction(HloInstruction::CreateDynamicSlice(
matched->bitcast->shape(), fusion, dus->index_operands(),
matched->bitcast->shape().dimensions()));
dus_shape, fusion, dus->index_operands(), dus_shape.dimensions()));

return parent->AddInstruction(
HloInstruction::CreateBitcast(matched->dot->shape(), slice));
});
Expand Down Expand Up @@ -337,7 +352,6 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion {
// Mapping to a buffer that holds output slice offset.
auto* offset =
Cast<HloParameterInstruction>(matched.update_slice->operand(2));

kernel::gemm_universal::DynamicSliceIndices slices;
slices.out = offset->parameter_number();

Expand Down
114 changes: 114 additions & 0 deletions third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,52 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) {
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) {
const char* hlo = R"(
HloModule test
ENTRY %main (p0: f32[4,2], p1: f32[2,2], i: s32[]) -> f32[4,2] {
%p0 = f32[4,2]{1,0} parameter(0)
%p1 = f32[2,2]{1,0} parameter(1)
%i = s32[] parameter(2)
%dot = f32[2,2]{1,0} dot(%p1, %p1),
lhs_contracting_dims={1},
rhs_contracting_dims={0}
ROOT %r = f32[4,2]{1,0} dynamic-update-slice(%p0, %dot, %i, %i)
}
)";

const char* expected = R"(
; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} {
; CHECK-DAG: [[P1:%[^ ]+]] = f32[4,2]{1,0} parameter
; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter
; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]])
; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter
; CHECK: ROOT [[DUS:%[^ ]+]] = f32[4,2]{1,0} dynamic-update-slice([[P1]], [[DOT]], [[P2]], [[P2]])
; CHECK: }
; CHECK: ENTRY %main {{.*}} {
; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[4,2]{1,0} fusion
; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice,
; CHECK: backend_config={
; CHECK: "kind":"__custom_fusion",
; CHECK: "custom_fusion_config":{
; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice"
; CHECK: }
; CHECK: }
; CHECK: }
)";

CustomKernelFusionPatternRegistry patterns;
patterns.Emplace<CutlassGemmWithDynamicUpdateSlicePattern>();

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
CustomKernelFusionRewriter pass(&device, &patterns);
RunAndFilecheckHloRewrite(hlo, std::move(pass), expected);
}

//===----------------------------------------------------------------------===//
// Run And Compare Tests
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -373,4 +419,72 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) {
/*run_hlo_passes=*/false));
}

TEST_F(CutlassFusionTest,
RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) {
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

const char* hlo_text_cublas = R"(
HloModule cublas
ENTRY e {
p0 = bf16[16,8]{1,0} parameter(0)
p1 = bf16[8,8]{1,0} parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
gemm.tuple = (bf16[8,8]{1,0}, s8[0]{0}) custom-call(p1, p1),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}}
gemm = bf16[8,8]{1,0} get-tuple-element(gemm.tuple), index=0
ROOT r = bf16[16,8]{1,0} dynamic-update-slice(p0, gemm, p2, p3)
}
)";

const char* hlo_text_custom_fusion = R"(
HloModule cutlass
cutlass_gemm {
p0.1 = bf16[8,8]{1,0} parameter(0)
p1.1 = bf16[16,8]{1,0} parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
dot.1 = bf16[8,8]{1,0} dot(p0.1, p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
r.1 = bf16[16,8]{1,0} dynamic-update-slice(p1.1, dot.1, p2, p3)
workspace = u8[1024]{0} custom-call(),
custom_call_target="__custom_kernel_fusion$workspace",
api_version=API_VERSION_TYPED_FFI
ROOT tuple = (bf16[16,8]{1,0}, u8[1024]{0}) tuple(r.1, workspace)
}
ENTRY e {
p0 = bf16[16,8]{1,0} parameter(0)
p1 = bf16[8,8]{1,0} parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
r.0 = (bf16[16,8]{1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom,
calls=%cutlass_gemm,
backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}}
ROOT %get-tuple-element = bf16[16,8]{1,0} get-tuple-element(r.0), index=0
})";

Array2D<bfloat16> p0_arr(16, 8); // bf16[16,8]
Array2D<bfloat16> p1_arr(8, 8); // bf16[8,8]
p1_arr.Each([](int64_t i, int64_t j, bfloat16* out) {
*out = bfloat16{1.0f * i * j};
});

Array<int32_t> p2_arr({}, 0);
Array<int32_t> p3_arr({}, 1);

auto p0 = LiteralUtil::CreateFromArray(p0_arr);
auto p1 = LiteralUtil::CreateFromArray(p1_arr);
auto p2 = LiteralUtil::CreateFromArray(p2_arr);
auto p3 = LiteralUtil::CreateFromArray(p3_arr);

EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion,
{&p0, &p1, &p2, &p3}, error_spec,
/*run_hlo_passes=*/false));
}

} // namespace xla::gpu

0 comments on commit e382119

Please sign in to comment.