Skip to content

Commit

Permalink
[CPU] FullyConnected weights compression: mxfp4 (wei=f4e2m1, scales=f…
Browse files Browse the repository at this point in the history
…8e8m0) support
  • Loading branch information
dmitry-gorokhov committed Jul 31, 2024
1 parent 36eebc2 commit 0bc0f57
Show file tree
Hide file tree
Showing 21 changed files with 262 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace pass {
class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass {
public:
OPENVINO_RTTI("MarkDequantizationSubgraph", "0");
MarkDequantizationSubgraph(const element::TypeVector& precisions, const bool fold_subtract_const = false);
MarkDequantizationSubgraph(const element::TypeVector& precisions, const bool fold_subtract_const = false, const bool fold_multiply_const = true);
};
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#include "transformations/utils/utils.hpp"

ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const) {
const bool fold_subtract_const,
const bool fold_multiply_const) {
// Dequantization subgraph may have two forms: with and without Subtract
//
// Input Input
Expand All @@ -28,10 +29,11 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
auto input_pattern = pattern::any_input();
auto convert_pattern = pattern::wrap_type<ov::op::v0::Convert>({input_pattern}, pattern::consumers_count(1));
auto zero_point_pattern = pattern::any_input();
auto scale_pattern = pattern::any_input();
auto subtract_pattern = pattern::wrap_type<ov::op::v1::Subtract>({convert_pattern, zero_point_pattern});
auto multiply_pattern = pattern::wrap_type<ov::op::v1::Multiply>({subtract_pattern, pattern::any_input()});
auto multiply_pattern = pattern::wrap_type<ov::op::v1::Multiply>({subtract_pattern, scale_pattern});
auto multiply_no_subtract_pattern =
pattern::wrap_type<ov::op::v1::Multiply>({convert_pattern, pattern::any_input()});
pattern::wrap_type<ov::op::v1::Multiply>({convert_pattern, scale_pattern});
auto root = std::make_shared<pattern::op::Or>(OutputVector{multiply_pattern, multiply_no_subtract_pattern});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) -> bool {
Expand Down Expand Up @@ -99,6 +101,18 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::

// mark Multiply as dequantization node
ov::mark_as_dequantization_node(multiply);
auto scale = multiply->get_input_node_shared_ptr(1);
if (ov::is_type<ov::op::v0::Convert>(scale) &&
ov::is_type<ov::op::v0::Constant>(scale->get_input_node_ptr(0))) {
if (!fold_multiply_const) {
ov::disable_constant_folding(scale);
ov::unmark_as_decompression(scale);
ov::enable_keep_const_precision(scale->get_input_node_shared_ptr(0));
} else {
ov::enable_constant_folding(scale);
ov::disable_keep_const_precision(scale->get_input_node_shared_ptr(0));
}
}

return false;
};
Expand Down
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/dnnl_extension_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ uint8_t DnnlExtensionUtils::sizeOfDataType(dnnl::memory::data_type dataType) {
case dnnl::memory::data_type::nf4:
case dnnl::memory::data_type::s4:
case dnnl::memory::data_type::u4:
case dnnl::memory::data_type::f8_e8m0:
case dnnl::memory::data_type::f4_e2m1:
return 1;
case dnnl::memory::data_type::undef:
return 0;
Expand Down Expand Up @@ -66,6 +68,10 @@ dnnl::memory::data_type DnnlExtensionUtils::ElementTypeToDataType(const ov::elem
return memory::data_type::s4;
case ov::element::u4:
return memory::data_type::u4;
case ov::element::f8e8m0:
return memory::data_type::f8_e8m0;
case ov::element::f4e2m1:
return memory::data_type::f4_e2m1;
case ov::element::undefined:
return memory::data_type::undef;
default: {
Expand Down Expand Up @@ -98,6 +104,10 @@ ov::element::Type DnnlExtensionUtils::DataTypeToElementType(const dnnl::memory::
return ov::element::i4;
case memory::data_type::u4:
return ov::element::u4;
case memory::data_type::f8_e8m0:
return ov::element::f8e8m0;
case memory::data_type::f4_e2m1:
return ov::element::f4e2m1;
case memory::data_type::undef:
return ov::element::undefined;
default: {
Expand Down
7 changes: 4 additions & 3 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,13 @@ static MemoryPtr prepackDecompressionParams(const MemoryCPtr& paramsPtr,
return dstMem;
}

void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose) {
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision) {
if (scales_ptr == nullptr)
return;

auto scalesMem = prepackDecompressionParams(scales_ptr, needTranspose, ov::element::f32, engine);
attr.set_scales_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(scalesMem->getStaticDims()));
auto scalesMem = prepackDecompressionParams(scales_ptr, needTranspose, dstPrecision, engine);
attr.set_scales_dims(DNNL_ARG_WEIGHTS,
DnnlExtensionUtils::convertToDnnlDims(scalesMem->getStaticDims()), DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = std::move(scalesMem);
dnnlArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS]->getPrimitive();
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/dnnl_postops_composer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DnnlPostOpsComposer {
const bool hasBias,
const dnnl::memory::data_type outDataType);
DnnlPrimitiveAttrs compose();
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose);
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision);
void setDynamicQuantizationParams(uint64_t groupSize);

Expand Down
29 changes: 23 additions & 6 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
}

void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
std::set<ov::element::Type> supportedWeightsPrecisions{ov::element::u8, ov::element::i8, ov::element::nf4, ov::element::u4, ov::element::i4};
std::set<ov::element::Type> supportedWeightsPrecisions{
ov::element::u8, ov::element::i8, ov::element::nf4, ov::element::u4, ov::element::i4, ov::element::f4e2m1};
const std::set<ov::element::Type> supportedDataPrecisions{ov::element::f32, ov::element::bf16};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
Expand Down Expand Up @@ -329,16 +330,24 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
}

