Skip to content

Commit

Permalink
Enabled BF16 pattern and added a test
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud-abuzaina committed Oct 5, 2023
1 parent 6ed8e02 commit e67bb7f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
15 changes: 13 additions & 2 deletions third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ namespace cpu {
namespace {
namespace m = match;

auto ConvertPattern(HloInstruction** instr) {
return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16))
.WithElementType(PrimitiveType::F32);
}

HloInstruction* FindLayerNormScale(HloInstruction* instr) {
HloInstruction* scale = nullptr;
auto scalePattern = m::Multiply().WithBinaryOperandsAnyOrder(
Expand Down Expand Up @@ -205,8 +210,14 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor {
*actual);
})));

if (Match(slicesource2, empirical_expectations) && prod_l == prod_c &&
prod_c == prod_r && prod_l == prod_s) {
HloInstruction *src1, *src2;
if (Match(slicesource2, empirical_expectations) &&
// Float32 pattern check
((prod_l == prod_c && prod_c == prod_r && prod_l == prod_s) ||
// Bfloat16 pattern check
(prod_l == prod_c && prod_c == prod_r &&
Match(prod_l, ConvertPattern(&src1)) &&
Match(prod_s, ConvertPattern(&src2)) && src1 == src2))) {
HloInstruction* ln_call =
instr->AddInstruction(HloInstruction::CreateCustomCall(
prod_shape, {prod_r, scale, shift}, "__onednn$layernorm"));
Expand Down
74 changes: 74 additions & 0 deletions third_party/xla/xla/tests/onednn_layer_norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,79 @@ TEST_F(LayerNormTest, SimpleTest) {
)");
}

TEST_F(LayerNormTest, SimpleTestBF16) {
const char* layer_norm_module_str = R"(
HloModule layer_norm_bf16.test, entry_computation_layout={(f32[768]{0}, f32[768]{0}, bf16[16,128,768]{2,1,0})->bf16[16,128,768]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
region_0.16 {
Arg_0.17 = f32[] parameter(0)
Arg_1.18 = f32[] parameter(1)
ROOT add.19 = f32[] add(Arg_0.17, Arg_1.18)
}
ENTRY main.53 {
Arg_2.3 = bf16[16,128,768]{2,1,0} parameter(2), sharding={replicated}
convert.31 = f32[16,128,768]{2,1,0} convert(Arg_2.3)
convert.11 = f32[16,128,768]{2,1,0} convert(Arg_2.3)
reshape.13 = f32[1,16,128,768]{3,2,1,0} reshape(convert.11)
multiply.12 = f32[16,128,768]{2,1,0} multiply(convert.11, convert.11)
reshape.14 = f32[1,16,128,768]{3,2,1,0} reshape(multiply.12)
concatenate.15 = f32[2,16,128,768]{3,2,1,0} concatenate(reshape.13, reshape.14), dimensions={0}
constant.10 = f32[] constant(0)
reduce.20 = f32[2,16,128]{2,1,0} reduce(concatenate.15, constant.10), dimensions={3}, to_apply=region_0.16
constant.8 = f32[] constant(768)
broadcast.9 = f32[2,16,128]{2,1,0} broadcast(constant.8), dimensions={}
divide.21 = f32[2,16,128]{2,1,0} divide(reduce.20, broadcast.9)
slice.22 = f32[1,16,128]{2,1,0} slice(divide.21), slice={[0:1], [0:16], [0:128]}
reshape.29 = f32[16,128,1]{2,1,0} reshape(slice.22)
broadcast.32 = f32[16,128,1]{2,1,0} broadcast(reshape.29), dimensions={0,1,2}
reshape.33 = f32[16,128]{1,0} reshape(broadcast.32)
broadcast.34 = f32[16,128,768]{2,1,0} broadcast(reshape.33), dimensions={0,1}
subtract.35 = f32[16,128,768]{2,1,0} subtract(convert.31, broadcast.34)
slice.24 = f32[1,16,128]{2,1,0} slice(divide.21), slice={[1:2], [0:16], [0:128]}
reshape.25 = f32[16,128]{1,0} reshape(slice.24)
reshape.23 = f32[16,128]{1,0} reshape(slice.22)
multiply.26 = f32[16,128]{1,0} multiply(reshape.23, reshape.23)
subtract.27 = f32[16,128]{1,0} subtract(reshape.25, multiply.26)
constant.6 = f32[] constant(0)
broadcast.7 = f32[16,128]{1,0} broadcast(constant.6), dimensions={}
maximum.28 = f32[16,128]{1,0} maximum(subtract.27, broadcast.7)
reshape.30 = f32[16,128,1]{2,1,0} reshape(maximum.28)
constant.4 = f32[] constant(1e-06)
broadcast.5 = f32[16,128,1]{2,1,0} broadcast(constant.4), dimensions={}
add.36 = f32[16,128,1]{2,1,0} add(reshape.30, broadcast.5)
rsqrt.37 = f32[16,128,1]{2,1,0} rsqrt(add.36)
broadcast.39 = f32[16,128,1]{2,1,0} broadcast(rsqrt.37), dimensions={0,1,2}
reshape.40 = f32[16,128]{1,0} reshape(broadcast.39)
broadcast.41 = f32[16,128,768]{2,1,0} broadcast(reshape.40), dimensions={0,1}
Arg_1.2 = f32[768]{0} parameter(1), sharding={replicated}
reshape.38 = f32[1,1,768]{2,1,0} reshape(Arg_1.2)
broadcast.42 = f32[1,1,768]{2,1,0} broadcast(reshape.38), dimensions={0,1,2}
reshape.43 = f32[768]{0} reshape(broadcast.42)
broadcast.44 = f32[16,128,768]{2,1,0} broadcast(reshape.43), dimensions={2}
multiply.45 = f32[16,128,768]{2,1,0} multiply(broadcast.41, broadcast.44)
multiply.46 = f32[16,128,768]{2,1,0} multiply(subtract.35, multiply.45)
Arg_0.1 = f32[768]{0} parameter(0), sharding={replicated}
reshape.47 = f32[1,1,768]{2,1,0} reshape(Arg_0.1)
broadcast.48 = f32[1,1,768]{2,1,0} broadcast(reshape.47), dimensions={0,1,2}
reshape.49 = f32[768]{0} reshape(broadcast.48)
broadcast.50 = f32[16,128,768]{2,1,0} broadcast(reshape.49), dimensions={2}
add.51 = f32[16,128,768]{2,1,0} add(multiply.46, broadcast.50)
ROOT convert.52 = bf16[16,128,768]{2,1,0} convert(add.51)
}
)";

EXPECT_TRUE(RunAndCompare(layer_norm_module_str, ErrorSpec{1e-2, 1e-2}));
MatchOptimizedHlo(layer_norm_module_str,
R"(
; CHECK: custom_call_target="__onednn$layernorm",
; CHECK: backend_config={
; CHECK-DAG: "onednn_layer_norm_config":{
; CHECK-DAG: "fused_ops":"SCALE_AND_SHIFT"
; CHECK-DAG: }
; CHECK: }
)");
}

} // namespace
} // namespace xla

0 comments on commit e67bb7f

Please sign in to comment.