diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 8840ef97b4279..9d801872192fd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -371,7 +371,7 @@ const createAttentionProbsProgramInfo = ( ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1; + const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1 && pastKey; const presentKeyShape = presentKey ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] : undefined; @@ -534,7 +534,7 @@ const createVxAttentionScoreProgramInfo = ( const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && context.outputCount > 1; + const presentValue = params.kvNumHeads == null && context.outputCount > 1 && pastValue; const presentValueShape = presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined; @@ -668,7 +668,8 @@ export const applyAttention = ( parameters: AttentionParameters, attributes: AttentionAttrs, ) => { - const outputCount = context.outputCount; + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists + const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;