Skip to content

Commit

Permalink
[CPU] Fixed BF16 Matmul inference precision (#22995)
Browse files Browse the repository at this point in the history
CPU plugin uses EnforceInferencePrecision routine for BF16 precision
mark-up. Its logic assumes only activations precision is changed before
Matmul op, while weights precision keeps w/o any changes. Since
dnnlFCTypeMapping misses BF16 activation, FP32 weights optimized
configuration for bf16, execution always happens in FP32 precision even
user manually set infer_precision=bf16.
This bug is not visible on FP16 IRs (since BF16+FP16 config is present),
so only FP32 IRs affected. SInce save_model and ovc apply FP16
compression be default, the issue mostly applicable for pipelines which
use a model directly after convert_model call.

Cherry-picks: #22994
  • Loading branch information
dmitry-gorokhov committed Feb 22, 2024
1 parent 43bc502 commit b54f753
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp,
// clang-format off
static const TypeMapping dnnlFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_bf16, _bf16, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
// integer precision outputs are not supported for float precision inputs
{{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())},
Expand All @@ -63,7 +63,7 @@ static const MappingNotation dnnlConvolutionMappingNotation {

static const TypeMapping dnnlConvolutionTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_bf16, _bf16, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
// integer precision outputs are not supported for float precision inputs
{{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())},
Expand Down

0 comments on commit b54f753

Please sign in to comment.