Skip to content

Commit

Permalink
Updating QDQ to support Float8E4M3FN (#16550)
Browse files Browse the repository at this point in the history
### Description
Naive update quantization tools to support Float8E4M3FN for Gemm.
  • Loading branch information
xadupre authored Aug 8, 2023
1 parent 063e905 commit d0316ee
Show file tree
Hide file tree
Showing 19 changed files with 2,479 additions and 232 deletions.
800 changes: 800 additions & 0 deletions docs/python/notebooks/quantization_f8.ipynb

Large diffs are not rendered by default.

37 changes: 21 additions & 16 deletions onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ REGISTER_QUANTIZELINEAR(Float8E5M2FNUZ)
REGISTER_QUANTIZELINEAR_VERSIONED(int8_t)
REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t)

template <typename OutputType>
void ParQuantizeLinear(const float* Input,
template <typename InputType, typename OutputType>
void ParQuantizeLinear(const InputType* Input,
OutputType* Output,
size_t N,
float Scale,
InputType Scale,
size_t bd,
const OutputType* ZeroPoint,
bool saturate,
Expand All @@ -236,11 +236,22 @@ void ParQuantizeLinear(const float* Input,
ParQuantizeLinearStd(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : (OutputType)0, thread_pool);
#if !defined(DISABLE_FLOAT8_TYPES)
} else {
ParQuantizeLinearSat(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : OutputType(static_cast<float>(0), true), saturate, thread_pool);
ParQuantizeLinearSat(Input, Output, N, Scale, ZeroPoint != nullptr ? ZeroPoint[bd] : OutputType(static_cast<InputType>(static_cast<float>(0)), true), saturate, thread_pool);
}
#endif
}

template <typename T, typename InT>
void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const T* zero_point, T* output, int64_t N, int64_t broadcast_dim, int64_t block_size, bool saturate) {
for (size_t n = 0; n < static_cast<size_t>(N); n++) {
for (size_t bd = 0; bd < static_cast<size_t>(broadcast_dim); bd++) {
ParQuantizeLinear(input, output, static_cast<size_t>(block_size), scale[bd], bd, zero_point, saturate, ctx->GetOperatorThreadPool());
input += block_size;
output += block_size;
}
}
}

// formula is Y = X / Scale + ZeroPoint
template <typename T>
Status QuantizeLinear<T>::Compute(OpKernelContext* ctx) const {
Expand All @@ -256,20 +267,14 @@ Status QuantizeLinear<T>::Compute(OpKernelContext* ctx) const {
PrepareForQDQ(x.Shape(), y_scale, y_zero_point, axis_, N, broadcast_dim, block_size);

const T* zero_point = y_zero_point != nullptr ? y_zero_point->Data<T>() : nullptr;
if (x.IsDataType<float>()) {
const float* scale = y_scale.Data<float>();
const float* input = x.Data<float>();
T* output = y.MutableData<T>();
T* output = y.MutableData<T>();

for (size_t n = 0; n < static_cast<size_t>(N); n++) {
for (size_t bd = 0; bd < static_cast<size_t>(broadcast_dim); bd++) {
ParQuantizeLinear(input, output, static_cast<size_t>(block_size), scale[bd], bd, zero_point, saturate_, ctx->GetOperatorThreadPool());
input += block_size;
output += block_size;
}
}
if (x.IsDataType<float>()) {
ComputeLoop<T, float>(ctx, x.Data<float>(), y_scale.Data<float>(), zero_point, output, N, broadcast_dim, block_size, saturate_);
} else if (x.IsDataType<MLFloat16>()) {
ComputeLoop<T, MLFloat16>(ctx, x.Data<MLFloat16>(), y_scale.Data<MLFloat16>(), zero_point, output, N, broadcast_dim, block_size, saturate_);
} else {
ORT_THROW("Quantization from float16 is not supported yet for CPU provider.");
ORT_THROW("Unsupported input type.");
}

return Status::OK();
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/core/util/qmath.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,35 @@ ParQuantizeLinearStd(const float* Input,
});
}

// This implementation could be more efficient however the cast from float16 to other types
// usually happens on GPU.
template <typename OutputType>
#if !defined(DISABLE_FLOAT8_TYPES)
typename std::enable_if<!boost::mp11::mp_contains<element_type_lists::AllFloat8, OutputType>::value, void>::type
#else
void
#endif
ParQuantizeLinearStd(const MLFloat16* Input,
OutputType* Output,
size_t N,
MLFloat16 Scale,
OutputType ZeroPoint,
concurrency::ThreadPool* thread_pool) {
constexpr std::ptrdiff_t block_size = 128;
const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size;
const TensorOpCost unit_cost{static_cast<double>(block_size * sizeof(MLFloat16)), static_cast<double>(block_size * sizeof(uint8_t)), static_cast<double>(block_size) * 2.0};
concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
auto begin_idx = begin * block_size;
auto end_idx = std::min(static_cast<std::ptrdiff_t>(N), end * block_size);
float fscale = Scale.ToFloat();
for (; begin_idx != end_idx; ++begin_idx) {
int32_t ival = static_cast<int32_t>(Input[begin_idx].ToFloat() / fscale) + ZeroPoint;
Output[begin_idx] = static_cast<OutputType>(std::min(static_cast<int32_t>(std::numeric_limits<OutputType>::max()),
std::max(static_cast<int32_t>(std::numeric_limits<OutputType>::lowest()), ival)));
}
});
}

#if !defined(DISABLE_FLOAT8_TYPES)

template <typename OutputFloat8Type>
Expand All @@ -155,6 +184,30 @@ ParQuantizeLinearSat(const float* Input,
});
}

// The implementation converts float16 to float and then do a quantization.
// This is not efficient and is mostly added to enable unittest on CPU.
// This case usually happens on GPU.
template <typename OutputFloat8Type>
typename std::enable_if<boost::mp11::mp_contains<element_type_lists::AllFloat8, OutputFloat8Type>::value, void>::type
ParQuantizeLinearSat(const MLFloat16* Input,
OutputFloat8Type* Output,
size_t N,
MLFloat16 Scale,
const OutputFloat8Type& /* ORT_UNUSED_PARAMETER(ZeroPoint) */,
bool saturate,
concurrency::ThreadPool* thread_pool) {
constexpr std::ptrdiff_t block_size = 128;
const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size;
const TensorOpCost unit_cost{static_cast<double>(block_size * sizeof(MLFloat16)), static_cast<double>(block_size * sizeof(uint8_t)), static_cast<double>(block_size) * 2.0};
concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
auto begin_idx = begin * block_size;
auto end_idx = std::min(static_cast<std::ptrdiff_t>(N), end * block_size);
for (; begin_idx < end_idx; ++begin_idx) {
Output[begin_idx] = OutputFloat8Type(Input[begin_idx].ToFloat() / Scale.ToFloat(), saturate);
}
});
}

#endif

} // namespace onnxruntime
Loading

0 comments on commit d0316ee

Please sign in to comment.