-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1dfd4db
Showing
141 changed files
with
328,452 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,321 @@ | ||
|
||
import * as flatbuffers from './flatbuffers.js'; | ||
|
||
const armnn = {}; | ||
|
||
armnn.ModelFactory = class { | ||
|
||
match(context) { | ||
const identifier = context.identifier; | ||
const extension = identifier.split('.').pop().toLowerCase(); | ||
const stream = context.stream; | ||
if (stream && extension === 'armnn') { | ||
return { name: 'armnn.flatbuffers', value: stream }; | ||
} | ||
if (extension === 'json') { | ||
const obj = context.peek('json'); | ||
if (obj && obj.layers && obj.inputIds && obj.outputIds) { | ||
return { name: 'armnn.flatbuffers.json', value: obj }; | ||
} | ||
} | ||
return undefined; | ||
} | ||
|
||
async open(context, target) { | ||
await context.require('./armnn-schema'); | ||
armnn.schema = flatbuffers.get('armnn').armnnSerializer; | ||
let model = null; | ||
switch (target.name) { | ||
case 'armnn.flatbuffers': { | ||
try { | ||
const stream = target.value; | ||
const reader = flatbuffers.BinaryReader.open(stream); | ||
model = armnn.schema.SerializedGraph.create(reader); | ||
} catch (error) { | ||
const message = error && error.message ? error.message : error.toString(); | ||
throw new armnn.Error(`File format is not armnn.SerializedGraph (${message.replace(/\.$/, '')}).`); | ||
} | ||
break; | ||
} | ||
case 'armnn.flatbuffers.json': { | ||
try { | ||
const obj = target.value; | ||
const reader = flatbuffers.TextReader.open(obj); | ||
model = armnn.schema.SerializedGraph.createText(reader); | ||
} catch (error) { | ||
const message = error && error.message ? error.message : error.toString(); | ||
throw new armnn.Error(`File text format is not armnn.SerializedGraph (${message.replace(/\.$/, '')}).`); | ||
} | ||
break; | ||
} | ||
default: { | ||
throw new armnn.Error(`Unsupported Arm NN '${target}'.`); | ||
} | ||
} | ||
const metadata = await context.metadata('armnn-metadata.json'); | ||
return new armnn.Model(metadata, model); | ||
} | ||
}; | ||
|
||
armnn.Model = class { | ||
|
||
constructor(metadata, model) { | ||
this.format = 'Arm NN'; | ||
this.graphs = [ new armnn.Graph(metadata, model) ]; | ||
} | ||
}; | ||
|
||
armnn.Graph = class { | ||
|
||
constructor(metadata, graph) { | ||
this.name = ''; | ||
this.nodes = []; | ||
this.inputs = []; | ||
this.outputs = []; | ||
const counts = new Map(); | ||
for (const layer of graph.layers) { | ||
const base = armnn.Node.getBase(layer); | ||
for (const slot of base.inputSlots) { | ||
const name = `${slot.connection.sourceLayerIndex}:${slot.connection.outputSlotIndex}`; | ||
counts.set(name, counts.has(name) ? counts.get(name) + 1 : 1); | ||
} | ||
} | ||
const values = new Map(); | ||
const value = (layerIndex, slotIndex, tensor) => { | ||
const name = `${layerIndex}:${slotIndex}`; | ||
if (!values.has(name)) { | ||
const layer = graph.layers[layerIndex]; | ||
const base = layerIndex < graph.layers.length ? armnn.Node.getBase(layer) : null; | ||
const tensorInfo = base && slotIndex < base.outputSlots.length ? base.outputSlots[slotIndex].tensorInfo : null; | ||
values.set(name, new armnn.Value(name, tensorInfo, tensor)); | ||
} | ||
return values.get(name); | ||
}; | ||
const layers = graph.layers.filter((layer) => { | ||
const base = armnn.Node.getBase(layer); | ||
if (base.layerType == armnn.schema.LayerType.Constant && base.outputSlots.length === 1 && layer.layer.input) { | ||
/* eslint-disable prefer-destructuring */ | ||
const slot = base.outputSlots[0]; | ||
/* eslint-enable prefer-destructuring */ | ||
const name = `${base.index}:${slot.index}`; | ||
if (counts.get(name) === 1) { | ||
const tensor = new armnn.Tensor(layer.layer.input, 'Constant'); | ||
value(base.index, slot.index, tensor); | ||
return false; | ||
} | ||
} | ||
return true; | ||
}); | ||
for (const layer of layers) { | ||
const base = armnn.Node.getBase(layer); | ||
for (const slot of base.inputSlots) { | ||
value(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex); | ||
} | ||
} | ||
for (const layer of layers) { | ||
const base = armnn.Node.getBase(layer); | ||
switch (base.layerType) { | ||
case armnn.schema.LayerType.Input: { | ||
const name = base ? base.layerName : ''; | ||
for (const slot of base.outputSlots) { | ||
const argument = new armnn.Argument(name, [ value(base.index, slot.index) ]); | ||
this.inputs.push(argument); | ||
} | ||
break; | ||
} | ||
case armnn.schema.LayerType.Output: { | ||
const base = armnn.Node.getBase(layer); | ||
const name = base ? base.layerName : ''; | ||
for (const slot of base.inputSlots) { | ||
const argument = new armnn.Argument(name, [ value(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex) ]); | ||
this.outputs.push(argument); | ||
} | ||
break; | ||
} | ||
default: | ||
this.nodes.push(new armnn.Node(metadata, layer, value)); | ||
break; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
armnn.Node = class { | ||
|
||
constructor(metadata, layer, value) { | ||
const type = layer.layer.constructor.name; | ||
this.type = Object.assign({}, metadata.type(type) || { name: type }); | ||
this.type.name = this.type.name.replace(/Layer$/, ''); | ||
this.name = ''; | ||
this.outputs = []; | ||
this.inputs = []; | ||
this.attributes = []; | ||
const inputSchemas = (this.type && this.type.inputs) ? [...this.type.inputs] : [ { name: 'input' } ]; | ||
const outputSchemas = (this.type && this.type.outputs) ? [...this.type.outputs] : [ { name: 'output' } ]; | ||
const base = armnn.Node.getBase(layer); | ||
if (base) { | ||
this.name = base.layerName; | ||
const inputs = [...base.inputSlots]; | ||
while (inputs.length > 0) { | ||
const inputSchema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' }; | ||
const count = inputSchema.list ? inputs.length : 1; | ||
const argument = new armnn.Argument(inputSchema.name, inputs.splice(0, count).map((inputSlot) => { | ||
return value(inputSlot.connection.sourceLayerIndex, inputSlot.connection.outputSlotIndex); | ||
})); | ||
this.inputs.push(argument); | ||
} | ||
const outputs = [...base.outputSlots]; | ||
while (outputs.length > 0) { | ||
const outputSchema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' }; | ||
const count = outputSchema.list ? outputs.length : 1; | ||
this.outputs.push(new armnn.Argument(outputSchema.name, outputs.splice(0, count).map((outputSlot) => { | ||
return value(base.index, outputSlot.index); | ||
}))); | ||
} | ||
} | ||
if (layer.layer) { | ||
if (layer.layer.descriptor && this.type.attributes) { | ||
for (const [name, value] of Object.entries(layer.layer.descriptor)) { | ||
const attribute = new armnn.Attribute(metadata.attribute(type, name), name, value); | ||
this.attributes.push(attribute); | ||
} | ||
} | ||
for (const [name, tensor] of Object.entries(layer.layer).filter(([, value]) => value instanceof armnn.schema.ConstTensor)) { | ||
const value = new armnn.Value('', tensor.info, new armnn.Tensor(tensor)); | ||
const argument = new armnn.Argument(name, [ value ]); | ||
this.inputs.push(argument); | ||
} | ||
} | ||
} | ||
|
||
static getBase(layer) { | ||
return layer.layer.base.base ? layer.layer.base.base : layer.layer.base; | ||
} | ||
|
||
static makeKey(layer_id, index) { | ||
return `${layer_id}_${index}`; | ||
} | ||
}; | ||
|
||
armnn.Attribute = class { | ||
|
||
constructor(metadata, name, value) { | ||
this.name = name; | ||
this.type = metadata ? metadata.type : null; | ||
this.value = ArrayBuffer.isView(value) ? Array.from(value) : value; | ||
if (armnn.schema[this.type]) { | ||
this.value = armnn.Utility.enum(this.type, this.value); | ||
} | ||
} | ||
}; | ||
|
||
armnn.Argument = class { | ||
|
||
constructor(name, value) { | ||
this.name = name; | ||
this.value = value; | ||
} | ||
}; | ||
|
||
armnn.Value = class { | ||
|
||
constructor(name, tensorInfo, initializer) { | ||
if (typeof name !== 'string') { | ||
throw new armnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`); | ||
} | ||
this.name = name; | ||
this.type = new armnn.TensorType(tensorInfo); | ||
this.initializer = initializer; | ||
|
||
if (this.type.dataType.startsWith('q') && tensorInfo) { | ||
this._scale = tensorInfo.quantizationScale; | ||
this._zeroPoint = tensorInfo.quantizationOffset; | ||
} | ||
} | ||
|
||
get quantization() { | ||
if (this._scale !== undefined && this._zeroPoint !== undefined) { | ||
return `${this._scale} * ${this._zeroPoint == 0 ? 'q' : (`(q - ${this._zeroPoint})`)}`; | ||
} | ||
return undefined; | ||
} | ||
}; | ||
|
||
armnn.Tensor = class { | ||
|
||
constructor(tensor, category) { | ||
this.type = new armnn.TensorType(tensor.info); | ||
this.category = category || ''; | ||
const data = tensor.data.data.slice(0); | ||
this.values = new Uint8Array(data.buffer, data.byteOffset, data.byteLength); | ||
} | ||
}; | ||
|
||
armnn.TensorType = class { | ||
|
||
constructor(tensorInfo) { | ||
const dataType = tensorInfo.dataType; | ||
switch (dataType) { | ||
case 0: this.dataType = 'float16'; break; | ||
case 1: this.dataType = 'float32'; break; | ||
case 2: this.dataType = 'quint8'; break; // QuantisedAsymm8 | ||
case 3: this.dataType = 'int32'; break; | ||
case 4: this.dataType = 'boolean'; break; | ||
case 5: this.dataType = 'qint16'; break; // QuantisedSymm16 | ||
case 6: this.dataType = 'quint8'; break; // QAsymmU8 | ||
case 7: this.dataType = 'qint16'; break; // QSymmS16 | ||
case 8: this.dataType = 'qint8'; break; // QAsymmS8 | ||
case 9: this.dataType = 'qint8'; break; // QSymmS8 | ||
default: | ||
throw new armnn.Error(`Unsupported data type '${JSON.stringify(dataType)}'.`); | ||
} | ||
this.shape = new armnn.TensorShape(tensorInfo.dimensions); | ||
} | ||
|
||
toString() { | ||
return this.dataType + this.shape.toString(); | ||
} | ||
}; | ||
|
||
armnn.TensorShape = class { | ||
|
||
constructor(dimensions) { | ||
this.dimensions = Array.from(dimensions); | ||
} | ||
|
||
toString() { | ||
if (!this.dimensions || this.dimensions.length == 0) { | ||
return ''; | ||
} | ||
return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`; | ||
} | ||
}; | ||
|
||
armnn.Utility = class { | ||
|
||
static enum(name, value) { | ||
const type = name && armnn.schema ? armnn.schema[name] : undefined; | ||
if (type) { | ||
armnn.Utility._enums = armnn.Utility._enums || new Map(); | ||
if (!armnn.Utility._enums.has(name)) { | ||
const entries = new Map(Object.entries(type).map(([key, value]) => [ value, key ])); | ||
armnn.Utility._enums.set(name, entries); | ||
} | ||
const entries = armnn.Utility._enums.get(name); | ||
if (entries.has(value)) { | ||
return entries.get(value); | ||
} | ||
} | ||
return value; | ||
} | ||
}; | ||
|
||
armnn.Error = class extends Error { | ||
|
||
constructor(message) { | ||
super(message); | ||
this.name = 'Error loading Arm NN model.'; | ||
} | ||
}; | ||
|
||
export const ModelFactory = armnn.ModelFactory; |
Oops, something went wrong.