From 9985e85118ea5e3ddeb975664b30b0a38105e9e4 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 19 Sep 2024 18:26:57 -0700 Subject: [PATCH] Update openvino.js --- source/openvino.js | 108 +++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/source/openvino.js b/source/openvino.js index f8a2cfa65d..579a08d43a 100644 --- a/source/openvino.js +++ b/source/openvino.js @@ -4,8 +4,8 @@ const openvino = {}; openvino.ModelFactory = class { match(context) { - const identifier = context.identifier; - const extension = identifier.split('.').pop().toLowerCase(); + const identifier = context.identifier.toLowerCase(); + const extension = identifier.split('.').pop(); if (/^.*\.ncnn\.bin$/.test(identifier) || /^.*\.pnnx\.bin$/.test(identifier) || /^.*pytorch_model.*\.bin$/.test(identifier) || @@ -16,10 +16,6 @@ openvino.ModelFactory = class { if (extension === 'bin') { const stream = context.stream; const length = stream.length; - const signature = [0x21, 0xA8, 0xEF, 0xBE, 0xAD, 0xDE]; - if (signature.length <= length && stream.peek(signature.length).every((value, index) => value === signature[index])) { - return; - } if (length >= 4) { let buffer = stream.peek(Math.min(0x20000, length)); const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.length); @@ -31,52 +27,68 @@ openvino.ModelFactory = class { return; } } - if (identifier.endsWith('.bin') || identifier.endsWith('.serialized')) { - const stream = context.stream; - const signatures = [ - [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], - [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01] - ]; - if (stream.length >= 16 && signatures.some((signature) => stream.peek(signature.length).every((value, index) => value === signature[index]))) { - return; + const match = (pattern, identifier, buffer) => { + if (pattern.identifier && typeof pattern.identifier === 'string' && identifier !== pattern.identifier) { + return false; + } else if (pattern.identifier && pattern.identifier instanceof RegExp && !pattern.identifier.test(identifier)) { + return false; + } else if (pattern.signature && !pattern.signature.every((value, index) => value === buffer[index])) { + return false; } + return true; + }; + const include = [ + { identifier: /googlenet-v1\.bin$/, signature: [0x80, 0xD6, 0x50, 0xD7, 0xB0, 0xD7, 0xA5, 0x2D, 0xCA, 0x28, 0x49, 0x2A, 0x35, 0x31, 0x0A, 0x31] } + ]; + if (include.some((pattern) => match(pattern, identifier, buffer))) { + context.type = 'openvino.bin'; + return; } - const identifiers = new Set([ - 'config.bin', 'model.bin', '__model__.bin', 'weights.bin', - 'programs.bin', 'best.bin', 'ncnn.bin', - 'stories15M.bin','stories42M.bin','stories110M.bin','stories260K.bin' - ]); - if (!identifiers.has(identifier) && signature !== 0x00000001) { - const size = Math.min(buffer.length & 0xfffffffc, 128); - buffer = buffer.subarray(0, size); - if (Array.from(buffer).every((value) => value === 0)) { - context.type = 'openvino.bin'; - return; - } - const f32 = new Array(buffer.length >> 2); - for (let i = 0; i < f32.length; i++) { - f32[i] = view.getFloat32(i << 2, true); - } - const f16 = new Array(buffer.length >> 1); - for (let i = 0; i < f16.length; i++) { - f16[i] = view.getFloat16(i << 1, true); - } - const i32 = new Array(buffer.length >> 2); - for (let i = 0; i < f32.length; i++) { - i32[i] = view.getInt32(i << 2, true); - } - const validateFloat = (array) => array[0] !== 0 && array.every((x) => !Number.isNaN(x) && Number.isFinite(x)) && - (array.every((x) => x > -20.0 && x < 20.0 && (x >= 0 || x < -0.0000001) && (x <= 0 || x > 0.0000001)) || - array.every((x) => x > -100.0 && x < 100.0 && (x * 10) % 1 === 0)); - const validateInt = (array) => array.length > 32 && - array.slice(0, 32).every((x) => x === 0 || x === 1 || x === 2 || x === 0x7fffffff); - if (validateFloat(f32) || validateFloat(f16) || validateInt(i32)) { - context.type = 'openvino.bin'; - return; - } + const exclude = [ + { identifier: '__model__.bin' }, + { identifier: 'config.bin' }, + { identifier: 'model.bin' }, + { identifier: 'ncnn.bin' }, + { identifier: 'programs.bin' }, + { identifier: 'weights.bin' }, + { identifier: /stories\d+(m|k)\.bin$/ }, + { signature: [0x21, 0xA8, 0xEF, 0xBE, 0xAD, 0xDE] } + ]; + if (exclude.some((pattern) => match(pattern, identifier, buffer))) { + return; + } + if (signature === 0x00000001) { + return; + } + const size = Math.min(buffer.length & 0xfffffffc, 128); + buffer = buffer.subarray(0, size); + if (Array.from(buffer).every((value) => value === 0)) { + context.type = 'openvino.bin'; + return; + } + const f32 = new Array(buffer.length >> 2); + for (let i = 0; i < f32.length; i++) { + f32[i] = view.getFloat32(i << 2, true); + } + const f16 = new Array(buffer.length >> 1); + for (let i = 0; i < f16.length; i++) { + f16[i] = view.getFloat16(i << 1, true); + } + const i32 = new Array(buffer.length >> 2); + for (let i = 0; i < f32.length; i++) { + i32[i] = view.getInt32(i << 2, true); + } + const validateFloat = (array) => array[0] !== 0 && array.every((x) => !Number.isNaN(x) && Number.isFinite(x)) && + (array.every((x) => x > -20.0 && x < 20.0 && (x >= 0 || x < -0.0000001) && (x <= 0 || x > 0.0000001)) || + array.every((x) => x > -100.0 && x < 100.0 && (x * 10) % 1 === 0)); + const validateInt = (array) => array.length > 32 && + array.slice(0, 32).every((x) => x === 0 || x === 1 || x === 2 || x === 0x7fffffff); + if (validateFloat(f32) || validateFloat(f16) || validateInt(i32)) { + context.type = 'openvino.bin'; + return; } - return; } + return; } const tags = context.tags('xml'); if (tags.has('net')) {