Skip to content

Commit

Permalink
[js/webgpu] Use the naive convTranspose when in/out channels are both…
Browse files Browse the repository at this point in the history
… 1 (microsoft#18658)

### Description
With this change, convTranspose with input0 [1, 18, 32, 1], input1 [1,
1, 16, 16] becomes 0.59ms from 6.64ms.
  • Loading branch information
qjia7 authored Dec 4, 2023
1 parent a5b2291 commit 5353adc
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,20 @@ const convTranspose2d =
(context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
if (adjustedAttributes.group !== 1) {
const outputShape = adjustedAttributes.outputShape;
const outChannels = outputShape[isChannelsLast ? 3 : 1];
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
// Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's
// not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit
// utilization rate is very low.
if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) {
context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes));
return;
}
const outputShape = adjustedAttributes.outputShape;
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];
const weightHeight = inputs[1].dims[2];
const weightWidth = inputs[1].dims[3];
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];

const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
Expand All @@ -240,6 +242,7 @@ const convTranspose2d =

// STEP.2: prepare reshaped inputs
const convTransposeInputs = [inputs[0], transposedWeight];
const hasBias = inputs.length === 3;
if (hasBias) {
if (!isChannelsLast && inputs[2].dims.length === 1) {
convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
Expand Down

0 comments on commit 5353adc

Please sign in to comment.