From 74acaf64255c063fbd085f79fb90a87eedc4efc7 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 10 Jul 2024 17:35:47 -0700 Subject: [PATCH] [DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul (#21298) ### Description [DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul The DynamicQuantizeMatMul allows input tensors in NCHW format, and DirectML requires that input tensors share the same batch and channel dimensions. Tensors A and B should be broadcast (if possible) to the corresponding output NC dims. ### Motivation and Context Certain models which use DynamicQuantizeMatMul hit a crash when the NC dims are intended to be broadcast. --------- Co-authored-by: Sheil Kumar --- .../DmlOperatorDynamicQuantizeMatMul.cpp | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp index c6a87da705a99..32d6af73aae8d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0) ); } + MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::A).tensorDataType; MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType; + gsl::span outputSizes = m_outputTensorDescs[0].GetSizes(); std::vector ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A); - std::vector ExpectedAScaleTensorShape = {1, 1, 1, 1}; - std::vector ExpectedAZeroPointTensorShape = {1, 1, 1, 1}; + std::vector BTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::B); + std::vector ExpectedAScaleTensorShape(outputSizes.size(), 1); + std::vector ExpectedAZeroPointTensorShape(outputSizes.size(), 1); + ML_CHECK_VALID_ARGUMENT(outputSizes.size() >= 4); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() + 2 >= outputSizes.size()); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() + 2 >= outputSizes.size()); + std::vector AShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(ATensorShape.end() - (outputSizes.size() - 2), + ATensorShape.end(), + AShapeBroadcasted.begin() + 2); + std::vector BShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(BTensorShape.end() - (outputSizes.size() - 2), + BTensorShape.end(), + BShapeBroadcasted.begin() + 2); // output edges between DynQL and MMItoFloat node TensorDesc intermediateQuantizedATensorDesc = TensorDesc( BDatatype, - gsl::make_span(ATensorShape), + gsl::make_span(AShapeBroadcasted), gsl::make_span(ATensorShape), TensorAxis::DoNotCoerce, TensorAxis::W, @@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator 0 // guaranteedBaseOffsetAlignment ); + TensorDesc broadcastedATensorDesc = TensorDesc( + ADatatype, + AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + TensorDesc broadcastedBTensorDesc = TensorDesc( + BDatatype, + BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc(); + DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc(); @@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator std::vector outputDescs = GetDmlOutputDescs(); DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {}; - dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A]; + dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc; dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc; @@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor; matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor; matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor; - matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B]; + matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc; matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale]; matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr; matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;