diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 0000000000000..6a76f7bcdbcb0 --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,7 @@ +name: "CodeQL config" +queries: + - uses: security-extended + - uses: security-and-quality +paths-ignore: + - tests + - build \ No newline at end of file diff --git a/.github/workflows/linux_training.yml b/.github/workflows/linux_training.yml new file mode 100644 index 0000000000000..51af6cd20de7d --- /dev/null +++ b/.github/workflows/linux_training.yml @@ -0,0 +1,55 @@ +name: orttraining-linux-ci-pipeline +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + orttraining-linux-ci-pipeline: + runs-on: ubuntu-24.04 + permissions: + actions: read + contents: read + security-events: write + steps: + - uses: actions/checkout@v4 + - run: | + python3 -m pip install -r tools/ci_build/github/linux/python/requirements.txt + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/codeql/codeql-config.yml + languages: 'cpp' + - run: | + set -e -x + rm -rf build + python3 tools/ci_build/build.py --build_dir build --config Release --enable_training --skip_submodule_sync --parallel --update --build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: sarif-results + upload: failure-only + + - name: filter-sarif + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: sarif-results/cpp.sarif + output: sarif-results/cpp.sarif + + - name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: sarif-results/cpp.sarif \ No newline at end of file diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index f696264aeead7..bf0f1dffb83ee 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -78,7 +78,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | -| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | +| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d13136d252d2a..37eb0e0edc67c 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -163,6 +163,69 @@ export class WebNNBackend { return id; } + // Register WebNN Constant operands from external data. + public registerMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + mountedFiles: Map | undefined, + ): MLOperand { + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (!mountedFiles) { + throw new Error('External mounted files are not available.'); + } + + let filePath = externalFilePath; + if (externalFilePath.startsWith('./')) { + filePath = externalFilePath.substring(2); + } + const fileData = mountedFiles.get(filePath); + if (!fileData) { + throw new Error(`File with name ${filePath} not found in preloaded files.`); + } + + if (dataOffset + dataLength > fileData.byteLength) { + throw new Error('Out of bounds: data offset and length exceed the external file data size.'); + } + + const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer; + let bufferView: ArrayBufferView; + switch (desc.dataType) { + case 'float32': + bufferView = new Float32Array(buffer); + break; + case 'float16': + bufferView = new Uint16Array(buffer); + break; + case 'int32': + bufferView = new Int32Array(buffer); + break; + case 'uint32': + bufferView = new Uint32Array(buffer); + break; + case 'int64': + bufferView = new BigInt64Array(buffer); + break; + case 'uint64': + bufferView = new BigUint64Array(buffer); + break; + case 'int8': + bufferView = new Int8Array(buffer); + break; + case 'uint8': + bufferView = new Uint8Array(buffer); + break; + default: + throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`); + } + + LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`); + + return builder.constant(desc, bufferView); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index fe824a5c4558a..09c786daa3fcd 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; -import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; +import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; import { matMul } from './ops/matmul'; @@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], + ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 832f6e132901e..6a78c8ae3b190 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU import { getMaxComponents, + IndicesHelper, inputVariable, outputVariable, ShaderHelper, @@ -65,14 +66,17 @@ export interface AttentionParameters { broadcastResPosBias: boolean; passPastInKv: boolean; qkvFormat: AttentionQkvFormat; - isPastkvBSNH?: boolean; + softcap?: number; + doRotary?: number; + rotaryInterLeaved?: number; + sommoothSoftmax?: number; + localWindowsSize?: number; } export interface AttentionAttrs { numHeads: number; - kvNumHeads?: number; - isUnidirectional?: number; - maskFilterValue?: number; + isUnidirectional: number; + maskFilterValue: number; scale: number; doRotary: number; qkvHiddenSizes: number[]; @@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => { - const components = getMaxComponents(d); +const initVarStub = ( + seqLensInput: IndicesHelper | undefined, + totalSequenceLengthInput: IndicesHelper | undefined, + initPastSequenceLength: boolean, +) => { + // In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input + if (totalSequenceLengthInput && seqLensInput) { + return ` + let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')}); + let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length); + let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input; + let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input; + total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1; + var past_sequence_length: u32 = 0; + if (is_first_prompt == false) { + past_sequence_length = total_sequence_length - sequence_length; + } + `; + } else { + return ` + ${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''}; + let present_sequence_length = total_sequence_length; + `; + } +}; + +const createInPlaceSoftmaxProgramInfo = ( + input: TensorView, + batchSize: number, + numHeads: number, + pastSequenceLength: number, + sequenceLength: number, + totalSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, +) => { + // Set components to 1 if seqLens is specified, i.e. GroupQueryAttention. + const components = getMaxComponents(seqLens ? 1 : totalSequenceLength); let WG = 64; - const dComp = d / components; - if (dComp < WG) { + const totalSequenceLengthComp = totalSequenceLength / components; + if (totalSequenceLengthComp < WG) { WG = 32; } - const elementsPerThread = Math.ceil(d / components / WG); + const elementsPerThread = Math.ceil(totalSequenceLength / components / WG); const programUniforms: ProgramUniform[] = [ - { type: DataType.float, data: 1 / d }, - { type: DataType.uint32, data: dComp }, + { type: DataType.uint32, data: batchSize }, + { type: DataType.uint32, data: numHeads }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: totalSequenceLengthComp }, { type: DataType.uint32, data: elementsPerThread }, ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); + const inputHelpers = [inputHelper]; + const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputHelper) { + inputHelpers.push(seqLensInputHelper); + } + + const totalSequenceLengthInputHelper = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputHelper) { + inputHelpers.push(totalSequenceLengthInputHelper); + } const elemValueType = tensorTypeToWsglValueType(input.dataType); const uniforms: UniformsArrayType = [ - { name: 'd_inv', type: 'f32' }, - { name: 'd_comp', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + { name: 'total_sequence_length', type: 'u32' }, { name: 'elements_per_thread', type: 'u32' }, ]; return ` var thread_max: array; var thread_sum: array; - ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)} ${shaderHelper.mainStart([WG, 1, 1])} + let batchIdx = workgroup_id.z / uniforms.num_heads; + let headIdx = workgroup_id.z % uniforms.num_heads; + let sequence_length = uniforms.sequence_length; + var total_sequence_length = uniforms.total_sequence_length; + ${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; - + let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; + let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; var thread_max_vector = ${f32Type}(-3.402823e+38f); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } thread_max[local_idx] = ${(() => { @@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } var sum_vector = ${f32Type}(0); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } thread_sum[local_idx] = ${(() => { @@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } if (sum == 0) { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); } } else { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { var f32input = ${f32Type}(x[offset + i]); x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); } } + ${ + seqLens + ? ` + for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { + x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); + }` + : '' + }; }`; }; @@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number name: 'AttentionProbsSoftmax', shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, - getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), + getRunData: () => ({ + outputs: [], + dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + programUniforms, + }), }; }; @@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = ( pastKey: TensorView | undefined, attentionBias: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, pastSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey; + const presentKey = outputCount > 1 && pastKey; + const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads; const presentKeyShape = presentKey - ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] + ? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize] : undefined; - + const nReps = parameters.nReps ? parameters.nReps : 1; // TODO: handle mask - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale; const components = getMaxComponents(parameters.headSize); const vectorizedHeadSize = parameters.headSize / components; const TILE_SIZE = 12; @@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: vectorizedHeadSize }, { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.uint32, data: parameters.headSize }, { type: DataType.float, data: alpha }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; @@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; if (presentKey) { outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default }); @@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputVariable) { + inputVars.push(seqLensInputVariable); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputVariable) { + inputVars.push(totalSequenceLengthInputVariable); + } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; if (presentKey) { @@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'alpha', type: 'f32' as UniformDataElementType }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = ( ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} // x holds the N and y holds the M - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; + let batchIdx = workgroup_id.z / uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; - let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - ${(() => { - if (feedPastKey && presentKey) { - return ` - let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; - let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; - } else { - return ` - let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; - } - })()} - ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.N; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; + let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''}; + let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; + ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { @@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = ( ${(() => { if (feedPastKey && presentKey) { return ` - if (n + local_id.y < uniforms.past_sequence_length) { + if (n + local_id.y < past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; - } else { - tileK[idx] = - key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; + } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; }`; } else { - return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + return ` + if (n + local_id.y < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + }`; } })()} ${ - presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : '' + presentKey + ? `if (n + local_id.y < present_sequence_length) { + present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx]; + }` + : '' } } workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); + value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); } workgroupBarrier(); } - let headOffset = headIdx * uniforms.M * uniforms.N; - if (global_id.y < uniforms.M && global_id.x < uniforms.N) { + if (global_id.y < uniforms.M && global_id.x < total_sequence_length) { + let headOffset = workgroup_id.z * uniforms.M * uniforms.N; let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; var sum: f32 = ${(() => { switch (components) { @@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = ( pastValue: TensorView | undefined, params: AttentionParameters, pastSequenceLength: number, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue; + const presentValue = outputCount > 1 && pastValue; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; const presentValueShape = presentValue - ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] + ? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize] : undefined; const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; @@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: params.vHeadSize }, { type: DataType.uint32, data: params.numHeads }, + { type: DataType.uint32, data: params.headSize }, { type: DataType.uint32, data: repeatedVHiddenSize }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0; @@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; if (presentValue) { outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); @@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLens) { + inputVars.push(seqLensInputVariable!); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInput) { + inputVars.push(totalSequenceLengthInputVariable!); + } const output = outputVariable('output', probs.dataType, outputShape); const outputVars = [output]; if (presentValue) { @@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'v_hidden_size', type: 'u32' }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let batchIdx = workgroup_id.z / uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; let m = global_id.y; let n = global_id.x; - - let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; - ${(() => { - if (feedPastValue && presentValue) { - return ` - let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; - let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; - `; - } else { - return ` - let offsetB = headIdx * uniforms.N * uniforms.K + n; - `; - } - })()} - ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.K; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch + ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''}; + let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; + ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (m < uniforms.M && w + local_id.x < uniforms.K) { @@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = ( ${(() => { if (feedPastValue && presentValue) { return ` - if (w + local_id.y < uniforms.past_sequence_length) { - tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; - } else { - tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N]; + if (w + local_id.y < past_sequence_length) { + tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; + } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N]; } `; } else { return ` - tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N]; - `; + if (w + local_id.y < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N]; + }`; } })()} - ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''} + ${ + presentValue + ? ` + if (w + local_id.y < present_sequence_length) { + present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx]; + }` + : '' + } } workgroupBarrier(); - for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x]; + for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) { + value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x]; } workgroupBarrier(); } // we need to transpose output from BNSH_v to BSND_v - let batchIdx = workgroup_id.z / uniforms.num_heads; - let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads; if (m < uniforms.M && n < uniforms.N) { let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size - + currentBatchHeadNumber * uniforms.N + n; + + headIdx * uniforms.N + n; output[outputIdx] = value; } }`; @@ -671,23 +804,29 @@ export const applyAttention = ( pastValue: TensorView | undefined, attentionBiasInput: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { - // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); - const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; + const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const attentionBias = attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; const inputsK = [q, k]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { + if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { inputsK.push(pastKey); } if (attentionBias) { inputsK.push(attentionBias); } - + if (seqLens) { + inputsK.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsK.push(totalSequenceLengthInput); + } // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( @@ -697,31 +836,55 @@ export const applyAttention = ( pastKey, attentionBias, parameters, - attributes, pastSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] }, + { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] }, )[0]; // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( probs, - parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.batchSize, + parameters.numHeads, + pastSequenceLength, + parameters.sequenceLength, totalSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: [probs], outputs: [] }, + { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] }, ); - // Run AttrionScore + // Run AttentionScore const inputsV = [probs, v]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { + if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); } - context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), { - inputs: inputsV, - outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0], - }); + if (seqLens) { + inputsV.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsV.push(totalSequenceLengthInput); + } + context.compute( + createVxAttentionScoreProgramInfo( + outputCount, + probs, + v, + pastValue, + parameters, + pastSequenceLength, + seqLens, + totalSequenceLengthInput, + ), + { + inputs: inputsV, + outputs: outputCount > 1 ? [0, 2] : [0], + }, + ); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { @@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs): undefined, context.inputs[5], params, - attributes, ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 56291c037b7da..bbe25460d6fd3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -1,31 +1,49 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { ShapeUtil } from '../../util'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; -import { - applyAttention, - AttentionAttrs, - AttentionMaskType, - AttentionParameters, - AttentionQkvFormat, -} from './attention'; -import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; -import { createTileProgramInfo } from './tile'; +import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; - -export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { +export interface GroupQueryAttentionAttributes { + numHeads: number; + kvNumHeads: number; + scale: number; + softcap: number; + doRotary: number; + rotaryInterleaved: number; + smoothSoftmax: boolean; + localWindowSize: number; +} + +export const validateInputs = ( + inputs: readonly TensorView[], + attributes: GroupQueryAttentionAttributes, +): AttentionParameters => { + if (attributes.doRotary && inputs.length <= 7) { + throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); + } const query = inputs[0]; const key = inputs[1]; const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; - + if (attributes.localWindowSize !== -1) { + throw new Error('Local attention is not supported'); + } + if (attributes.softcap !== 0) { + throw new Error('Softcap is not supported'); + } + if (attributes.rotaryInterleaved !== 0) { + throw new Error('Rotary interleaved is not supported'); + } + if (attributes.smoothSoftmax) { + throw new Error('Smooth softmax is not supported'); + } // Abbreviation and Meanings: // B: batch_size // S: sequence_length (input sequence length of query) @@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = + let hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; - let maxSequenceLength = 0; - const headSize = Math.floor(hiddenSize / attributes.numHeads); + const packedQKV = !key || key.dims.length === 0; + const headSize = !packedQKV + ? Math.floor(hiddenSize / attributes.numHeads) + : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads)); + if (packedQKV) { + hiddenSize = headSize * attributes.numHeads; + } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; - // TODO : this should be from attributes. - const isPastkvBSNH = true; + // Currenly the onnxruntime GQA specification only support key/value BNSH format. + const isPastkvBSNH = + hasPastKey && + pastKey.dims.length === 4 && + pastKey.dims[0] === batchSize && + pastKey.dims[1] !== attributes.kvNumHeads && + pastKey.dims[2] === attributes.kvNumHeads && + pastKey.dims[3] === headSize; + + if (isPastkvBSNH) { + throw new Error('BSNH pastKey/pastValue is not supported'); + } if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); @@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (pastValue.dims.length !== 4) { throw new Error('Input "past_value" is expected to have 4 dimensions'); } - if (isPastkvBSNH) { - // For BSNH - pastSequenceLength = pastKey.dims[1]; - maxSequenceLength = pastKey.dims[1]; - } else { - // For BNSH - pastSequenceLength = pastKey.dims[2]; - maxSequenceLength = pastKey.dims[2]; - } + pastSequenceLength = pastKey.dims[2]; } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } - let qkvFormat: AttentionQkvFormat; - if (key) { + let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH; + if (key && key.dims.length > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } @@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (query.dims[2] % key.dims[2] !== 0) { throw new Error('Dimension 2 of "query" should be a multiple of "key"'); } - qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { @@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (value) { throw new Error('Expect "value" be none when "key" has packed kv format.'); } - qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - - qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } } else { @@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const maskType: AttentionMaskType = AttentionMaskType.none; let passPastInKv = false; - let vHiddenSize = hiddenSize; - if (value) { + let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize; + if (value && value.dims.length > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } @@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const seqlLens = inputs.length > 4 ? inputs[5] : undefined; + if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { + throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); + } + const totalSequenceLength = -1; + const maxSequenceLength = -1; const broadcastResPosBias = false; return { @@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent hiddenSize, vHiddenSize, headSize, - vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!), + vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads), numHeads: attributes.numHeads, kvNumHeads: attributes.kvNumHeads, - nReps: attributes.numHeads / attributes.kvNumHeads!, + nReps: attributes.numHeads / attributes.kvNumHeads, pastPresentShareBuffer: false, maskType, scale: attributes.scale, broadcastResPosBias, passPastInKv, qkvFormat, - isPastkvBSNH, }; }; -const createConcatProgramInfo = ( - a: TensorView, - b: TensorView | undefined, - dataType: DataType, - params: AttentionParameters, -): ProgramInfo => { - const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; - const component = 4; - const outputSize = ShapeUtil.size(outputShape) / component; - const presentSequenceLength = params.totalSequenceLength; - const output = outputVariable('present_kv', dataType, outputShape.length, component); - const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); - const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; - - const H = Math.ceil(params.headSize / component); - const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 }; - - const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: params.pastSequenceLength }, - { type: DataType.uint32, data: params.kvSequenceLength }, - { type: DataType.uint32, data: params.totalSequenceLength }, - ]; - - const inputs = [inputA]; - if (inputB) { - programUniforms.push( - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b!.dims), - ...createTensorShapeVariables(outputShape), - ); - inputs.push(inputB); - } else { - programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); - } - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'past_seqlen', type: 'u32' }, - { name: 'new_seqlen', type: 'u32' }, - { name: 'present_seqlen', type: 'u32' }, - ]; - - const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; - var past_head_stride = uniforms.past_seqlen * H; - if (is_bsnh) { - past_head_stride = H; - } - let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset];`; - const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; - let new_row_stride = num_heads * H; - let new_head_stride = H; - let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset];`; - const concatStr = b - ? `if (s < past_seqlen) { - ${pastStr} - } else if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }` - : `if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }`; - - // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} - ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var indices = ${output.offsetToIndices('global_idx')}; - let h = local_id.x; - let n = local_id.y; - let s = workgroup_id.x; - let b = workgroup_id.y; - let num_heads = ${params.kvNumHeads!}u; - let H = ${H}u; - - let present_seqlen = uniforms.present_seqlen; - let present_batch_stride = present_seqlen * num_heads * H; - var row_stride = H; - let is_bsnh = ${params.isPastkvBSNH}; - - if (is_bsnh) { - row_stride = num_heads * H; - } - var present_head_stride = present_seqlen * H; - if (is_bsnh) { - present_head_stride = H; - } - - let past_seqlen = uniforms.past_seqlen; - - let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - ${concatStr} - }`; - - return { - name: 'ConcatPastNew', - shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType }], - dispatchGroup: dispatch, - programUniforms, - }), - getShaderSource, - }; -}; - -export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({ ...attributes }); - const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); -const maybeExpandAndTransposeToBNSH = ( - context: ComputeContext, - input: TensorView, - pastKV: TensorView | undefined, - params: AttentionParameters, - outputIndex: number, -) => { +const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => { let reshapedInput = input; const numHeads = params.kvNumHeads!; - const nReps = params.nReps!; if (input.dims.length === 3 && params.kvSequenceLength !== 0) { reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); - } - - if (pastKV) { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), { - inputs: [reshapedInput, pastKV], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } else { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), { - inputs: [reshapedInput], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } - if (nReps !== 1) { - reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), { + reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], })[0]; - reshapedInput = reshapedInput.reshape([ - params.batchSize, - params.totalSequenceLength, - numHeads * nReps, - params.headSize, - ]); } - return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { - inputs: [reshapedInput], - outputs: [-1], - })[0]; + return reshapedInput; }; -export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { +export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); @@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti throw new Error('Packed KV is not implemented'); } + const q = context.inputs[0]; + const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; + const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; + const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; + const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; + const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; + const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; + + // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. + + const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ + axis: 2, + numOutputs: 3, + splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], + }); + const [query, key, value] = + !k && !v + ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) + : [q, k!, v!]; + const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, - context.inputs[0], + query, undefined, 0, ); - const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; - const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; - const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); - const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2); - applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); + applyAttention( + context, + Q, + maybeTransposeToBNSH(context, key, params), + maybeTransposeToBNSH(context, value, params), + undefined, + undefined, + pastKey, + pastValue, + undefined, + params, + seqLens, + totalSequenceLengthInput, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 1a31253905694..db7a4b8e68b79 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio ); if (kvBNSH) { - return applyAttention( - context, - Q, - key, - value, - keyPaddingMask, - undefined, - pastKey, - pastValue, - attentionBias, - params, - attributes, - ); + return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); } if (!key || !value) { throw new Error('key and value must be provided'); @@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 1dc3a206cf94b..8c39505734e41 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { +export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc index 2a4b265078456..036069f43eb54 100644 --- a/js/web/test/data/ops/group-query-attention.jsonc +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -1,6 +1,316 @@ [ { - "name": "GroupQueryAttention Basic", + "name": "GroupQueryAttention 0", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 1", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // past value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [40, 41, 42, 43, 44, 45, 46, 47, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 2", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 72, 73, 74, 75, 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 3", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 4", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -12,44 +322,293 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 3, 16], + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + // key, BS* + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // value, BS* + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // past key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // past value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 8, 9, 10, 11, 12, + 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 64, 65, 66, 67, 68, 69, 70, 71, 80, 81, 82, 83, 84, 85, 86, 87, 56, 57, + 58, 59, 60, 61, 62, 63, 72, 73, 74, 75, 76, 77, 78, 79, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 2, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 5", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 6", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 7", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], "type": "float32" }, // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], "dims": [1, 3, 8], "type": "float32" }, // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], "dims": [1, 3, 8], "type": "float32" }, // past key, BS* { - "data": null, + "data": [96, 97, 98, 99, 100, 101, 102, 103], + "dims": [1, 1, 1, 8], "type": "float32" }, // past value, BS* { - "data": null, + "data": [104, 105, 106, 107, 108, 109, 110, 111], + "dims": [1, 1, 1, 8], "type": "float32" }, // seqlens_k, unimplemented { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, // total_sequence_length, unimplemented { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -57,22 +616,28 @@ "outputs": [ { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, - 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21 + 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, + 109, 110, 111 ], - "dims": [1, 3, 16], + "dims": [1, 3, 8], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present key, BNSH + "data": [ + 96, 97, 98, 99, 100, 101, 102, 103, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 66, 67, 68, 69, 70, 71 + ], + "dims": [1, 1, 4, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present value, BNSH + "data": [ + 104, 105, 106, 107, 108, 109, 110, 111, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -80,13 +645,12 @@ ] }, { - "name": "GroupQueryAttention Scale", + "name": " GroupQueryAttention 8", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" }, - { "name": "scale", "data": 2.0, "type": "float" } + { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ { @@ -94,38 +658,43 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -135,35 +704,34 @@ "outputs": [ { "data": [ - 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, - 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731 + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 2, 1, 8], "type": "float32" } ] } ] }, - { - "name": "GroupQueryAttention, different sequence length", + "name": "GroupQueryAttention 9", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "num_heads", "data": 2, "type": "int" }, { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ @@ -171,39 +739,41 @@ "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 4, 8], + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -212,23 +782,20 @@ ], "outputs": [ { - "data": [ - 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823, - 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, - 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1, - 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815 - ], - "dims": [1, 4, 8], + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 1, 8], "type": "float32" }, { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" } ] @@ -236,12 +803,164 @@ ] }, { - "name": "GroupQueryAttention Basic, q k v same head number", + "name": "GroupQueryAttention 10", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 4, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 16], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 11", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [2], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 12", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -249,45 +968,49 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // key, BS* { "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, + // value, BS* { "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 32], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -297,26 +1020,28 @@ "outputs": [ { "data": [ - 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, - 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 16], + "dims": [1, 1, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" }, { + // present value, BNSH "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 3, 4, 4], + "dims": [1, 1, 1, 32], "type": "float32" } ] @@ -324,12 +1049,12 @@ ] }, { - "name": "GroupQueryAttention, no past kv, used as reference", + "name": "GroupQueryAttention 13", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -337,50 +1062,51 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, + // key, BS* { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, + // value, BS* { "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 8], + "dims": [1, 4, 8], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [4], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [4], "dims": [1], "type": "int32" } @@ -388,29 +1114,28 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 16], + "dims": [1, 4, 8], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 ], - "dims": [1, 7, 2, 4], + "dims": [1, 1, 4, 8], "type": "float32" } ] @@ -418,12 +1143,12 @@ ] }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1", + "name": "GroupQueryAttention PackedQKV 14", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -431,52 +1156,41 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 ], - "dims": [1, 7, 16], + "dims": [1, 1, 32], "type": "float32" }, - // new key, BS* + // key, BS* { - "data": [ - 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value, BS* { - "data": [ - 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 6, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7], - "dims": [1, 1, 2, 4], + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { "data": [1], "dims": [1], @@ -485,30 +1199,20 @@ ], "outputs": [ { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], "type": "float32" }, { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], "type": "float32" }, { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], "type": "float32" } ] @@ -516,7 +1220,7 @@ ] }, { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2", + "name": "GroupQueryAttention PackedQKV 15", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -529,54 +1233,48 @@ "inputs": [ { "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, + 22, 21, 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, + 1, 3, 4, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, + 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, + 2, 131, 22, 21 ], - "dims": [1, 7, 16], + "dims": [1, 3, 64], "type": "float32" }, - // new key, BS* + // key { - "data": [ - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, - 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // new value, BS* + // value { - "data": [ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 5, 8], + "data": null, "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - "dims": [1, 2, 2, 4], + "data": [], + "dims": [1, 2, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } @@ -584,29 +1282,29 @@ "outputs": [ { "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 + 1, 9, 1, 1, 2, 2, 2, 2, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 1, 12, 21, 131, 22, 21, 2, + 2, 8, 12, 233, 4, 5, 6, 7, 8, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4, + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4 ], - "dims": [1, 7, 16], + "dims": [1, 3, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 131, 22, 21, 2, 2, 131, 22, 21, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 2, 2, 2 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + 1, 9, 1, 1, 2, 2, 2, 2, 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 5, 6, 7, 8, 1, 1, 3, 4, 131, 22, 21, 2, 2, 131, 22, 21 ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" } ] diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h index 7553883a2478d..dff8663133c31 100644 --- a/onnxruntime/contrib_ops/js/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#include "contrib_ops/cpu/bert/gqa_attention_base.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { @@ -11,31 +11,29 @@ namespace js { using onnxruntime::js::JsKernel; -class GroupQueryAttention : public JsKernel { +class GroupQueryAttention : public JsKernel, GQAAttentionBase { public: explicit GroupQueryAttention(const OpKernelInfo& info) - : JsKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - scale_ = info.GetAttrOrDefault("scale", 0.0f); + : JsKernel(info), GQAAttentionBase(info, false) { JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({ "numHeads" : $1, "kvNumHeads" : $2, "scale" : $3, + "softcap" : $4, + "doRotary" : $5, + "rotaryInterleaved" : $6, + "smoothSoftmax" : $7, + "localWindowSize" : $8 }), static_cast(num_heads_), static_cast(kv_num_heads_), - static_cast(scale_)); + static_cast(scale_), + static_cast(softcap_), + static_cast(do_rotary_), + static_cast(rotary_interleaved_), + static_cast(use_smooth_softmax_), + static_cast(local_window_size_)); } - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // number of k and v heads - float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size) }; } // namespace js diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 83eec9632054c..0b82e2859f0d2 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -165,49 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size, - bool& pre_packed) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), - "Tensor does not have external data to read from."); - - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), - "External data type cannot be UNDEFINED or STRING."); - - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - - pre_packed = external_data_info->GetPrePacked(); - - const auto& location = external_data_info->GetRelPath(); - - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); - - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); - const size_t external_data_length = external_data_info->GetLength(); - ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, - "TensorProto: ", tensor_proto.name(), - " external data size mismatch. Computed size: ", *&tensor_byte_size, - ", external_data.length: ", external_data_length); - - file_offset = external_data_info->GetOffset(); - - return Status::OK(); -} - -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { - bool pre_packed = false; - return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); -} - // Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully. // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. @@ -273,6 +230,40 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +inline Status utils::GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size, + bool& pre_packed) { + ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + "Tensor does not have external data to read from."); + + ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + "External data type cannot be UNDEFINED or STRING."); + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + + pre_packed = external_data_info->GetPrePacked(); + + const auto& location = external_data_info->GetRelPath(); + + external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + const size_t external_data_length = external_data_info->GetLength(); + ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, + "TensorProto: ", tensor_proto.name(), + " external data size mismatch. Computed size: ", *&tensor_byte_size, + ", external_data.length: ", external_data_length); + + file_offset = external_data_info->GetOffset(); + + return Status::OK(); +} + void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { tensor_proto.set_raw_data(std::move(param)); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 3a3bd1a881fce..4ecea2161577a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -25,6 +25,32 @@ namespace onnxruntime { namespace utils { /** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @pre_packed whether output tensor is prepacked buffer serialize as tensor + * @returns Status::OK() if the function is executed successfully + */ +inline Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size, + bool& pre_packed); + +inline Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { + bool pre_packed = false; + return utils::GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); +} + + /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file * generated on little endian system. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index eaffe1e2ac224..34dcbd1d77fca 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -302,13 +302,21 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } Status QnnBackendManager::ResetQnnLogLevel() { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + std::lock_guard lock(logger_mutex_); + + if (backend_setup_completed_ && logger_ != nullptr) { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); + } + return Status::OK(); } Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); + ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -686,6 +694,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { + std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; return Status::OK(); @@ -972,6 +981,7 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } + std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b80f1374fcdc7..43007d4a5c244 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -12,9 +12,11 @@ #endif #include +#include #include #include #include + #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" #include "QnnTypes.h" @@ -233,6 +235,7 @@ class QnnBackendManager { private: const std::string backend_path_; + std::mutex logger_mutex_; const logging::Logger* logger_ = nullptr; QNN_INTERFACE_VER_TYPE qnn_interface_ = QNN_INTERFACE_VER_TYPE_INIT; QNN_SYSTEM_INTERFACE_VER_TYPE qnn_sys_interface_ = QNN_SYSTEM_INTERFACE_VER_TYPE_INIT; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index becb9a728b1e3..6735528bebbf9 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -258,49 +258,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); - } - } - } - - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); -#endif - // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { @@ -440,6 +397,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_arch, soc_model, enable_htp_weight_sharing_); + +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); +#endif } QNNExecutionProvider::~QNNExecutionProvider() { @@ -453,7 +453,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #ifdef _WIN32 - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 30e2fd53e9613..35c061de6132c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -151,7 +151,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 - onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index aecb1f7a03bb9..ec9993bf138ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -36,6 +36,31 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type); // Collects all the initializer tensors in the subGraph and its ancestor graphs. InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); +inline std::vector convertAxesFromNCHWtoNHWC(const std::vector& axes) { + constexpr std::array nchw_to_nhwc = {0, 3, 1, 2}; + std::vector new_axes; + new_axes.reserve(axes.size()); + for (int64_t axis : axes) { + if (axis >= nchw_to_nhwc.size()) { + ORT_THROW("Invalid axis value: ", axis); + } + new_axes.push_back(nchw_to_nhwc[static_cast(axis)]); + } + return new_axes; +} + +inline std::vector HandleNegativeAxes(const std::vector& axes, size_t input_size) { + std::vector new_axes(axes.size()); + for (size_t i = 0; i < axes.size(); ++i) { + new_axes[i] = HandleNegativeAxis(axes[i], input_size); + } + return new_axes; +} + +inline std::vector GetResolvedAxes(const NodeAttrHelper& helper, size_t input_size) { + return HandleNegativeAxes(helper.Get("axes", std::vector{}), input_size); +} + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -144,6 +169,17 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va return true; } +inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::string& name) { + if (name.empty() || !Contains(initializers, name)) { + return true; + } + + const auto& tensor = *initializers.at(name); + const auto dims = tensor.dims(); + // An empty tensor contains a 0 in the dimensions list. + return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); +} + bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 8da255a288f17..fffe964e6aaf2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -12,27 +12,6 @@ namespace onnxruntime { namespace webnn { - -// Shared functions. -bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, - const logging::Logger& logger) { - for (const auto* node_arg : node.InputDefs()) { - const auto& input_name(node_arg->Name()); - if (!Contains(initializers, input_name)) - continue; - - const auto& tensor = *initializers.at(input_name); - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS(logger, VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; - return true; - } - } - - return false; -} - // Add operator related. Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, @@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; - // We do not support external initializers for now. - if (HasExternalInitializer(initializers, node, logger)) - return false; - if (!HasSupportedOpSet(node, logger)) return false; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f03e5b90ff6db..329db75316e82 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -133,7 +133,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, const auto out_t = dims[0], in_t = dims[1], h_t = dims[2], w_t = dims[3]; std::vector dest_shape; - if (is_conv == 1) + if (is_conv) dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight @@ -265,7 +265,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); if (is_constant_weight) { - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); } } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 9dc79f4f52f46..3442afbc2b3cd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -38,16 +38,33 @@ class ResizeOpBuilder : public BaseOpBuilder { }; // Helper functions -bool GetResizeScales(const InitializedTensorSet& initializers, - const Node& node, std::vector& scales, - const logging::Logger& logger) { +bool GetResizeScalesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& scales, + std::vector& axes, const bool is_nhwc, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 3) return false; + const bool has_axes = !axes.empty(); const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); - if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + if (scales_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'scales' should be a 1D tensor."; return false; + } + + // Number of elements of 'scales' tensor. + const auto num_of_scales = scales_tensor.dims()[0]; + + if (has_axes && num_of_scales != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'scales' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_scales != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'scales' should have 4 elements."; + return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); @@ -56,20 +73,65 @@ bool GetResizeScales(const InitializedTensorSet& initializers, return false; } const float* scales_data = reinterpret_cast(unpacked_tensor.data()); - scales = std::vector{scales_data, scales_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'scales' should have 2 elements. + scales = std::vector{scales_data, scales_data + 2}; + } else { + // Before opset 18, 'scales' should have 4 elements. + // Make sure 'scales' is not trying to scale on N/C channels here. + std::vector onnx_scales{scales_data, scales_data + 4}; + // 'scales' input has been transposed to NHWC layout if it is NHWC preferred layout. + const float scale_n = onnx_scales[0]; + const float scale_c = is_nhwc ? onnx_scales[3] : onnx_scales[1]; + const float scale_h = is_nhwc ? onnx_scales[1] : onnx_scales[2]; + const float scale_w = is_nhwc ? onnx_scales[2] : onnx_scales[3]; + if (scale_n != 1.0f || scale_c != 1.0f) { + LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" + << "Scales of N/C channels are not supported" + << ", scale_n, " << scale_n << ", scale_c, " << scale_c; + return false; + } + + scales = {scale_h, scale_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert axes from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } -bool GetResizeOutputSizes(const InitializedTensorSet& initializers, - const Node& node, std::vector& sizes, - const logging::Logger& logger) { +bool GetResizeSizesAndAxes(const InitializedTensorSet& initializers, + const Node& node, std::vector& sizes, + std::vector& axes, const bool is_nhwc, + const gsl::span& input_shape, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 4) return false; + const bool has_axes = !axes.empty(); const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); - if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + if (sizes_tensor.dims_size() != 1) { + LOGS(logger, ERROR) << "'sizes' should be a 1D tensor."; + return false; + } + + // Number of elements of sizes tensor. + const auto num_of_sizes = sizes_tensor.dims()[0]; + if (has_axes && num_of_sizes != 2) { + LOGS(logger, ERROR) << "When 'axes' is provided, 'sizes' should have 2 elements."; + return false; + } + + if (!has_axes && num_of_sizes != 4) { + LOGS(logger, ERROR) << "When 'axes' is not provided, 'sizes' should have 4 elements."; return false; + } std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); @@ -78,7 +140,35 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, return false; } const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); - sizes = std::vector{sizes_data, sizes_data + 4}; + + if (has_axes) { + // 'axes' is specified since opset 18+, 'sizes' should have 2 elements. + sizes = std::vector{sizes_data, sizes_data + 2}; + } else { + // Before opset 18, 'sizes' should have 4 elements. + // Make sure 'sizes' is not trying to resize on N/C channels here. + std::vector onnx_sizes{sizes_data, sizes_data + 4}; + auto size_n = onnx_sizes[0]; + const int c_idx = is_nhwc ? 3 : 1; + if (size_n != input_shape[0] || onnx_sizes[c_idx] != input_shape[c_idx]) { + LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " + << "Resize of N/C channels are not supported" + << ", input_size_n, " << input_shape[0] << ", output_size_n, " << size_n + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << onnx_sizes[c_idx]; + return false; + } + // 'sizes' input has been transposed to NHWC layout if it is NHWC preferred layout. + const int64_t sizes_h = is_nhwc ? onnx_sizes[1] : onnx_sizes[2]; + const int64_t sizes_w = is_nhwc ? onnx_sizes[2] : onnx_sizes[3]; + sizes = {sizes_h, sizes_w}; + axes = {2, 3}; + } + + if (is_nhwc) { + // For NHWC preferred layout, we need to convert 'axes' from NCHW to NHWC. + axes = convertAxesFromNCHWtoNHWC(axes); + } + return true; } @@ -103,9 +193,15 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& initializers(model_builder.GetInitializerTensors()); + NodeAttrHelper helper(node); + emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { options.set("mode", emscripten::val("linear")); @@ -113,45 +209,30 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("mode", emscripten::val("nearest-neighbor")); } - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - std::vector scales; - std::vector sizes; - std::vector scales_hw; - std::vector sizes_hw; - std::vector axes; - std::string scales_name = GetTensorName(input_defs, 2); + std::vector sizes; + std::vector webnn_sizes; + std::vector axes = GetResolvedAxes(helper, 4); // We already checked input shape is 4D in IsOpSupportedImpl. + std::string sizes_name = GetTensorName(input_defs, 3); const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; - if (!scales_name.empty()) { // Use scales. - ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - if (is_nhwc) { - scales_hw = {scales[1], scales[2]}; - } else { - scales_hw = {scales[2], scales[3]}; - } - options.set("scales", emscripten::val::array(scales_hw)); - } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. - std::vector output_sizes; - ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), - "Error getting resize output_sizes"); - std::transform(output_sizes.cbegin(), output_sizes.cend(), - std::back_inserter(sizes), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); - if (is_nhwc) { - sizes_hw = {sizes[1], sizes[2]}; - } else { - sizes_hw = {sizes[2], sizes[3]}; - } - options.set("sizes", emscripten::val::array(sizes_hw)); - } - if (is_nhwc) { - axes = {1, 2}; + // We know we have either a 'scales' or 'sizes' input so this is safe. + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + bool using_sizes = !sizes_name.empty() && Contains(initializers, sizes_name); + if (using_sizes) { + ORT_RETURN_IF_NOT(GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger), + "Error getting Resize sizes"); + webnn_sizes = GetVecUint32FromVecInt64(sizes); + options.set("sizes", emscripten::val::array(webnn_sizes)); } else { - axes = {2, 3}; + ORT_RETURN_IF_NOT(GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger), + "Error getting Resize scales"); + options.set("scales", emscripten::val::array(scales)); } - options.set("axes", emscripten::val::array(axes)); + + std::vector webnn_axes = GetVecUint32FromVecInt64(axes); + options.set("axes", emscripten::val::array(webnn_axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = model_builder.GetBuilder().call("resample2d", input, options); @@ -166,6 +247,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) @@ -179,92 +261,81 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } { // Check attributes. - NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - bool is_linear_resize = mode == "linear"; - bool is_nearest_resize = mode == "nearest"; - // WebNN only supports "linear" and "nearest" modes. - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; + // antialias + if (helper.Get("antialias", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support antialias"; return false; } - const auto exclude_outside = helper.Get("exclude_outside", 0); - if (exclude_outside != 0) { - LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + // coordinate_transformation_mode + // Spec issue for supporting more coordinate transformation modes: + // https://github.com/webmachinelearning/webnn/issues/270 + const std::string coordinate_transformation_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + if (coordinate_transformation_mode != "half_pixel") { + LOGS(logger, VERBOSE) << "Resize does not support coordinate_transformation_mode: " + << coordinate_transformation_mode; return false; } - } - { // scales and sizes (if present) must be initializers. - const std::string scales_name = GetTensorName(input_defs, 2); - const std::string sizes_name = GetTensorName(input_defs, 3); - - // scales (scales may be empty tensor) - bool has_scales = !scales_name.empty(); - if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + // exclude_outside + const auto exclude_outside = helper.Get("exclude_outside", 0); + if (exclude_outside != 0) { + LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; return false; } - // sizes (sizes may be empty tensor) - bool has_sizes = !sizes_name.empty(); - if (has_sizes && !Contains(initializers, sizes_name)) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + // keep_aspect_ratio_policy + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (keep_aspect_ratio_policy != "stretch") { + LOGS(logger, VERBOSE) << "Resize does not support keep_aspect_ratio_policy: " << keep_aspect_ratio_policy; return false; } - if (has_scales && has_sizes) { - LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified"; + // mode + const auto mode = helper.Get("mode", "nearest"); + bool is_linear_resize = mode == "linear"; + bool is_nearest_resize = mode == "nearest"; + // WebNN only supports "linear" and "nearest" modes. + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; return false; } + } - const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; - // We want to check if the scales or sizes are not trying to resize on N/C channels here. - if (has_scales) { // We are using scales. - std::vector scales; - if (!GetResizeScales(initializers, node, scales, logger)) - return false; - - float scale_n = scales[0]; - float scale_c = is_nhwc ? scales[3] : scales[1]; - if (scale_n != 1.0f || scale_c != 1.0f) { - LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" - << "Resize of N/C channels are not supported" - << ", scale_n, " << scale_n << ", scale_c, " << scale_c; - return false; - } + { // 'scales' and 'sizes' (if present) must be non-empty initializers. + const std::string scales_name = GetTensorName(input_defs, 2); + const std::string sizes_name = GetTensorName(input_defs, 3); - // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. - // TODO support ResizeBilinear. - float scale_h = is_nhwc ? scales[1] : scales[2]; - float scale_w = is_nhwc ? scales[2] : scales[3]; + // Check for 'sizes' first. + // This handles Resize-11 where 'scales' was a required input but 'sizes' were used if provided. + // 'scales' or 'sizes' may be empty tensor. + bool using_sizes = !IsEmptyTensor(initializers, sizes_name); + bool using_scales = !using_sizes && !IsEmptyTensor(initializers, scales_name); - // Onnx spec requires scale to be a positive float, so we are not checking that here. - if (roundf(scale_h) != scale_h) { - LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; - return false; - } + if (!using_scales && !using_sizes) { + LOGS(logger, VERBOSE) << "Resize: only one of 'scales' and 'sizes' can be specified"; + return false; + } - if (roundf(scale_w) != scale_w) { - LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + // 'axes' is from opset 18 on and allows 'scales' or 'sizes' to have entries for the subset of 'axes'. + // We fill with default values if necessary so that the processing is consistent across all supported opsets. + std::vector axes = GetResolvedAxes(helper, input_size); + if (!axes.empty()) { // We have 'axes' attribute. + if (axes.size() != 2 || axes[0] >= input_size || axes[1] >= input_size) { + LOGS(logger, VERBOSE) << "Resize: invalid axes attribute"; return false; } } - if (has_sizes) { - // We are using sizes. - std::vector output_sizes; - if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; + if (using_sizes) { // We are using 'sizes'. + std::vector sizes; + if (!GetResizeSizesAndAxes(initializers, node, sizes, axes, is_nhwc, input_shape, logger)) { return false; - - auto output_size_n = output_sizes[0]; - const int c_idx = is_nhwc ? 3 : 1; - if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { - LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " - << "Resize of N/C channels are not supported" - << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; + } + } else { // We are using 'scales'. + std::vector scales; + if (!GetResizeScalesAndAxes(initializers, node, scales, axes, is_nhwc, logger)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 044baa738e8c4..8a7fea0cde431 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() { auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; - if (tensor.has_raw_data()) { - tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + + if (utils::HasExternalData(tensor)) { + // Create WebNN Constant from external data. + std::basic_string external_file_path; + onnxruntime::FileOffsetType data_offset; + SafeInt tensor_byte_size; + ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo( + tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size)); + + auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant"); + operand = jsepRegisterMLConstant(emscripten::val(external_file_path), + static_cast(data_offset), + static_cast(tensor_byte_size), + wnn_builder_, + desc); } else { - // Store temporary unpacked_tensor. - unpacked_tensors_.push_back({}); - std::vector& unpacked_tensor = unpacked_tensors_.back(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); - tensor_ptr = reinterpret_cast(unpacked_tensor.data()); - } - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - default: - break; + if (tensor.has_raw_data()) { + tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + } else { + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + tensor_ptr = reinterpret_cast(unpacked_tensor.data()); + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + default: + break; + } + + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached + // buffers in JS side. Simply create a copy to fix it. + operand = wnn_builder_.call("constant", desc, view.call("slice")); } - - // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached - // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); } else { // TODO: support other type. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 8c9f0ba0f21be..8ff5990b7815a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -3,7 +3,7 @@ diffusers==0.28.0 transformers==4.41.2 numpy>=1.24.1 accelerate -onnx==1.16.0 +onnx==1.17.0 coloredlogs packaging # Use newer version of protobuf might cause crash diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..78d60326dd0a8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => { Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); } + + Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { + return backend['registerMLConstant']( + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + } } }; diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 50d4d8a912585..4e5d9a70beb66 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -34,11 +34,8 @@ stages: # build Python packages # Linux GPU only - ${{ if parameters.BuildPythonPackages }}: - - template: templates/py-packaging-stage.yml + - template: stages/py-gpu-packaging-stage.yml parameters: enable_linux_gpu: true - enable_linux_cpu: false - enable_windows_cpu: false enable_windows_gpu: false - enable_mac_cpu: false - enable_linux_arm: false + cuda_version: 12.2 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml deleted file mode 100644 index 5c3273f79bd30..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ /dev/null @@ -1,95 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- job: Linux_Build - timeoutInMinutes: 180 - workspace: - clean: all - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - pool: onnxruntime-Ubuntu-2204-Training-CPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" - Repository: onnxruntimecpubuildcentos8x64_packaging - - - task: Cache@2 - inputs: - key: '"$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - cacheHitVar: CACHE_RESTORED - restoreKeys: | - "$(TODAY)" | "$(Build.SourceBranch)" - "$(TODAY)" | - displayName: Cach Task - - - task: CmdLine@2 - displayName: 'build' - inputs: - script: | - set -e -x - mkdir -p $HOME/.onnx - mkdir -p $(Pipeline.Workspace)/ccache - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume /data/models:/build/models:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(Pipeline.Workspace)/ccache:/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - onnxruntimecpubuildcentos8x64_packaging \ - /onnxruntime_src/tools/ci_build/github/linux/build_training_ci.sh - workingDirectory: $(Build.SourcesDirectory) - - - task: PublishTestResults@2 - displayName: 'Publish unit test results' - inputs: - testResultsFiles: '**/*.results.xml' - searchFolder: '$(Build.BinariesDirectory)' - testRunTitle: 'Unit Test Run' - condition: succeededOrFailed() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml deleted file mode 100644 index 494035637a79d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ /dev/null @@ -1,55 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- template: templates/linux-ci.yml - parameters: - AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' - JobName: 'Onnxruntime_Linux_GPU_Training' - RunDockerBuildArgs: > - -o ubuntu20.04 -d gpu - -t onnxruntime_orttraining_ortmodule_tests_image - -u - -e - -x " - --enable_training - --config Release - --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 - --build_wheel - --enable_nvtx_profile - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 - " - RunInjectedPipeline: 'true' - InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' - DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - TimeoutInMinutes: 190 - # Enable unreleased onnx opsets in CI builds - # This facilitates testing the implementation for the new opsets - AllowReleasedOpsetOnly: '0' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml new file mode 100644 index 0000000000000..cc2977721d03b --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml @@ -0,0 +1,27 @@ +trigger: none + +parameters: + - name: enable_linux_cuda + type: boolean + default: true + + - name: enable_windows_cuda + type: boolean + default: true + + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml index 3503857a9233c..7e6b1889687a3 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -1,12 +1,20 @@ trigger: none - +# The `resources` specify the location and version of the 1ES PT. +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release parameters: - - name: enable_linux_gpu + - name: enable_linux_cuda type: boolean default: true - - name: enable_windows_gpu + + - name: enable_windows_cuda type: boolean default: true + - name: cmake_build_type type: string default: 'Release' @@ -15,28 +23,22 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - - name: cuda_version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 - - name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - - name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + # Update the pool with your team's 1ES hosted pool. + pool: + name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool + os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS -stages: - - template: stages/py-cuda-packaging-stage.yml - parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} - cmake_build_type: ${{ parameters.cmake_build_type }} - cuda_version: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_linux_cuda: ${{ parameters.enable_linux_cuda }} + enable_windows_cuda: ${{ parameters.enable_windows_cuda }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml new file mode 100644 index 0000000000000..0c7c6abeb35da --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml @@ -0,0 +1,18 @@ +trigger: none + +parameters: + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +stages: + - template: stages/py-gpu-packaging-stage.yml + parameters: + enable_windows_dml: true + cmake_build_type: ${{ parameters.cmake_build_type }} + publish_symbols: true diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index de17db216da9c..ed992be31257a 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -4,21 +4,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -74,12 +64,10 @@ parameters: trigger: none stages: -- template: templates/py-packaging-stage.yml +- template: stages/py-cpu-packaging-stage.yml parameters: - enable_linux_gpu: ${{ parameters.enable_linux_gpu }} enable_linux_cpu: ${{ parameters.enable_linux_cpu }} enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} @@ -90,3 +78,4 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} qnn_sdk_version: ${{ parameters.qnn_sdk_version }} publish_symbols: true + diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 9289935b4ef9c..a33f757c24408 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -57,15 +57,15 @@ jobs: - checkout: self - task: DownloadPipelineArtifact@2 inputs: - artifact: 'drop-linux-gpu-x86_64' - targetPath: '$(Build.SourcesDirectory)/drop-linux-gpu-x86_64' + artifact: 'linux_gpu_wheel_x86_64' + targetPath: '$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64' ${{ if ne(parameters.build_id, 'latest') }}: buildType: 'specific' project: '${{ parameters.project }}' pipeline: '${{ parameters.pipeline }}' buildVersionToDownload: 'specific' buildId: '${{ parameters.build_id }}' - displayName: 'Download Build Artifacts - drop-linux-gpu-x86_64' + displayName: 'Download Build Artifacts - linux_gpu_wheel_x86_64' - task: DownloadPipelineArtifact@2 inputs: @@ -82,7 +82,7 @@ jobs: - bash: | set -e -x ls $(Build.SourcesDirectory) - mv "$(Build.SourcesDirectory)/drop-linux-gpu-x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Build.SourcesDirectory)/linux_gpu_wheel_x86_64" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Build.SourcesDirectory)/onnxruntime_gpu" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml similarity index 64% rename from tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml rename to tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 10d7ce04747d9..e92761e20d9e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -10,21 +10,11 @@ parameters: type: boolean default: true -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' - type: boolean - default: true - - name: enable_windows_cpu displayName: 'Whether Windows CPU package is built.' type: boolean default: true -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' - type: boolean - default: true - - name: enable_mac_cpu displayName: 'Whether Mac CPU package is built.' type: boolean @@ -65,10 +55,6 @@ parameters: - RelWithDebInfo - MinSizeRel -- name: publish_symbols - type: boolean - default: false - # Only applies to QNN packages. - name: qnn_sdk_version type: string @@ -128,7 +114,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -142,7 +128,7 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: BatchScript@1 displayName: 'setup env' @@ -151,7 +137,7 @@ stages: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -180,24 +166,12 @@ stages: --enable_pybind --enable_onnx_tests ${{ parameters.build_py_parameters }} - --parallel --use_binskim_compliant_compile_flags --update + --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\onnxruntime.sln' - platform: $(MsbuildPlatform) - configuration: ${{ parameters.cmake_build_type }} - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}' - createLogFile: true - # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -251,29 +225,8 @@ stages: python onnx_backend_test_series.py workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' displayName: 'Run Python Tests' - - ${{ if eq(parameters.publish_symbols, true) }}: - - task: PublishSymbols@2 - displayName: 'Publish symbols' - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) - inputs: - SymbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - SearchPattern: | - onnxruntime_pybind11_state.pdb - onnxruntime_providers_shared.pdb - IndexSources: true - SymbolServerType: TeamServices - SymbolExpirationInDays: 3650 - SymbolsArtifactName: 'win_cpu_$(PythonVersion)_$(buildArch)_$(Build.BuildNumber)' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - continueOnError: true - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -281,87 +234,6 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- ${{ if eq(parameters.enable_windows_gpu, true) }}: - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' - PYTHON_VERSION: '3.13' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ENV_SETUP_SCRIPT: setup_env_gpu.bat - EP_NAME: gpu - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - template: py-win-gpu.yml - parameters: - MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' - PYTHON_VERSION: '3.13' - EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat - EP_NAME: directml - publish_symbols: ${{ parameters.publish_symbols }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - ${{ if eq(parameters.enable_mac_cpu, true) }}: - stage: Python_Packaging_MacOS dependsOn: [] @@ -395,9 +267,9 @@ stages: inputs: versionSpec: $(PythonVersion) - - template: use-xcode-version.yml + - template: ../templates/use-xcode-version.yml - - template: download-deps.yml + - template: ../templates/download-deps.yml - task: PythonScript@0 displayName: 'Update deps.txt' @@ -437,7 +309,7 @@ stages: inputs: ArtifactName: onnxruntime - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -446,7 +318,7 @@ stages: - stage: Python_Packaging_Linux_ARM dependsOn: [] jobs: - - template: py-linux.yml + - template: ../templates/py-linux.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' @@ -457,30 +329,18 @@ stages: - stage: Python_Packaging_Linux_CPU dependsOn: [] jobs: - - template: py-linux.yml - parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: py-linux-gpu.yml + - template: ../templates/py-linux.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - trt_version: '10.4.0.26-1.cuda11.8' - cuda_version: '11.8' - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - stage: Python_Packaging_Windows_ARM64_QNN dependsOn: [] jobs: - - template: py-win-arm64-qnn.yml + - template: ../templates/py-win-arm64-qnn.yml parameters: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -490,7 +350,7 @@ stages: - stage: Python_Packaging_Windows_arm64ec_QNN dependsOn: [] jobs: - - template: py-win-arm64ec-qnn.yml + - template: ../templates/py-win-arm64ec-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -500,7 +360,7 @@ stages: - stage: Python_Packaging_Windows_x64_QNN dependsOn: [] jobs: - - template: py-win-x64-qnn.yml + - template: ../templates/py-win-x64-qnn.yml parameters: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} @@ -510,7 +370,7 @@ stages: - stage: Python_Packaging_Linux_x64_QNN dependsOn: [] jobs: - - template: py-linux-qnn.yml + - template: ../templates/py-linux-qnn.yml parameters: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml similarity index 68% rename from tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml rename to tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index ae18687cb9e54..1ae95a296162c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -5,15 +5,20 @@ parameters: type: string default: '' -- name: enable_linux_gpu - displayName: 'Whether Linux GPU package is built.' +- name: enable_linux_cuda + displayName: 'Whether Linux CUDA package is built.' type: boolean - default: true + default: false -- name: enable_windows_gpu - displayName: 'Whether Windows GPU package is built.' +- name: enable_windows_cuda + displayName: 'Whether Windows CUDA package is built.' type: boolean - default: true + default: false + +- name: enable_windows_dml + displayName: 'Whether Windows DML package is built.' + type: boolean + default: false # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type @@ -34,16 +39,6 @@ parameters: - 11.8 - 12.2 -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - - name: PythonVersions type: object displayName: 'Python versions to build' @@ -53,23 +48,25 @@ parameters: - '3.12' - '3.13' +- name: publish_symbols + type: boolean + default: false + stages: - - ${{ if eq(parameters.enable_windows_gpu, true) }}: + - ${{ if eq(parameters.enable_windows_cuda, true) }}: - ${{ each python_version in parameters.PythonVersions }}: - - template: ../templates/py-win-gpu.yml + - template: py-win-gpu-stage.yml parameters: PYTHON_VERSION: ${{ python_version }} EP_NAME: gpu CudaVersion: ${{ parameters.cuda_version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} ${{ if eq(parameters.cuda_version, '11.8') }}: EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{ if eq(parameters.cuda_version, '12.2') }}: EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - - ${{ if eq(parameters.enable_linux_gpu, true) }}: - - template: ../templates/py-linux-gpu.yml + - ${{ if eq(parameters.enable_linux_cuda, true) }}: + - template: py-linux-gpu-stage.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' @@ -82,3 +79,15 @@ stages: ${{ if eq(parameters.cuda_version, '12.2') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 trt_version: 10.4.0.26-1.cuda12.6 + + - ${{ if eq(parameters.enable_windows_dml, true) }}: + - ${{ each python_version in parameters.PythonVersions }}: + - template: py-win-gpu-stage.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' + PYTHON_VERSION: ${{ python_version }} + EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos + ENV_SETUP_SCRIPT: setup_env.bat + EP_NAME: directml + publish_symbols: ${{ parameters.publish_symbols }} + cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml similarity index 63% rename from tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml index d19472bcbab5a..f9053cba56835 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml @@ -41,7 +41,27 @@ stages: timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: linux + templateContext: + codeSignValidation: + enabled: true + break: true + psscriptanalyzer: + enabled: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + targetPathPattern: '\".*.so\"' + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/dist + artifactName: onnxruntime_gpu + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/${{ parameters.cmake_build_type }} + artifactName: linux_gpu_wheel_${{ parameters.arch }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - name: skipComponentGovernanceDetection @@ -56,9 +76,9 @@ stages: clean: true submodules: recursive - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml + - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda @@ -73,17 +93,18 @@ stages: filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime_gpu - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Test Binaries' - inputs: - artifactName: 'drop-linux-gpu-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/Release' + - script: | + set -e -x + mv $(Build.BinariesDirectory)/${{ parameters.cmake_build_type }} ./${{ parameters.cmake_build_type }} + mv $(Build.BinariesDirectory)/dist ./dist + pushd dist + find . -name \*.whl -exec unzip -qq -o {} \; + popd + pushd ${{ parameters.cmake_build_type }} + find . -name \*.whl -exec unzip -qq -o {} \; + popd + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'Move files' - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml similarity index 93% rename from tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml rename to tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 71500e4ef9025..0cbcd2b74371e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -28,16 +28,6 @@ parameters: - 11.8 - 12.2 -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - - name: cmake_build_type type: string displayName: 'Linux packages cmake build type. Linux Only.' @@ -75,7 +65,7 @@ stages: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: ../templates/telemetry-steps.yml - task: UsePythonVersion@0 inputs: @@ -89,10 +79,10 @@ stages: tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' appendSourceBranchName: false - - template: download-deps.yml + - template: ../templates/download-deps.yml - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/set-winenv.yml + - template: ../templates/jobs/set-winenv.yml parameters: EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: @@ -101,7 +91,7 @@ stages: DownloadTRT: true - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: jobs/download_win_gpu_library.yml + - template: ../templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: ${{ parameters.CudaVersion }} ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: @@ -123,7 +113,7 @@ stages: workingDirectory: '$(Build.BinariesDirectory)' arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\installed -build_config ${{ parameters.cmake_build_type }} - - template: set-nightly-build-option-variable-step.yml + - template: ../templates/set-nightly-build-option-variable-step.yml - task: PythonScript@0 displayName: 'Generate cmake config' @@ -153,7 +143,7 @@ stages: workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - - template: win-esrp-dll.yml + - template: ../templates/win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' @@ -216,7 +206,7 @@ stages: GdnPublishTsaOnboard: false GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - template: component-governance-component-detection-steps.yml + - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' @@ -243,13 +233,11 @@ stages: addToPath: true architecture: 'x64' - - template: flex-downloadPipelineArtifact.yml + - template: ../templates/flex-downloadPipelineArtifact.yml parameters: ArtifactName: onnxruntime_${{ parameters.EP_NAME }} StepName: 'Download Pipeline Artifact - Windows GPU Build' TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - task: PowerShell@2 displayName: 'Install ONNX' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 4ca462bf962f5..6a74d0e7befd3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -61,7 +61,7 @@ jobs: # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. - artifact: 'drop-linux-gpu-${{ parameters.arch }}' + artifact: 'linux_gpu_wheel_${{ parameters.arch }}' - download: build # pipeline resource identifier. artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' @@ -69,7 +69,7 @@ jobs: - bash: | set -e -x ls $(Pipeline.Workspace)/build - mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/linux_gpu_wheel_${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/linux/build_training_ci.sh b/tools/ci_build/github/linux/build_training_ci.sh deleted file mode 100755 index 82f75a5cbbc50..0000000000000 --- a/tools/ci_build/github/linux/build_training_ci.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -set -e -x -python3.12 -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt -python3.12 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --enable_training --skip_submodule_sync --parallel diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training deleted file mode 100644 index 4d11cbbde3354..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_gpu_training +++ /dev/null @@ -1,60 +0,0 @@ -ARG BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu18.04 - -FROM $BASEIMAGE - -ARG PYTHON_VERSION=3.9 -ARG INSTALL_DEPS_EXTRA_ARGS -ARG USE_CONDA=false - -ADD scripts /tmp/scripts -RUN /tmp/scripts/install_ubuntu.sh -p $PYTHON_VERSION && \ - /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS - -# If USE_CONDA is false, use root to install python dependencies. -RUN if [ "$USE_CONDA" = false ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS ; \ - fi - -WORKDIR /root - -# Allow configure to pick up GDK and CuDNN where it expects it. -# (Note: $CUDNN_VERSION is defined by NVidia's base image) -RUN _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/include && \ - ln -s /usr/include/cudnn.h /usr/local/cudnn-$_CUDNN_VERSION/cuda/include/cudnn.h && \ - mkdir -p /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64 && \ - ln -s /etc/alternatives/libcudnn_so /usr/local/cudnn-$_CUDNN_VERSION/cuda/lib64/libcudnn.so && \ - ln -s /usr/local/cudnn{-$_CUDNN_VERSION,} - -ENV LD_LIBRARY_PATH /usr/local/openblas/lib:$LD_LIBRARY_PATH - -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID -WORKDIR /home/$BUILD_USER -USER $BUILD_USER - -ARG MINICONDA_PREFIX=/home/$BUILD_USER/miniconda3 -RUN if [ "$USE_CONDA" = true ] ; \ - then MINICONDA=miniconda.sh && \ - wget --no-verbose https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O $MINICONDA && \ - chmod a+x $MINICONDA && \ - ./$MINICONDA -b -p $MINICONDA_PREFIX && \ - rm ./$MINICONDA && \ - $MINICONDA_PREFIX/bin/conda clean --yes --all && \ - $MINICONDA_PREFIX/bin/conda install -y python=$PYTHON_VERSION ; \ - fi - -ENV PATH /home/$BUILD_USER/miniconda3/bin:$PATH - -# If USE_CONDA is true, use onnxruntimedev user to install python dependencies -RUN if [ "$USE_CONDA" = true ] ; \ - then /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d gpu $INSTALL_DEPS_EXTRA_ARGS -c ; \ - fi - -WORKDIR /root -USER root -RUN rm -rf /tmp/scripts - -WORKDIR /home/$BUILD_USER -USER $BUILD_USER