diff --git a/md_doc/inference_with_snpe_dsp.md b/md_doc/inference_with_snpe_dsp.md new file mode 100644 index 00000000..2789132f --- /dev/null +++ b/md_doc/inference_with_snpe_dsp.md @@ -0,0 +1,81 @@ +# Deploy Model with SNPE DSP +This document describes the quantization deployment process of the SNPE DSP and how PPQ writes quantization parameters to the SNPE model. + + +## Environment setup +Refer to [Qualcomm official documentation](https://developer.qualcomm.com/sites/default/files/docs/snpe/setup.html) to configure the Linux host environment. The SNPE model conversion and quantization are all done on the Linux host. SNPE supports reading Caffe and Onnx models. This document uses the ONNX model as an example. + +## Quantize Your Network +as we have specified in [how_to_use](./how_to_use.md), we should prepare our calibration dataloader, confirm +the target platform on which we want to deploy our model(*TargetPlatform.QNN_DSP_INT8* in this case), load our +simplified model, initialize quantizer and executor, and then run the quantization process +```python +import os +import numpy as np +import torch +from ppq.api import load_onnx_graph +from ppq.api.interface import dispatch_graph, QUANTIZER_COLLECTION +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq import QuantizationSettingFactory + +model_path = '/models/shufflenet-v2-sim.onnx' # onnx simplified model +data_path = '/data/ImageNet/calibration' # calibration data folder +EXECUTING_DEVICE = 'cuda' + +# initialize dataloader +INPUT_SHAPE = [1, 3, 224, 224] +npy_array = [np.fromfile(os.path.join(data_path, file_name), dtype=np.float32).reshape(*INPUT_SHAPE) for file_name in os.listdir(data_path)] +dataloader = [torch.from_numpy(np.load(npy_tensor)) for npy_tensor in npy_array] + +# confirm platform and setting +target_platform = TargetPlatform.QNN_DSP_INT8 +setting = QuantizationSettingFactory.dsp_setting() + +# load and schedule graph +ppq_graph_ir = load_onnx_graph(model_path) +ppq_graph_ir = dispatch_graph(ppq_graph_ir, target_platform, setting) + +# intialize quantizer and executor +executor = TorchExecutor(ppq_graph_ir, device='cuda') +quantizer = QUANTIZER_COLLECTION[target_platform](graph=ppq_graph_ir) + +# run quantization +calib_steps = max(min(512, len(dataloader)), 8) # 8 ~ 512 +dummy_input = dataloader[0].to(EXECUTING_DEVICE) # random input for meta tracing +quantizer.quantize( + inputs=dummy_input, # some random input tensor, should be list or dict for multiple inputs + calib_dataloader=dataloader, # calibration dataloader + executor=executor, # executor in charge of everywhere graph execution is needed + setting=setting, # quantization setting + calib_steps=calib_steps, # number of batched data needed in calibration, 8~512 + collate_fn=lambda x: x.to(EXECUTING_DEVICE) # final processing of batched data tensor +) + +# export quantization param file and model file +export_ppq_graph(graph=ppq_ir_graph, platform=TargetPlatform.QNN_DSP_INT8, graph_save_to='shufflenet-v2-sim-ppq', config_save_to='shufflenet-v2-sim-ppq.table') +``` + +## Convert Your Model +The snpe-onnx-to-dlc tool converts the ppq export onnx model to an equivalent DLC representation. +```shell +snpe-onnx-to-dlc -i ppq_export_fp32.onnx -o fp32.dlc +``` +Generate 8 or 16 bit TensorFlow style fixed point weight and activations encodings for a floating point SNPE model. +The snpe-dlc-quantize tool converts non-quantized DLC models into quantized DLC models. +```shell +snpe-dlc-quantize --input_dlc fp32.dlc --input_list path_to_binary_calidata --output_dlc quant.dlc +``` + +Finally, write the PPQ quantization parameters to quant.dlc. We have fully tested the script in snpe version 1.43. In recent SNPE releases, if the option –quantization_overrides is provided during model conversion the user can provide a json file with parameters to use for quantization. These will be cached along with the model and can be used to override any quantization data carried from conversion (eg TF fake quantization) or calculated during the normal quantization process in snpe-dlc-quantize. + +```shell +python3 write_qparams_to_snpe_dlc.py --input_dlc_model quant.dlc --qparam quant.json +``` + +## Run Inference +Model inference using mobile dsp. The inputs must be presented as a tensor of shape (batch x height x width x channel) + +```shell +snpe-net-run --container ppq_export_quant.dlc --input_list path_to_data --use_dsp +``` diff --git a/ppq/quantization/quantizer/DSPQuantizer.py b/ppq/quantization/quantizer/DSPQuantizer.py index 8d78436d..c009232c 100644 --- a/ppq/quantization/quantizer/DSPQuantizer.py +++ b/ppq/quantization/quantizer/DSPQuantizer.py @@ -75,7 +75,7 @@ def quant_operation_types(self) -> set: 'GlobalMaxPool', 'GlobalAveragePool', 'Softmax', 'Mul', 'Add', 'Max', 'Sub', 'Div', 'Reshape', 'LeakyRelu', 'Concat', 'Sigmoid', 'Slice', 'Interp', - 'ReduceMean' + 'ReduceMean', 'Flatten', } @ property @@ -178,4 +178,4 @@ def quantize_policy(self) -> QuantizationPolicy: @ property def activation_fusion_types(self) -> set: - return {'Relu', 'Clip'} \ No newline at end of file + return {'Relu', 'Clip'} diff --git a/ppq/utils/write_qparams_to_snpe_dlc.py b/ppq/utils/write_qparams_to_snpe_dlc.py new file mode 100644 index 00000000..7791e165 --- /dev/null +++ b/ppq/utils/write_qparams_to_snpe_dlc.py @@ -0,0 +1,46 @@ +import argparse +import json +import os +import snpe +import qti.aisw.dlc_utils as dlc + +parser = argparse.ArgumentParser(description='Write ppq qparams to snpe dlc') +parser.add_argument('--input_dlc_model', default='snpe_quantized.dlc', help='path to snpe quantized dlc model') +parser.add_argument('--output_dlc_model', default='ppq_export.dlc', help='path to export quantized dlc') +parser.add_argument('--qparam', default='quantized.json', help='path to ppq qparams json') + +def json_load(filename): + with open(filename) as json_file: + data = json.load(json_file) + return data + +def write_qparams_to_dlc_model(input_dlc, output_dlc, activation_qparams): + model = dlc.modeltools.Model() + model.load(input_dlc) + model.set_tf_encoding_type("TF") + + for snpe_layer in model.get_layers(): + print('\n write qparams to {}'.format(snpe_layer['name'])) + for snpe_layer_out_ind, snpe_layer_out in enumerate(snpe_layer['output_names']): + layer_name = snpe_layer['name'] + print('original quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind)) + top = snpe_layer['output_names'][0] + + if top not in activation_qparams.keys(): + # Before the Reshape layer, SNPE will insert the shape conversion layer(xxx.ncs) + # Because the SNPE data is arranged as NHWC + assert top.endswith('.ncs'), '{} ranges not exists'.format(top) + bottom = snpe_layer['input_names'][0] + new_enc = activation_qparams[bottom][0] #List[dict] + else: + new_enc = activation_qparams[top][0] #List[dict] + + model.set_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind, bitwidth=8, min=new_enc["min"], max=new_enc["max"]) + print('ppq quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind)) + model.quantize_weights(should_quantize=True) + model.save(output_dlc) + +if __name__ == '__main__': + args = parser.parse_args() + act_ranges = json_load(args.qparam)['activation_encodings'] + write_qparams_to_dlc_model(args.input_dlc_model, args.output_dlc_model, act_ranges)