Skip to content

Commit

Permalink
[js/web] JSEP Expand fix for inputs with rank < 2 (microsoft#16829)
Browse files Browse the repository at this point in the history
### Description
If Expand inputs has rank < 2, `inputIndicesHelper` and
`outputIndicesHelper` create indices as u32 instead if array<u32> and
`calculateInputIndex` throws an error



### Motivation and Context
I've encountered this error while making StableDiffusion work with JSEP
  • Loading branch information
dakenf authored Aug 3, 2023
1 parent 757c42c commit acb9e56
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const dataType = 'f32';

const isl = inputShape.length;
const osl = outputShape.length;
const calculateInputIndexImpl = (): string => `
fn calculateInputIndex(outputIndices: array<u32, ${outputShape.length}>) -> array<u32,${inputShape.length}> {
fn calculateInputIndex(outputIndices: ${outputIndicesHelper.iType}) -> ${inputIndicesHelper.iType} {
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
for (var i = 0; i < ${inputShape.length}; i++) {
for (var i = 0; i < ${isl}; i++) {
if (inputShape[i] == 1) {
inputIndices[i] = 0;
// TODO: IndicesHelper should offer uniform way to get/set indices for all ranks
inputIndices${isl >= 2 ? '[i]' : ''} = 0;
} else {
inputIndices[i] = outputIndices[i + ${outputShape.length - inputShape.length}];
inputIndices${isl >= 2 ? '[i]' : ''} = ${osl > 1 ? `outputIndices[i + ${osl - isl}]` : 'outputIndices'};
}
}
return inputIndices;
Expand Down

0 comments on commit acb9e56

Please sign in to comment.