Skip to content

Commit

Permalink
add snpe example (#143)
Browse files Browse the repository at this point in the history
Co-authored-by: jizhe <[email protected]>
  • Loading branch information
Jzz24 and jizhe committed May 31, 2022
1 parent 54c0e3f commit 4486ab6
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 2 deletions.
81 changes: 81 additions & 0 deletions md_doc/inference_with_snpe_dsp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Deploy Model with SNPE DSP
This document describes the quantization deployment process of the SNPE DSP and how PPQ writes quantization parameters to the SNPE model.


## Environment setup
Refer to [Qualcomm official documentation](https://developer.qualcomm.com/sites/default/files/docs/snpe/setup.html) to configure the Linux host environment. The SNPE model conversion and quantization are all done on the Linux host. SNPE supports reading Caffe and Onnx models. This document uses the ONNX model as an example.

## Quantize Your Network
as we have specified in [how_to_use](./how_to_use.md), we should prepare our calibration dataloader, confirm
the target platform on which we want to deploy our model(*TargetPlatform.QNN_DSP_INT8* in this case), load our
simplified model, initialize quantizer and executor, and then run the quantization process
```python
import os
import numpy as np
import torch
from ppq.api import load_onnx_graph
from ppq.api.interface import dispatch_graph, QUANTIZER_COLLECTION
from ppq.core import TargetPlatform
from ppq.executor import TorchExecutor
from ppq import QuantizationSettingFactory

model_path = '/models/shufflenet-v2-sim.onnx' # onnx simplified model
data_path = '/data/ImageNet/calibration' # calibration data folder
EXECUTING_DEVICE = 'cuda'

# initialize dataloader
INPUT_SHAPE = [1, 3, 224, 224]
npy_array = [np.fromfile(os.path.join(data_path, file_name), dtype=np.float32).reshape(*INPUT_SHAPE) for file_name in os.listdir(data_path)]
dataloader = [torch.from_numpy(np.load(npy_tensor)) for npy_tensor in npy_array]

# confirm platform and setting
target_platform = TargetPlatform.QNN_DSP_INT8
setting = QuantizationSettingFactory.dsp_setting()

# load and schedule graph
ppq_graph_ir = load_onnx_graph(model_path)
ppq_graph_ir = dispatch_graph(ppq_graph_ir, target_platform, setting)

# intialize quantizer and executor
executor = TorchExecutor(ppq_graph_ir, device='cuda')
quantizer = QUANTIZER_COLLECTION[target_platform](graph=ppq_graph_ir)

# run quantization
calib_steps = max(min(512, len(dataloader)), 8) # 8 ~ 512
dummy_input = dataloader[0].to(EXECUTING_DEVICE) # random input for meta tracing
quantizer.quantize(
inputs=dummy_input, # some random input tensor, should be list or dict for multiple inputs
calib_dataloader=dataloader, # calibration dataloader
executor=executor, # executor in charge of everywhere graph execution is needed
setting=setting, # quantization setting
calib_steps=calib_steps, # number of batched data needed in calibration, 8~512
collate_fn=lambda x: x.to(EXECUTING_DEVICE) # final processing of batched data tensor
)

# export quantization param file and model file
export_ppq_graph(graph=ppq_ir_graph, platform=TargetPlatform.QNN_DSP_INT8, graph_save_to='shufflenet-v2-sim-ppq', config_save_to='shufflenet-v2-sim-ppq.table')
```

## Convert Your Model
The snpe-onnx-to-dlc tool converts the ppq export onnx model to an equivalent DLC representation.
```shell
snpe-onnx-to-dlc -i ppq_export_fp32.onnx -o fp32.dlc
```
Generate 8 or 16 bit TensorFlow style fixed point weight and activations encodings for a floating point SNPE model.
The snpe-dlc-quantize tool converts non-quantized DLC models into quantized DLC models.
```shell
snpe-dlc-quantize --input_dlc fp32.dlc --input_list path_to_binary_calidata --output_dlc quant.dlc
```

Finally, write the PPQ quantization parameters to quant.dlc. We have fully tested the script in snpe version 1.43. In recent SNPE releases, if the option –quantization_overrides is provided during model conversion the user can provide a json file with parameters to use for quantization. These will be cached along with the model and can be used to override any quantization data carried from conversion (eg TF fake quantization) or calculated during the normal quantization process in snpe-dlc-quantize.

```shell
python3 write_qparams_to_snpe_dlc.py --input_dlc_model quant.dlc --qparam quant.json
```

## Run Inference
Model inference using mobile dsp. The inputs must be presented as a tensor of shape (batch x height x width x channel)

```shell
snpe-net-run --container ppq_export_quant.dlc --input_list path_to_data --use_dsp
```
4 changes: 2 additions & 2 deletions ppq/quantization/quantizer/DSPQuantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def quant_operation_types(self) -> set:
'GlobalMaxPool', 'GlobalAveragePool', 'Softmax',
'Mul', 'Add', 'Max', 'Sub', 'Div', 'Reshape',
'LeakyRelu', 'Concat', 'Sigmoid', 'Slice', 'Interp',
'ReduceMean'
'ReduceMean', 'Flatten',
}

@ property
Expand Down Expand Up @@ -178,4 +178,4 @@ def quantize_policy(self) -> QuantizationPolicy:

@ property
def activation_fusion_types(self) -> set:
return {'Relu', 'Clip'}
return {'Relu', 'Clip'}
46 changes: 46 additions & 0 deletions ppq/utils/write_qparams_to_snpe_dlc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import json
import os
import snpe
import qti.aisw.dlc_utils as dlc

parser = argparse.ArgumentParser(description='Write ppq qparams to snpe dlc')
parser.add_argument('--input_dlc_model', default='snpe_quantized.dlc', help='path to snpe quantized dlc model')
parser.add_argument('--output_dlc_model', default='ppq_export.dlc', help='path to export quantized dlc')
parser.add_argument('--qparam', default='quantized.json', help='path to ppq qparams json')

def json_load(filename):
with open(filename) as json_file:
data = json.load(json_file)
return data

def write_qparams_to_dlc_model(input_dlc, output_dlc, activation_qparams):
model = dlc.modeltools.Model()
model.load(input_dlc)
model.set_tf_encoding_type("TF")

for snpe_layer in model.get_layers():
print('\n write qparams to {}'.format(snpe_layer['name']))
for snpe_layer_out_ind, snpe_layer_out in enumerate(snpe_layer['output_names']):
layer_name = snpe_layer['name']
print('original quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind))
top = snpe_layer['output_names'][0]

if top not in activation_qparams.keys():
# Before the Reshape layer, SNPE will insert the shape conversion layer(xxx.ncs)
# Because the SNPE data is arranged as NHWC
assert top.endswith('.ncs'), '{} ranges not exists'.format(top)
bottom = snpe_layer['input_names'][0]
new_enc = activation_qparams[bottom][0] #List[dict]
else:
new_enc = activation_qparams[top][0] #List[dict]

model.set_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind, bitwidth=8, min=new_enc["min"], max=new_enc["max"])
print('ppq quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind))
model.quantize_weights(should_quantize=True)
model.save(output_dlc)

if __name__ == '__main__':
args = parser.parse_args()
act_ranges = json_load(args.qparam)['activation_encodings']
write_qparams_to_dlc_model(args.input_dlc_model, args.output_dlc_model, act_ranges)

0 comments on commit 4486ab6

Please sign in to comment.