Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Optimize conv1d by conv2d #19388

Merged
merged 9 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ export const createConv2DMatMulProgramInfo = (
dimInner: number,
hasBias: boolean,
sequentialAccessByThreads: boolean,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const isChannelsLast = attributes.format === 'NHWC';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
Expand Down Expand Up @@ -290,6 +291,18 @@ export const createConv2DMatMulProgramInfo = (
elementsSize[2],
t,
)}
${conv2dCommonSnippet(
isChannelsLast,
fitAOuter,
fitBOuter,
fitInner,
hasBias,
attributes,
elementsSize[0],
elementsSize[1],
elementsSize[2],
t,
)}
${
isVec4
? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner)
Expand All @@ -309,13 +322,16 @@ export const createConv2DMatMulProgramInfo = (
return {
name: 'Conv2DMatMul',
shaderCache: {
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
tileAOuter
};${tileBOuter};${tileInner}`,
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${tileAOuter};${tileBOuter};${tileInner}`,
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
programUniforms,
}),
Expand Down
22 changes: 10 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,9 @@ export const makeMatMulPackedVec4Source = (
workPerThread[0] === 4
)
) {
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${
innerElementSize
} and workPerThread[1] ${workPerThread[1]} must be 4.
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4.
Otherwise, innerElementSize ${innerElementSize} must be 3 or 4.
tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${
tileInner
} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`);
tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`);
}
return `
var<workgroup> mm_Asub: array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
Expand Down Expand Up @@ -227,11 +223,7 @@ export const makeMatMulPackedSource = (
!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0)
) {
throw new Error(
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${
workgroupSize[1]
}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${
workgroupSize[0]
}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`,
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`,
);
}
const rowPerThreadA = tileAHight / workgroupSize[1];
Expand Down Expand Up @@ -470,6 +462,7 @@ export const createMatmulProgramInfo = (
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand Down Expand Up @@ -562,7 +555,12 @@ export const createMatmulProgramInfo = (
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
programUniforms,
}),
Expand Down
8 changes: 7 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ export const createGroupedConvVectorizeProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
outputShape: readonly number[],
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
const components = getMaxComponents(outputShape[3]);
Expand Down Expand Up @@ -234,7 +235,12 @@ export const createGroupedConvVectorizeProgramInfo = (
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
Expand Down
59 changes: 38 additions & 21 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,23 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
};
};

const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);

const conv2d = (
context: ComputeContext,
inputs: readonly TensorView[],
attributes: ConvAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): void => {
// check attributes

// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
if (attributes.group !== 1) {
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
// Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases:
// [webgpu]Conv - conv - vectorize group - B
// [webgpu]Conv - conv - vectorize group - D
const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
const disableGroupedConvVectorize = false;
if (
enableGroupedConvVectorize &&
!disableGroupedConvVectorize &&
isChannelsLast &&
inputs[1].dims[0] === attributes.group &&
inputs[1].dims[1] === 1 &&
Expand All @@ -177,7 +179,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
adjustedAttributes.pads,
attributes.pads,
attributes.strides,
isChannelsLast,
);
Expand All @@ -194,11 +196,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
if (inputs.length === 3) {
convInputs.push(inputs[2]);
}
context.compute(createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {
inputs: convInputs,
});
context.compute(
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
{ inputs: convInputs },
);
} else {
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
}
return;
}
Expand All @@ -214,7 +217,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
adjustedAttributes.pads,
attributes.pads,
attributes.strides,
isChannelsLast,
);
Expand Down Expand Up @@ -280,12 +283,26 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// Tune the threshold.
if (N < 8 && K < 8) {
context.compute(
createNaiveMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
createNaiveMatmulProgramInfo(
matmulInputs,
attributes,
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs },
);
} else {
context.compute(
createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
createMatmulProgramInfo(
matmulInputs,
attributes,
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs },
);
}
Expand Down Expand Up @@ -320,13 +337,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
context.compute(
createConv2DMatMulProgramInfo(
convInputs,
adjustedAttributes,
attributes,
outputShape,
dimAOuter,
dimBOuter,
dimInner,
hasBias,
sequentialAccessByThreads,
squeezeOutputShapeFunction,
),
{ inputs: convInputs },
);
Expand Down Expand Up @@ -357,10 +375,8 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
{ ...attributes, pads, strides, dilations, kernelShape },
inputs,
);
context.compute(
createGroupedConvProgramInfo(inputs, adjustedAttributes, (outputShape) =>
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [],
),
conv2d(context, inputs, adjustedAttributes, (outputShape) =>
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]],
);
};

Expand Down Expand Up @@ -396,6 +412,7 @@ export const conv = (context: ComputeContext, attributes: ConvAttributes): void
} else if (context.inputs[0].dims.length === 5) {
conv3d(context, context.inputs, attributes);
} else {
conv2d(context, context.inputs, attributes);
const adjustedAttributes = getAdjustedConvAttributes(attributes, context.inputs);
conv2d(context, context.inputs, adjustedAttributes);
}
};
14 changes: 10 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const createNaiveMatmulProgramInfo = (
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand Down Expand Up @@ -120,9 +121,9 @@ export const createNaiveMatmulProgramInfo = (

for (let j = 0; j < aComponents; j++) {
calcStr += `
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${
i
}]);\n`;
values[${i}] = fma(${b.type.value}(a_data${
aComponents === 1 ? '' : `[${j}]`
}), b_data${j}, values[${i}]);\n`;
}
}
return calcStr;
Expand Down Expand Up @@ -168,7 +169,12 @@ export const createNaiveMatmulProgramInfo = (
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
return {
name: 'Transpose',
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
getRunData: (inputs) => {
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(inputs[0].dims, outputShape),
...createTensorShapeVariables(inputTensor.dims, outputShape),
],
};
},
Expand Down
69 changes: 69 additions & 0 deletions js/web/test/data/ops/conv1d.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
[
{
"name": "conv 1D without bias addition A",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30],
"dims": [1, 1, 3],
"type": "float32"
},
{
"data": [1, 2],
"dims": [1, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [50, 80],
"dims": [1, 1, 2],
"type": "float32"
}
]
}
]
},
{
"name": "conv 1D with bias addition A",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40],
"dims": [1, 2, 2],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"dims": [4, 2, 2],
"type": "float32"
},
{
"data": [0.1, 0.2, 0.3, 0.4],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [100.1, 100.2, 100.3, 100.4],
"dims": [1, 4, 1],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@
"concat_zero-sized.jsonc",
"cast.jsonc",
"conv.jsonc",
"conv1d.jsonc",
"conv3dncdhw.jsonc",
"cos.jsonc",
"div.jsonc",
Expand Down
Loading