Skip to content

Commit

Permalink
Avoid producing presentKey/presentValue outputs if pastKey/pastValue …
Browse files Browse the repository at this point in the history
…don't exists.
  • Loading branch information
satyajandhyala committed Aug 17, 2024
1 parent d79e3c5 commit 18a3906
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ const createAttentionProbsProgramInfo = (
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1 && pastKey;
const presentKeyShape = presentKey
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
: undefined;
Expand Down Expand Up @@ -534,7 +534,7 @@ const createVxAttentionScoreProgramInfo = (
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
const presentValue = params.kvNumHeads == null && context.outputCount > 1 && pastValue;
const presentValueShape = presentValue
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
: undefined;
Expand Down Expand Up @@ -668,7 +668,8 @@ export const applyAttention = (
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const outputCount = context.outputCount;
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;

Expand Down

0 comments on commit 18a3906

Please sign in to comment.