Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Add GatherBlockQuantized op support #21734

Merged
merged 48 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3e2e54c
Added GatherBlockQuantized operator.
satyajandhyala Aug 9, 2024
6c34932
Remove templatization
satyajandhyala Aug 13, 2024
c9d465c
Added int4x2 and uint4x2
satyajandhyala Aug 13, 2024
2321960
Revert "Remove templatization"
satyajandhyala Aug 13, 2024
4db3348
Fixed script to look for ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME
satyajandhyala Aug 13, 2024
5c3f33e
Updated the doc
satyajandhyala Aug 13, 2024
cb63d14
Added test cases.
satyajandhyala Aug 14, 2024
6c27002
Added more GatherBlockQunntized op functionality.
satyajandhyala Aug 14, 2024
3677d60
Calculate zero-point array index.
satyajandhyala Aug 14, 2024
07ceae5
Split signed and unsigned test cases, not group.
satyajandhyala Aug 14, 2024
fe80212
Reapply "Remove templatization"
satyajandhyala Aug 14, 2024
b29d9b2
lint
satyajandhyala Aug 14, 2024
dd6f95a
Add missing semicolon
satyajandhyala Aug 14, 2024
253d409
trim error message
satyajandhyala Aug 14, 2024
8da8bd9
Inserted missing line space.
satyajandhyala Aug 14, 2024
1ea1f98
Test using indices input with dims > 1
satyajandhyala Aug 14, 2024
1b6accd
updated tensor_helper.cc
satyajandhyala Aug 15, 2024
f75dcbd
Updated hint
satyajandhyala Aug 15, 2024
a0355a8
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 15, 2024
4f016ca
format
satyajandhyala Aug 15, 2024
219e2b0
Replaced ternary operator with if-else
satyajandhyala Aug 15, 2024
56554da
Use vec instead of array to unpack data and use built-in function unp…
satyajandhyala Aug 15, 2024
9e9d15e
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 16, 2024
31a62e2
Add code to verify that the indices input is valid.
satyajandhyala Aug 15, 2024
bf3abaa
Added (u)int4 in tensor-impl.ts
satyajandhyala Aug 16, 2024
963814a
Commented out indices validation code.
satyajandhyala Aug 16, 2024
5fbb497
Added (u)int4.
satyajandhyala Aug 16, 2024
d065cea
test related changes
satyajandhyala Aug 16, 2024
6a57c34
fixed unused variable.
satyajandhyala Aug 16, 2024
163866a
everted changes tensor_helper.cc
satyajandhyala Aug 16, 2024
a519237
Updated expected output to match that of wasm.
satyajandhyala Aug 16, 2024
b93ca99
Use indicesGet/indicesSet to access index out of indices
satyajandhyala Aug 16, 2024
60d6ba8
typo
satyajandhyala Aug 16, 2024
0a7387d
renamed dequantize-linear_int4.jsonc as dequantize-linear-int4.jsonc
satyajandhyala Aug 16, 2024
9b5eac4
Indices should be normalized before indexing. Added a test case.
satyajandhyala Aug 17, 2024
d62058a
format JSONC
satyajandhyala Aug 17, 2024
2162f92
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 17, 2024
18a3906
Avoid producing presentKey/presentValue outputs if pastKey/pastValue …
satyajandhyala Aug 17, 2024
9e22564
Don't treat empty inputs as undefined in MHA. Let Attention deal with…
satyajandhyala Aug 17, 2024
c01f721
Feed pastKey/pastValue inputs down to the functions that generate sha…
satyajandhyala Aug 17, 2024
c8b187f
Added a test case with zero-size pastKey/pastValue inputs that requir…
satyajandhyala Aug 17, 2024
d0b0627
Added back the assumption comment.
satyajandhyala Aug 17, 2024
d8aeee1
ShapeUtis.size should return 0 if the tensor dims is empty instead of 1.
satyajandhyala Aug 18, 2024
7520dc2
Revert "ShapeUtis.size should return 0 if the tensor dims is empty in…
satyajandhyala Aug 18, 2024
5b706fe
Merge branch 'sajandhy/wepu_fix_attention_output' of github.com:micro…
satyajandhyala Aug 18, 2024
52e2821
Skip Indices shape from the cache key.
satyajandhyala Aug 19, 2024
62e2665
format
satyajandhyala Aug 19, 2024
c562099
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ export class Tensor implements TensorInterface {
type !== 'int64' &&
type !== 'uint32' &&
type !== 'uint8' &&
type !== 'bool'
type !== 'bool' &&
type !== 'uint4' &&
type !== 'int4'
) {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
Expand Down
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherBlockQuantized | com.microsoft(1+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| Gelu | ai.onnx(20+); com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum';
import { expand } from './ops/expand';
import { fastGelu } from './ops/fast-gelu';
import { gather, parseGatherAttributes } from './ops/gather';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
Expand Down Expand Up @@ -96,6 +97,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
Expand Down
5 changes: 4 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s
throw new Error('bool must be vec4');
}
return ['u32', 'vec4<bool>'];

case DataType.int4:
return 'i32';
case DataType.uint4:
return 'u32';
default:
throw new Error(`Unknown data type: ${type}`);
}
Expand Down
193 changes: 193 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';