CPU_GRAPH_OPTIMIZER_SCOPE(FuseFCAndWeightsDecompression);
const auto multiplyConstNode = multiplyNode->getParentEdgeAt(1)->getParent();
const auto mulParent1 = multiplyNode->getParentEdgeAt(1)->getParent();
NodePtr multiplyParent, multiplyConvertNode, multiplyConstNode;
multiplyParent = mulParent1;
if (multiplyParent->getType() == Type::Convert) {
multiplyConvertNode = multiplyParent;
multiplyParent = multiplyConvertNode->getParentEdgeAt(0)->getParent();
}
multiplyConstNode = multiplyParent;
if (multiplyConstNode->getType() != Type::Input) {
SKIP_FUSION_FOR_NODE(fcNode);
}
const bool withMultiplyConvert = multiplyConvertNode != nullptr;

const auto mulParent = multiplyNode->getParentEdgeAt(0)->getParent();
const bool withSubtract = mulParent->getAlgorithm() == Algorithm::EltwiseSubtract;
const auto mulParent0 = multiplyNode->getParentEdgeAt(0)->getParent();
const bool withSubtract = mulParent0->getAlgorithm() == Algorithm::EltwiseSubtract;
NodePtr subtractNode, subtractConvertNode, subtractConstNode;
if (withSubtract) {
subtractNode = mulParent;
subtractNode = mulParent0;
if (!expectedNode(subtractNode, Type::Eltwise)) {
SKIP_FUSION_FOR_NODE(fcNode);
}
Expand All @@ -354,7 +363,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
}

const bool withSubtractConvert = subtractConvertNode != nullptr;
const auto convertNode = withSubtract ? subtractNode->getParentEdgeAt(0)->getParent() : mulParent;
const auto convertNode = withSubtract ? subtractNode->getParentEdgeAt(0)->getParent() : mulParent0;
if (!expectedNode(convertNode, Type::Convert)) {
SKIP_FUSION_FOR_NODE(fcNode);
}
Expand Down Expand Up @@ -461,6 +470,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
fcNode->addOriginalLayer(subtractNode->getOriginalLayers());
if (withSubtractConvert)
fcNode->addOriginalLayer(subtractConvertNode->getOriginalLayers());
if (withMultiplyConvert)
fcNode->addOriginalLayer(multiplyConvertNode->getOriginalLayers());

const auto& weightsPrecision = weightsNode->getOriginalOutputPrecisionAtPort(0);
if (withTranspose) {
Expand Down Expand Up @@ -511,6 +522,12 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
graph.RemoveEdge(subtractConvertNode->getParentEdgeAt(0));
}
graph.RemoveEdge(multiplyNode->getParentEdgeAt(1));
if (withMultiplyConvert) {
// MultiplyConvert is removed only if there are no other consumers (e.g. CompressedGather)
const auto& restChilds = multiplyConvertNode->getChildEdges();
if (restChilds.empty())
graph.RemoveEdge(multiplyConvertNode->getParentEdgeAt(0));
}

graph.DropNode(convertNode);
if (withSubtract)
Expand Down
Loading

0 comments on commit 0bc0f57

Please sign in to comment.