Skip to content

Commit

Permalink
[DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul (
Browse files Browse the repository at this point in the history
…#21298)

### Description
[DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul
The DynamicQuantizeMatMul allows input tensors in NCHW format, and
DirectML requires that input tensors share the same batch and channel
dimensions. Tensors A and B should be broadcast (if possible) to the
corresponding output NC dims.

### Motivation and Context
Certain models which use DynamicQuantizeMatMul hit a crash when the NC
dims are intended to be broadcast.

---------

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
smk2007 and Sheil Kumar committed Aug 21, 2024
1 parent 3871274 commit 74acaf6
Showing 1 changed file with 45 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0)
);
}
MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::A).tensorDataType;
MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType;

gsl::span<const uint32_t> outputSizes = m_outputTensorDescs[0].GetSizes();
std::vector<uint32_t> ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A);
std::vector<uint32_t> ExpectedAScaleTensorShape = {1, 1, 1, 1};
std::vector<uint32_t> ExpectedAZeroPointTensorShape = {1, 1, 1, 1};
std::vector<uint32_t> BTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::B);
std::vector<uint32_t> ExpectedAScaleTensorShape(outputSizes.size(), 1);
std::vector<uint32_t> ExpectedAZeroPointTensorShape(outputSizes.size(), 1);
ML_CHECK_VALID_ARGUMENT(outputSizes.size() >= 4);
ML_CHECK_VALID_ARGUMENT(ATensorShape.size() >= 2);
ML_CHECK_VALID_ARGUMENT(BTensorShape.size() >= 2);
ML_CHECK_VALID_ARGUMENT(ATensorShape.size() + 2 >= outputSizes.size());
ML_CHECK_VALID_ARGUMENT(BTensorShape.size() + 2 >= outputSizes.size());
std::vector<uint32_t> AShapeBroadcasted(outputSizes.begin(), outputSizes.end());
std::copy(ATensorShape.end() - (outputSizes.size() - 2),
ATensorShape.end(),
AShapeBroadcasted.begin() + 2);
std::vector<uint32_t> BShapeBroadcasted(outputSizes.begin(), outputSizes.end());
std::copy(BTensorShape.end() - (outputSizes.size() - 2),
BTensorShape.end(),
BShapeBroadcasted.begin() + 2);

// output edges between DynQL and MMItoFloat node
TensorDesc intermediateQuantizedATensorDesc = TensorDesc(
BDatatype,
gsl::make_span(ATensorShape),
gsl::make_span(AShapeBroadcasted),
gsl::make_span(ATensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand Down Expand Up @@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
0 // guaranteedBaseOffsetAlignment
);

TensorDesc broadcastedATensorDesc = TensorDesc(
ADatatype,
AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

TensorDesc broadcastedBTensorDesc = TensorDesc(
BDatatype,
BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc();
Expand All @@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {};
dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A];
dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc;
dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc;
dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc;
dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc;
Expand All @@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor;
matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor;
matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor;
matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B];
matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc;
matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale];
matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr;
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
Expand Down

0 comments on commit 74acaf6

Please sign in to comment.