import {
createTensorShapeVariables,
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglValueType,
UniformsArrayType,
} from './common';

export interface GatherBlockQuantizedAttributes extends AttributeWithCacheKey {
gatherAxis: number;
quantizeAxis: number;
blockSize: number;
}

export const validateInputs = (inputs: readonly TensorView[], attributes: GatherBlockQuantizedAttributes): void => {
if (inputs.length < 3 || inputs.length > 4) {
throw new Error('GatherBlockQuantized requires 3 or 4 inputs.');
}
const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputs[0].dims.length);
const blockSize = attributes.blockSize;
const data = inputs[0];
const scales = inputs[2];
const zeroPoint = inputs.length === 4 ? inputs[3] : undefined;
if (
scales.dims.length !== data.dims.length ||
!data.dims
.map((d, i) => (i === quantizeAxis ? Math.ceil(d / blockSize) === scales.dims[i] : d === scales.dims[i]))
.reduce((a, b) => a && b, true)
) {
throw new Error(
'Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.',
);
}
// TODO Uncomment the following check once the test case creation code is fixed to create data correctly aligned.
// const indices = inputs[1];
// const validIndex = (index: number) => index >= 0 && index < data.dims[attributes.gatherAxis];
// if (indices.dataType === DataType.int32 && indices.getInt32Array().some((v) => !validIndex(v)) ||
// indices.dataType === DataType.int64 && indices.getBigInt64Array().some((v) => !validIndex(Number(v)))) {
// throw new Error('Indices must be within the bounds of the gatherAxis.');
// }
if (zeroPoint) {
if (zeroPoint.dataType !== data.dataType) {
throw new Error('Zero point must have the same data type as the input tensor.');
}
if (
zeroPoint.dims.length !== scales.dims.length ||
!zeroPoint.dims.map((d, i) => d === scales.dims[i]).reduce((a, b) => a && b, true)
) {
throw new Error(
'Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.',
);
}
}
};

const createGatherBlockQuantizedProgramInfo = (
inputs: readonly TensorView[],
attributes: GatherBlockQuantizedAttributes,
): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
const inputRank = inputShape.length;
const gatherAxis = ShapeUtil.normalizeAxis(attributes.gatherAxis, inputRank);
const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputRank);
const outputShape = inputShape.slice(0);
outputShape.splice(gatherAxis, 1, ...indicesShape);
const outputSize = ShapeUtil.size(outputShape);
const outputType = inputs[2].dataType;
const inputType = inputs[0].dataType;
const isSigned = inputType === DataType.int4; // input data type is either int4 or uint4.
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: quantizeAxis },
{ type: DataType.uint32, data: gatherAxis },
{ type: DataType.uint32, data: attributes.blockSize },
...createTensorShapeVariables(...inputs.map((input, _) => input.dims), outputShape),
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const zeroPoint =
inputs.length > 3 ? inputVariable('zeroPoint', inputs[3].dataType, inputs[3].dims.length) : undefined;
const output = outputVariable('output', outputType, outputShape.length);
const inputVariables = [data, indices, scales];
if (zeroPoint) {
inputVariables.push(zeroPoint);
}
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'quantize_axis', type: 'u32' },
{ name: 'gather_axis', type: 'u32' },
{ name: 'block_size', type: 'u32' },
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
let output_indices = ${output.offsetToIndices('global_idx')};
var indices_indices = ${indices.type.indices}(0);
${(() => {
if (indicesShape.length > 1) {
return `
for (var i: u32 = 0; i < ${indicesShape.length}; i++) {
let index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')};
${indices.indicesSet('indices_indices', 'i', 'index')};
}`;
} else {
return `indices_indices = ${output.indicesGet('output_indices', 'uniforms.gather_axis')};`;
}
})()};
var data_indices = ${data.type.indices}(0);
for (var i: u32 = 0; i < uniforms.gather_axis; i++) {
let index = ${output.indicesGet('output_indices', 'i')};
${data.indicesSet('data_indices', 'i', 'index')};
}
var index_from_indices = ${indices.getByIndices('indices_indices')};
if (index_from_indices < 0) {
index_from_indices += ${inputShape[gatherAxis]};
}
${data.indicesSet('data_indices', 'uniforms.gather_axis', 'u32(index_from_indices)')};
for (var i = uniforms.gather_axis + 1; i < ${outputShape.length}; i++) {
let index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)};
${data.indicesSet('data_indices', 'i', 'index')};
}
let data_offset = ${data.indicesToOffset('data_indices')};
let data_index = data_offset % 8;
// Convert 4-bit packed data to 8-bit packed data.
let packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')};
let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
let quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data));
let quantized_data = quantized_data_vec[data_index / 2];
var scale_indices = data_indices;
let quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size;
${scales.indicesSet('scale_indices', 'uniforms.quantize_axis', 'quantize_axis_index')};
var scale = ${scales.getByIndices('scale_indices')};
${(() => {
if (!zeroPoint) {
return 'var zero_point = 0';
} else {
return `
let zero_point_indices = scale_indices;
let zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')};
let zero_point_index = zero_point_offset % 8;
let packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')};
let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
let zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_points));
let zero_point = zero_point_vec[zero_point_index / 2];`;
}
})()};
let dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale;
${output.setByOffset('global_idx', 'dequantized_data')};
}`;
};
return {
name: 'GatherBlockQuantized',
shaderCache: {
hint: `${attributes.cacheKey};${inputs.map((input, _) => input.dims.join('_')).join(';')}`,
satyajandhyala marked this conversation as resolved.
Show resolved Hide resolved
inputDependencies: Array.from({ length: inputs.length }, (_v, _i) => 'rank'),
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: outputType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};

export const gatherBlockQuantized = (context: ComputeContext, attributes: GatherBlockQuantizedAttributes): void => {
const inputs = context.inputs;
validateInputs(inputs, attributes);
context.compute(createGatherBlockQuantizedProgramInfo(context.inputs, attributes));
};

export const parseGatherBlockQuantizedAttributes = (
attributes: Record<string, unknown>,
): GatherBlockQuantizedAttributes =>
createAttributeWithCacheKey({
blockSize: attributes.blockSize as number,
gatherAxis: attributes.gatherAxis as number,
quantizeAxis: attributes.quantizeAxis as number,
});
4 changes: 3 additions & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB
type === 'int64' ||
type === 'uint32' ||
type === 'uint8' ||
type === 'bool';
type === 'bool' ||
type === 'uint4' ||
type === 'int4';

/**
* Map string data location to integer value
Expand Down
1 change: 1 addition & 0 deletions js/web/script/generate-webgpu-operator-md.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const MATCHERS = [
/class ONNX_OPERATOR_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersionStart>\d+),\s*(?<opsetVersionEnd>\d+),\s*(?<type>\w+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<type>\w+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<type1>\w+),\s*(?<type2>\w+),\s*(?<op>\w+)\)/g,
];
/* eslint-enable max-len */

Expand Down
Loading
Loading