From 71f5abcb0c45e1f94ee7364430f0ce4de7c39a61 Mon Sep 17 00:00:00 2001 From: AwesomeCodingBoy <43309460+ZhangZhiPku@users.noreply.github.com> Date: Thu, 9 Mar 2023 15:30:51 +0800 Subject: [PATCH] Update 20230306 (#403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 为 equalization pass 修复一个 bug * 为 lsq pass 添加一个接口,允许用户传入优化器 * 添加 fuse matmul+add 函数 * 修复了一些 typo * 上传了 QuantZoo 数据集 --- ppq/IR/morph.py | 32 ++ ppq/parser/onnx_exporter.py | 13 +- ppq/quantization/algorithm/equalization.py | 2 +- ppq/quantization/algorithm/training.py | 6 +- ppq/quantization/observer/range.py | 2 +- ppq/quantization/optim/training.py | 7 +- ppq/samples/QuantZoo/QuantZoo_Imagenet.py | 165 +++++++++ ppq/samples/QuantZoo/QuantZoo_OCR.py | 151 ++++++++ ppq/samples/QuantZoo/QuantZoo_Segmentation.py | 139 ++++++++ ppq/samples/QuantZoo/QuantZoo_SuperRes.py | 137 ++++++++ ppq/samples/QuantZoo/QuantZoo_Yolo.py | 179 ++++++++++ ppq/samples/QuantZoo/Readme.md | 330 ++++++++++++++++++ ppq/utils/fetch.py | 6 +- 13 files changed, 1147 insertions(+), 22 deletions(-) create mode 100644 ppq/samples/QuantZoo/QuantZoo_Imagenet.py create mode 100644 ppq/samples/QuantZoo/QuantZoo_OCR.py create mode 100644 ppq/samples/QuantZoo/QuantZoo_Segmentation.py create mode 100644 ppq/samples/QuantZoo/QuantZoo_SuperRes.py create mode 100644 ppq/samples/QuantZoo/QuantZoo_Yolo.py create mode 100644 ppq/samples/QuantZoo/Readme.md diff --git a/ppq/IR/morph.py b/ppq/IR/morph.py index 1f5feb8b..909b7937 100644 --- a/ppq/IR/morph.py +++ b/ppq/IR/morph.py @@ -1042,6 +1042,38 @@ def fuse_selfattention(self): if v is not None: non_empty_attr[k] = v op._attributes = non_empty_attr + def fuse_matmul_add(self, verbose: bool = True): + """ + Fuse Matmul + bias add to PPQBiasFusedMatMul + + PPQBiasFusedMatMul is a temporary operation which will be splited when exporting. + """ + graph, fused = self.graph, False + for current_op in [_ for _ in graph.operations.values()]: + if current_op.type != 'MatMul': continue + + # check down-stream op is add + next_ops = graph.get_downstream_operations(current_op) + if len(next_ops) != 1: continue + if next_ops[0].type != 'Add': continue + + # check if is a constant add + fusing_op = next_ops[0] + if fusing_op.num_of_parameter == 1: + + # do graph fusion + bias = fusing_op.parameters[0].value + graph.remove_operation(fusing_op, keep_coherence=True) + graph.create_variable(value=bias, is_parameter=True, dest_ops=[current_op]) + current_op.type = 'PPQBiasFusedMatMul' + fused = True + + if verbose: + print(f'Fusing graph op: {current_op.name} + {fusing_op.name}') + + if not fused: + ppq_warning("No suitable matmul + add was found, check your graph again.") + class GraphDecomposer(GraphCommandProcessor): """Since PPQ 0.6.4, GraphDecomposer is introduced to split some complex diff --git a/ppq/parser/onnx_exporter.py b/ppq/parser/onnx_exporter.py index a4892cd2..6f31c56d 100644 --- a/ppq/parser/onnx_exporter.py +++ b/ppq/parser/onnx_exporter.py @@ -55,17 +55,6 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation: graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op]) graph.remove_variable(op.inputs[-1]) -class PPQBiasFusedMatMulExporter(OperationExporter): - def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation: - if op.num_of_input == 3: bias = op.inputs[-1] - assert bias.is_parameter and bias.value is not None, 'MatMul Format Error' - - bias_op = graph.create_operation(op_type='Add') - op.type = 'MatMul' - graph.insert_op_after(bias_op, op) - graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op]) - graph.remove_variable(op.inputs[-1]) - OP_CONVERTERS = { 'ConstantOfShape': ConstantOfShapeExporter, 'MMCVRoiAlign': MMCVExporter, @@ -83,7 +72,7 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation: 'QLinearMul': OOSExporter, 'QLinearReduceMean': OOSExporter, 'QLinearSigmoid': OOSExporter, - 'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter + # 'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter } def convert_value(value: Union[int, float, np.ndarray, torch.Tensor]) -> str: diff --git a/ppq/quantization/algorithm/equalization.py b/ppq/quantization/algorithm/equalization.py index 43b54f98..74984259 100644 --- a/ppq/quantization/algorithm/equalization.py +++ b/ppq/quantization/algorithm/equalization.py @@ -82,7 +82,7 @@ def key_value_from_upstream( # ---------------------------------- # step - 3, extract activation from op: # ---------------------------------- - if including_act and op.inputs[0].value is not None: + if including_act and op.outputs[0].value is not None: a = op.outputs[0].value * act_multiplier buffer.append(a) diff --git a/ppq/quantization/algorithm/training.py b/ppq/quantization/algorithm/training.py index 587bee39..c340f320 100644 --- a/ppq/quantization/algorithm/training.py +++ b/ppq/quantization/algorithm/training.py @@ -332,7 +332,7 @@ def __init__( if self.is_parameter and is_parameter_trainable: self.param_backup = self.var.value.clone() - # There is 4 checks for training scale: + # There is 4 checks for scale training: # 1. scale is valid # 2. state is active # 3. do not have POWER_OF_2 policy but Must have Linear policy @@ -348,7 +348,7 @@ def __init__( self.is_scale_trainable = True self.scale_backup = self.config.scale.detach().clone() - # There is 4 checks for training offset: + # There is 4 checks for offset training: # 1. offset is valid # 2. state is active # 3. do not have SYMMETRICAL policy @@ -419,4 +419,4 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to quantized = (quantized - offset.detach()) * scale quantized = quantized return quantized - + diff --git a/ppq/quantization/observer/range.py b/ppq/quantization/observer/range.py index 3dd8cdc1..3354c206 100644 --- a/ppq/quantization/observer/range.py +++ b/ppq/quantization/observer/range.py @@ -240,7 +240,7 @@ def hist_to_scale_offset( losses, quant_bins = [], 2 ** (config.num_of_bits - 1) - # following code is curcial, do not move + # following code is curcial, do not remove histogram[: int(hist_bins * .002)] = 0 histogram[int(hist_bins * .002)] = 1 diff --git a/ppq/quantization/optim/training.py b/ppq/quantization/optim/training.py index 25f1734b..78710934 100644 --- a/ppq/quantization/optim/training.py +++ b/ppq/quantization/optim/training.py @@ -709,8 +709,9 @@ class LearnedStepSizePass(TrainingBasedPass): def __init__( self, name: str = 'PPQ LSQ Optimization', interested_layers: List[str] = [], steps: int = 500, gamma: float = 0.0, is_scale_trainable: bool = True, - lr: float = 5e-5, block_size: int = None, expire_device: str = 'cpu', + lr: float = 5e-5, block_size: int = 5, expire_device: str = 'cpu', collecting_device: str = 'cuda', loss_fn: Callable = torch_mean_square_error, + optimizer: Any = None ) -> None: super().__init__(name=name) self.interested_layers = interested_layers @@ -722,6 +723,7 @@ def __init__( self.gamma = gamma self.steps = steps self.lr = lr + self.optimizer = optimizer def finetune( self, steps: int, learning_rate: float, block: TrainableBlock, executor: TorchExecutor, @@ -764,8 +766,9 @@ def finetune( return 0, 0 # initilize optimizer. - if optimizer is None: + if self.optimizer is None: optimizer = torch.optim.Adam(tensors, lr=learning_rate) + else: optimizer = self.optimizer(tensors, lr=learning_rate) dataset_length = len(qt_inputs) if dataset_length == 0: raise ValueError('Dataset is empty.') diff --git a/ppq/samples/QuantZoo/QuantZoo_Imagenet.py b/ppq/samples/QuantZoo/QuantZoo_Imagenet.py new file mode 100644 index 00000000..5b94768e --- /dev/null +++ b/ppq/samples/QuantZoo/QuantZoo_Imagenet.py @@ -0,0 +1,165 @@ +# Test Quantization System Performace on Image Classification Models with ILSVRC2012 Dataset + +# Should contains model file(.onnx) +MODEL_DIR = 'QuantZoo/Model/Imagenet' + +# Should contains Calib & Test Img Folder +CALIB_DIR = 'QuantZoo/Data/Imagenet/Calib' +TEST_DIR = 'QuantZoo/Data/Imagenet/Test' + +# calibration & test batchsize +BATCHSIZE = 32 + +# Quantizer Configuration +SYMMETRICAL = True +PER_CHANNEL = True +POWER_OF_2 = False +BIT_WIDTH = 8 + +# write report to here +REPORT_DIR = 'QuantZoo/Reports' + +CONFIGS = [ +{ + 'Model': 'efficientnet_v1_b0', + 'Output': ['/features/features.8/features.8.2/Mul_output_0'] +}, +{ + 'Model': 'efficientnet_v1_b1', + 'Output': ['/features/features.8/features.8.2/Mul_output_0'] +}, +{ + 'Model': 'efficientnet_v2_s', + 'Output': ['/features/features.7/features.7.2/Mul_output_0'] +}, +{ + 'Model': 'mnasnet0_5', + 'Output': ['/layers/layers.16/Relu_output_0'] +}, +{ + 'Model': 'mnasnet1_0', + 'Output': ['/layers/layers.16/Relu_output_0'] +}, +{ + 'Model': 'mobilenet_v2', + 'Output': ['/features/features.18/features.18.2/Clip_output_0'] +}, +{ + 'Model': 'resnet18', + 'Output': ['/layer4/layer4.1/relu_1/Relu_output_0'] +}, +{ + 'Model': 'resnet50', + 'Output': ['/layer4/layer4.2/relu_2/Relu_output_0'] +}, + +{ + 'Model': 'mobilenet_v3_large', + 'Output': ['/classifier/classifier.1/Mul_output_0'] +}, +{ + 'Model': 'mobilenet_v3_small', + 'Output': ['/classifier/classifier.1/Mul_output_0'] +}, +{ + 'Model': 'v100_gpu64@5ms_top1@71.6_finetune@25', + 'Output': ['471'] +}, +{ + 'Model': 'v100_gpu64@6ms_top1@73.0_finetune@25', + 'Output': ['471'] +}, +{ + # vit_b_16 requires BATCHSIZE = 1! + 'Model': 'vit_b_16', + 'Output': ['onnx::Gather_1703'] +} +] + +import os + +import torch + +import ppq.lib as PFL +from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq.quantization.optim import (LayerwiseEqualizationPass, + LearnedStepSizePass, ParameterQuantizePass, + RuntimeCalibrationPass) +from QuantZoo.Data.Imagenet.Eval import (evaluate_ppq_module_with_imagenet, + load_imagenet_from_directory) +from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer +from QuantZoo.Util import error_analyze + + +calib_loader = load_imagenet_from_directory( + directory=CALIB_DIR, batchsize=BATCHSIZE, + shuffle=False, require_label=False, + num_of_workers=8) + + +test_loader = load_imagenet_from_directory( + directory=TEST_DIR, batchsize=BATCHSIZE, + shuffle=False, require_label=True, + num_of_workers=8) + + +with ENABLE_CUDA_KERNEL(): + for config in CONFIGS: + model = config['Model'] + monitoring_vars = config['Output'] + + print(f"Ready to run quant benchmark on {model}") + graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx')) + + if model == 'vit_b_16': + if BATCHSIZE == 32: + raise Exception('To Evaluate vit_b_16, change batchsize to 1, change calibration method to minmax.') + from ppq.IR import GraphMerger + processor = GraphMerger(graph) + processor.fuse_matmul_add() + processor.fuse_layernorm() + processor.fuse_gelu() + + quantizer = MyInt8Quantizer( + graph=graph, sym=SYMMETRICAL, power_of_2=POWER_OF_2, + num_of_bits=BIT_WIDTH, per_channel=PER_CHANNEL) + # quantizer = MyFP8Quantizer(graph=graph) + + # convert op to quantable-op + for name, op in graph.operations.items(): + if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm', + 'PPQBiasFusedMatMul', 'LayerNormalization'}: + quantizer.quantize_operation(name, platform=TargetPlatform.INT8) + + # build quant pipeline. + pipeline = PFL.Pipeline([ + # LayerwiseEqualizationPass(iteration=10), + ParameterQuantizePass(), + RuntimeCalibrationPass(), + # LearnedStepSizePass(steps=500, collecting_device='cuda', block_size=5) + ]) + + # call pipeline. + executor = TorchExecutor(graph=graph) + executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 224, 224]).cuda()) + + pipeline.optimize( + graph=graph, dataloader=calib_loader, verbose=True, + calib_steps=32, collate_fn=lambda x: x.to('cuda'), executor=executor) + + # evaluation + acc = evaluate_ppq_module_with_imagenet( + model=graph, imagenet_validation_loader=test_loader, + batchsize=BATCHSIZE, device='cuda', verbose=False) + print(f'Model Classify Accurarcy = {acc: .4f}%') + + # error analyze + performance = error_analyze( + graph=graph, + outputs=monitoring_vars, + dataloader=test_loader, + collate_fn=lambda x: x[0].to('cuda'), + verbose=True + ) \ No newline at end of file diff --git a/ppq/samples/QuantZoo/QuantZoo_OCR.py b/ppq/samples/QuantZoo/QuantZoo_OCR.py new file mode 100644 index 00000000..f5a5058f --- /dev/null +++ b/ppq/samples/QuantZoo/QuantZoo_OCR.py @@ -0,0 +1,151 @@ +# Test Quantization System Performace on OCR Models with IC15 Dataset + +# Should contains model file(.onnx) +MODEL_DIR = 'QuantZoo/Model/ocr' + +# Should contains Calib & Test Img Folder +CALIB_DIR = 'QuantZoo/Data/IC15' +CALIB_LABEL = 'QuantZoo/Data/IC15/rec_gt_train.txt' +TEST_DIR = 'QuantZoo/Data/IC15' +TEST_LABEL = 'QuantZoo/Data/IC15/rec_gt_test.txt' +CHAR_DIR = 'QuantZoo/Data/IC15/ic15_dict.txt' + +# calibration & test batchsize +BATCHSIZE = 32 + +# Quantizer Configuration +SYMMETRICAL = True +PERCHANNEL = True +POWER_OF_2 = False +BIT_WIDTH = 8 + +# write report to here +REPORT_DIR = 'QuantZoo/Reports' + +CONFIGS = [ + +{ + 'Model': 'en_PP-OCRv3_rec_infer', + 'Output': ['swish_13.tmp_0'], + 'Dictionary': 'en_dict.txt', + 'Reshape': [3, 48, 320], + 'Language': 'en', +}, +{ + 'Model': 'en_number_mobile_v2.0_rec_infer', + 'Output': ['save_infer_model/scale_0.tmp_1'], + 'Dictionary': 'en_dict.txt', + 'Reshape': [3, 32, 320], + 'Language': 'en', +}, +{ + 'Model': 'ch_PP-OCRv2_rec_infer', + 'Output': ['p2o.LSTM.5'], + 'Dictionary': 'ppocr_keys_v1.txt', + 'Reshape': [3, 32, 320], + 'Language': 'ch', +}, +{ + 'Model': 'ch_PP-OCRv3_rec_infer', + 'Output': ['swish_27.tmp_0'], + 'Dictionary': 'ppocr_keys_v1.txt', + 'Reshape': [3, 48, 320], + 'Language': 'ch', +}, +{ + 'Model': 'ch_ppocr_mobile_v2.0_rec_infer', + 'Output': ['p2o.LSTM.5'], + 'Dictionary': 'ppocr_keys_v1.txt', + 'Reshape': [3, 32, 320], + 'Language': 'ch', +}, +{ + 'Model': 'ch_ppocr_server_v2.0_rec_infer', + 'Output': ['p2o.LSTM.5'], + 'Dictionary': 'ppocr_keys_v1.txt', + 'Reshape': [3, 32, 320], + 'Language': 'ch', +}, +] + +import os + +import torch + +import ppq.lib as PFL +from ppq import convert_any_to_torch_tensor, graphwise_error_analyse +from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq.quantization.optim import (LayerwiseEqualizationPass, + LearnedStepSizePass, ParameterQuantizePass, + RuntimeCalibrationPass) +from QuantZoo.Data.IC15.Data import IC15_PaddleOCR +from QuantZoo.Data.IC15.Eval import evaluate_ppq_module_with_ic15 +from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer +from QuantZoo.Util import error_analyze, report + + +with ENABLE_CUDA_KERNEL(): + for config in CONFIGS: + model = config['Model'] + monitoring_vars = config['Output'] + dictionary = config['Dictionary'] + shape = config['Reshape'] + chinese = config['Language'] == 'ch' + + calib_loader = IC15_PaddleOCR( + images_path=CALIB_DIR, + label_path=CALIB_LABEL, + input_shape=shape, + is_chinese_version=chinese).dataloader( + batchsize=BATCHSIZE, shuffle=False) + + test_loader = IC15_PaddleOCR( + images_path=TEST_DIR, + label_path=TEST_LABEL, + input_shape=shape, + is_chinese_version=chinese).dataloader( + batchsize=BATCHSIZE, shuffle=False) + + print(f"Ready to run quant benchmark on {model}") + graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx')) + + quantizer = MyInt8Quantizer(graph=graph, sym=SYMMETRICAL, power_of_2=POWER_OF_2, + num_of_bits=BIT_WIDTH, per_channel=PERCHANNEL) + # quantizer = MyFP8Quantizer(graph=graph) + + # convert op to quantable-op + for name, op in graph.operations.items(): + if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm'}: + quantizer.quantize_operation(name, platform=TargetPlatform.INT8) + + # build quant pipeline. + pipeline = PFL.Pipeline([ + # LayerwiseEqualizationPass(iteration=10), + ParameterQuantizePass(), + RuntimeCalibrationPass(), + # LearnedStepSizePass(steps=500, collecting_device='cuda') + ]) + + # call pipeline. + executor = TorchExecutor(graph=graph) + executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 32, 100]).cuda()) + + pipeline.optimize( + graph=graph, dataloader=calib_loader, verbose=True, + calib_steps=32, collate_fn=lambda x: x[0].to('cuda'), executor=executor) + + acc = evaluate_ppq_module_with_ic15( + executor=executor, character_dict_path=os.path.join(MODEL_DIR, dictionary), + dataloader=test_loader, collate_fn=lambda x: convert_any_to_torch_tensor(x).cuda()) + print(f'Model Performace on IC15: {acc * 100 :.4f}%') + + # error analyze + performance = error_analyze( + graph=graph, + outputs=monitoring_vars, + dataloader=test_loader, + collate_fn=lambda x: x[0].to('cuda'), + verbose=True + ) \ No newline at end of file diff --git a/ppq/samples/QuantZoo/QuantZoo_Segmentation.py b/ppq/samples/QuantZoo/QuantZoo_Segmentation.py new file mode 100644 index 00000000..cf2bb81a --- /dev/null +++ b/ppq/samples/QuantZoo/QuantZoo_Segmentation.py @@ -0,0 +1,139 @@ +# Test Quantization System Performace on Image Classification Models with ILSVRC2012 Dataset + +# Should contains model file(.onnx) +MODEL_DIR = 'QuantZoo/Model/mmseg' + +# Should contains Calib & Test Img Folder +CALIB_DIR = 'QuantZoo/Data/Cityscapes/Calib' +TEST_DIR = 'QuantZoo/Data/Cityscapes/Test' + +# calibration & test batchsize +BATCHSIZE = 1 + +# Quantizer Configuration +SYMMETRICAL = True +PERCHANNEL = True +POWER_OF_2 = False +BIT_WIDTH = 8 + +# write report to here +REPORT_DIR = 'QuantZoo/Reports' + +CONFIGS = [ + +{ + 'Model': 'stdc1_512x1024_80k_cityscapes', + 'Output': ['/convs/convs.0/activate/Relu_output_0'] +}, + +{ + 'Model': 'pspnet_r50-d8_512x1024_40k_cityscapes', + 'Output': ['/bottleneck/activate/Relu_output_0'] +}, + +{ + 'Model': 'pointrend_r50_512x1024_80k_cityscapes', # complex model + 'Output': ['/Concat_60_output_0'] +}, + +{ + 'Model': 'fpn_r50_512x1024_80k_cityscapes', + 'Output': ['/Add_2_output_0'] +}, + +{ + 'Model': 'icnet_r18-d8_832x832_80k_cityscapes', + 'Output': ['/convs/convs.0/activate/Relu_output_0'] +}, + +{ + 'Model': 'fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes', + 'Output': ['/convs/convs.0/activate/Relu_output_0'] +}, + +{ + 'Model': 'fast_scnn_lr0.12_8x4_160k_cityscapes', + 'Output': ['/convs/convs.1/pointwise_conv/activate/Relu_output_0'] +}, + +{ + 'Model': 'fcn_r50-d8_512x1024_40k_cityscapes', + 'Output': ['/conv_cat/activate/Relu_output_0'] +}, + +{ + 'Model': 'bisenetv2_fcn_4x4_1024x1024_160k_cityscapes', + 'Output': ['/convs/convs.0/activate/Relu_output_0'] +}, + +{ + 'Model': 'deeplabv3_r50-d8_512x1024_40k_cityscapes', + 'Output': ['/bottleneck/activate/Relu_output_0'] +}, + +] + +import os + +import torch + +import ppq.lib as PFL +from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq.quantization.optim import (LayerwiseEqualizationPass, + LearnedStepSizePass, ParameterQuantizePass, + RuntimeCalibrationPass) +from QuantZoo.Data.Cityscapes.Data import load_cityscapes_dataset +from QuantZoo.Data.Cityscapes.Eval import evaluation +from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer +from QuantZoo.Util import error_analyze, report + +calib_loader = load_cityscapes_dataset(img_folder=CALIB_DIR) +test_loader = load_cityscapes_dataset(img_folder=TEST_DIR) + + +with ENABLE_CUDA_KERNEL(): + for config in CONFIGS: + model = config['Model'] + monitoring_vars = config['Output'] + + print(f"Ready to run quant benchmark on {model}") + graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx')) + quantizer = MyInt8Quantizer( + graph=graph, sym=SYMMETRICAL, power_of_2=POWER_OF_2, + num_of_bits=BIT_WIDTH, per_channel=PERCHANNEL) + # quantizer = MyFP8Quantizer(graph=graph) + + # convert op to quantable-op + for name, op in graph.operations.items(): + if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm'}: + quantizer.quantize_operation(name, platform=TargetPlatform.INT8) + + # build quant pipeline. + pipeline = PFL.Pipeline([ + # LayerwiseEqualizationPass(iteration=10), + ParameterQuantizePass(), + RuntimeCalibrationPass(), + # LearnedStepSizePass(steps=500, collecting_device='cpu') + ]) + + # call pipeline. + executor = TorchExecutor(graph=graph) + executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 1024, 2048]).cuda()) + + pipeline.optimize( + graph=graph, dataloader=calib_loader, verbose=True, + calib_steps=32, collate_fn=lambda x: x[0].to('cuda'), executor=executor) + + miou = evaluation(graph=graph, dataloader=test_loader, working_directory=TEST_DIR) + print(f'Model Performance on CityScapes Miou: {miou * 100: .4f}%') + + # error analyze + performance = error_analyze( + graph=graph, + outputs=monitoring_vars, + dataloader=test_loader, + collate_fn=lambda x: x[0].cuda(), + verbose=True + ) \ No newline at end of file diff --git a/ppq/samples/QuantZoo/QuantZoo_SuperRes.py b/ppq/samples/QuantZoo/QuantZoo_SuperRes.py new file mode 100644 index 00000000..54f75a22 --- /dev/null +++ b/ppq/samples/QuantZoo/QuantZoo_SuperRes.py @@ -0,0 +1,137 @@ +# Test Quantization System Performance on Detection Models with Coco Dataset + +# Should contains model file(.onnx) +MODEL_DIR = 'QuantZoo/Model/mmedit' + +# Should contains Calib & Test Img Folder +TRAIN_HR_DIR = 'QuantZoo/Data/DIV2K/DIV2K_train_HR' +TRAIN_LR_DIR = 'QuantZoo/Data/DIV2K/DIV2K_train_LR_bicubic' +VALID_HR_DIR = 'QuantZoo/Data/DIV2K/DIV2K_valid_HR' +VALID_LR_DIR = 'QuantZoo/Data/DIV2K/DIV2K_valid_LR_bicubic' + +# calibration & test batchsize +# super resolution model must have batchsize = 1 +BATCHSIZE = 1 + +# Quantizer Configuration +SYMMETRICAL = True +PER_CHANNEL = True +POWER_OF_2 = False +BIT_WIDTH = 8 + +# write report to here +REPORT_DIR = 'QuantZoo/Reports' + +CONFIGS = [ +{ + 'Model': 'srcnn_x4k915_g1_1000k_div2k', + 'Output': ['output'], +}, +{ + 'Model': 'srgan_x4c64b16_g1_1000k_div2k', + 'Output': ['output'], +}, +{ + 'Model': 'rdn_x4c64b16_g1_1000k_div2k', + 'Output': ['output'], +}, +{ + 'Model': 'edsr_x4c64b16_g1_300k_div2k', + 'Output': ['/generator/conv_last/Conv_output_0'], +}, +] + +import os +from typing import Iterable + +import torch +from tqdm import tqdm + +import ppq.lib as PFL +from ppq import convert_any_to_numpy, graphwise_error_analyse +from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq.IR import BaseGraph, GraphFormatter +from ppq.quantization.optim import (LayerwiseEqualizationPass, + LearnedStepSizePass, ParameterQuantizePass, + RuntimeCalibrationPass) +from QuantZoo.Data.DIV2K.Data import load_div2k_dataset +from QuantZoo.Data.DIV2K.Eval import psnr, ssim +from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer +from QuantZoo.Util import error_analyze, report + + +def evaluation(graph: BaseGraph, dataloader: Iterable, method: str='psnr'): + if method not in {'psnr', 'ssim'}: raise Exception('Evaluation method not understood.') + executor = TorchExecutor(graph) + ret_collector = [] + + for lr_img, hr_img in tqdm(dataloader): + pred = executor.forward(lr_img.cuda())[0] + real = hr_img + + # post processing + pred = convert_any_to_numpy((pred.squeeze(0) * 255).round()) + real = convert_any_to_numpy((real.squeeze(0) * 255).round()) + + if method == 'psnr': sample_ret = psnr(img1=real, img2=pred, input_order='CHW') + else: sample_ret = ssim(img1=real, img2=pred, input_order='CHW') + ret_collector.append(sample_ret) + + return sum(ret_collector) / len(ret_collector) + +calib_loader = load_div2k_dataset( + lr_folder = TRAIN_LR_DIR, + hr_folder = TRAIN_HR_DIR) + +test_loader = load_div2k_dataset( + lr_folder = VALID_LR_DIR, + hr_folder = VALID_HR_DIR) + +with ENABLE_CUDA_KERNEL(): + for config in CONFIGS: + model = config['Model'] + monitoring_vars = config['Output'] + + print(f"Ready to run quant benchmark on {model}") + graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx')) + + quantizer = MyInt8Quantizer(graph=graph, sym=SYMMETRICAL, per_channel=PER_CHANNEL, + power_of_2=POWER_OF_2, num_of_bits=BIT_WIDTH) + # quantizer = MyFP8Quantizer(graph=graph) + + # convert op to quantable-op + for name, op in graph.operations.items(): + if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm'}: + quantizer.quantize_operation(name, platform=TargetPlatform.INT8) + + # build quant pipeline. + pipeline = PFL.Pipeline([ + # LayerwiseEqualizationPass(iteration=10), + ParameterQuantizePass(), + RuntimeCalibrationPass(), + # LearnedStepSizePass(steps=500, collecting_device='cpu') + ]) + + + # call pipeline. + executor = TorchExecutor(graph=graph) + executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 480, 640]).cuda()) + + pipeline.optimize( + graph=graph, dataloader=calib_loader, verbose=True, + calib_steps=32, collate_fn=lambda x: x[0].cuda(), + executor=executor) + + result = evaluation(graph=graph, dataloader=test_loader, method='psnr') + print(f'Model Performance on DIV2K PSNR: {result}') + + # error analyze + performance = error_analyze( + graph=graph, + outputs=monitoring_vars, + dataloader=test_loader, + collate_fn=lambda x: x[0].cuda(), + verbose=True + ) \ No newline at end of file diff --git a/ppq/samples/QuantZoo/QuantZoo_Yolo.py b/ppq/samples/QuantZoo/QuantZoo_Yolo.py new file mode 100644 index 00000000..69c8a1d9 --- /dev/null +++ b/ppq/samples/QuantZoo/QuantZoo_Yolo.py @@ -0,0 +1,179 @@ +# Test Quantization System Performance on Detection Models with Coco Dataset + +# Should contains model file(.onnx) +MODEL_DIR = 'QuantZoo/Model/yolo' + +# Should contains Calib & Test Img Folder +CALIB_DIR = 'QuantZoo/Data/Coco/Calib' +TEST_DIR = 'QuantZoo/Data/Coco/Test' +CALIB_ANN_FILE = 'QuantZoo/Data/Coco/Calib/DetectionAnnotation.json' +TEST_ANN_FILE = 'QuantZoo/Data/Coco/Test/DetectionAnnotation.json' +PRED_ANN_FILE = 'QuantZoo/Data/Coco/Test/DetectionPrediction.json' +VALID_DIR = '/mnt/hpc/share/wangpeiqi/COCO/val2017' +VALID_ANN_FILE = '/mnt/hpc/share/wangpeiqi/COCO/annotations/instances_val2017.json' +EVAL_MODE = True # for coco evaluation + +# calibration & test batchsize +# yolo requires batchsize = 1 +BATCHSIZE = 1 + +# Quantizer Configuration +SYMMETRICAL = True +PER_CHANNEL = True +POWER_OF_2 = False +BIT_WIDTH = 8 + +# write report to here +REPORT_DIR = 'QuantZoo/Reports' + +CONFIGS = [ +{ + 'Model': 'yolov6p5_n', + 'Output': ['/Concat_5_output_0', '/Concat_4_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolov6p5_t', + 'Output': ['/Concat_5_output_0', '/Concat_4_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolov5s6_n', + 'Output': ['/baseModel/head_module/convs_pred.1/Conv_output_0', '/baseModel/head_module/convs_pred.2/Conv_output_0', '/baseModel/head_module/convs_pred.0/Conv_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolov5s6_s', + 'Output': ['/baseModel/head_module/convs_pred.1/Conv_output_0', '/baseModel/head_module/convs_pred.2/Conv_output_0', '/baseModel/head_module/convs_pred.0/Conv_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolov7p5_tiny', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0', '/Concat_6_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolov7p5_l', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0', '/Concat_6_output_0'], + 'collate_fn': lambda x: x[0].cuda() # img preprocessing function +}, +{ + 'Model': 'yolox_s', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0', '/Concat_6_output_0'], + 'collate_fn': lambda x: x[0].cuda() * 255 # img preprocessing function +}, +{ + 'Model': 'yolox_tiny', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0', '/Concat_6_output_0'], + 'collate_fn': lambda x: x[0].cuda() * 255 # img preprocessing function +}, +{ + 'Model': 'ppyoloe_m', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0'], + 'collate_fn': lambda x: ( + x[0].cuda() * 255 - torch.tensor([103.53, 116.28, 123.675]).reshape([1, 3, 1, 1]).cuda() + ) / 255 # img preprocessing function +}, +{ + 'Model': 'ppyoloe_s', + 'Output': ['/Concat_4_output_0', '/Concat_5_output_0'], + 'collate_fn': lambda x: ( + x[0].cuda() * 255 - torch.tensor([103.53, 116.28, 123.675]).reshape([1, 3, 1, 1]).cuda() + ) / 255 # img preprocessing function +}, +] + +import os + +import torch + +import ppq.lib as PFL +from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph +from ppq.core import TargetPlatform +from ppq.executor import TorchExecutor +from ppq.quantization.optim import (LayerwiseEqualizationPass, + LearnedStepSizePass, ParameterQuantizePass, + RuntimeCalibrationPass) +from ppq.IR import GraphFormatter + +from QuantZoo.Data.Coco.Data import load_coco_detection_dataset +from QuantZoo.Data.Coco.Eval import evaluate_ppq_module_with_coco +from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer +from QuantZoo.Util import error_analyze + + +calib_loader = load_coco_detection_dataset( + data_dir=CALIB_DIR, + batchsize=BATCHSIZE) + +test_loader = load_coco_detection_dataset( + data_dir=TEST_DIR, + batchsize=BATCHSIZE) + + +with ENABLE_CUDA_KERNEL(): + for config in CONFIGS: + model = config['Model'] + monitoring_vars = config['Output'] + collate_fn = config['collate_fn'] + + print(f"Ready to run quant benchmark on {model}") + graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx')) + + # if EVAL_MODE == False, truncate graph + if EVAL_MODE == False: + graph.outputs.clear() + editor = GraphFormatter(graph) + for var in monitoring_vars: + graph.mark_variable_as_graph_output(graph.variables[var]) + editor.delete_isolated() + else: + editor = GraphFormatter(graph) + graph.outputs.pop('scores') + graph.outputs.pop('num_dets') + graph.mark_variable_as_graph_output(graph.variables['/Split_output_1']) + editor.delete_isolated() + + quantizer = MyInt8Quantizer(graph=graph, sym=SYMMETRICAL, + per_channel=PER_CHANNEL, power_of_2=POWER_OF_2, + num_of_bits=BIT_WIDTH) + # quantizer = MyFP8Quantizer(graph=graph) + + # convert op to quantable-op + for name, op in graph.operations.items(): + if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm'}: + quantizer.quantize_operation(name, platform=TargetPlatform.INT8) + + # build quant pipeline. + pipeline = PFL.Pipeline([ + # LayerwiseEqualizationPass(iteration=10), + ParameterQuantizePass(), + RuntimeCalibrationPass(), + # LearnedStepSizePass(steps=500, collecting_device='cpu') + ]) + + # call pipeline. + executor = TorchExecutor(graph=graph) + executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 640, 640]).cuda()) + + pipeline.optimize( + graph=graph, dataloader=calib_loader, verbose=True, + calib_steps=32, collate_fn=collate_fn, + executor=executor) + + # evaluation 好像 batchsize != 1 会错 + evaluate_ppq_module_with_coco( + ann_file=TEST_ANN_FILE, + output_file=PRED_ANN_FILE, + executor=executor, + dataloader=test_loader, + collate_fn=collate_fn) + + # error analyze + performance = error_analyze( + graph=graph, + outputs=monitoring_vars, + dataloader=test_loader, + collate_fn=collate_fn, + verbose=True + ) diff --git a/ppq/samples/QuantZoo/Readme.md b/ppq/samples/QuantZoo/Readme.md new file mode 100644 index 00000000..cc9722de --- /dev/null +++ b/ppq/samples/QuantZoo/Readme.md @@ -0,0 +1,330 @@ +# ONNX Quantization Model Zoo (OnnxQuant) + +:smile: OnnxQuant 是目前最大的模型量化数据集,它包含 ONNX 模型,数据,以及相关的测试脚本。该数据集的提出用于推动模型量化在视觉模型中的应用与量化算法的研究,具备以下特点: + +1. 可移植与可复现,所有模型均由 ONNX 格式提供。 +2. 包含图像分类、图像分割、图像超分辨率、图像-文字识别、目标检测、姿态检测等多个任务的模型。 +3. 提供切分好的 calibration 数据和 test 数据,提供模型精度测试脚本。 +4. 提供灵活的量化器用于确定模型在不同量化规则下的理论性能,并提供 FP8 量化器。 + +:eyes: OnnxQuant 目前处于公开测试阶段,近几个月内仍然将发生修改与变动。 + +## 1. 如何使用 + +### 1.1 下载数据集: + +1. 图像分类: https://pan.baidu.com/s/1CIrQBvMkPaI-19M8IpVP8w?pwd=z5z8 +2. 图像超分: https://pan.baidu.com/s/1u7ZAVNtlaMHBzDzzq-1eCw?pwd=gzsb +3. 图像分割: https://pan.baidu.com/s/1LAS8LYyklz7kgkVUuxDlLg?pwd=db6s +4. 目标检测: https://pan.baidu.com/s/1uBvK-Wm1AKVrNgvA9E4lhA?pwd=9n06 +5. 姿态检测: +6. 图像-文字识别: https://pan.baidu.com/s/1GyYvLbhibLL6kPIA1J0X7Q?pwd=vpxi +7. NLP: + +### 1.2 建立工作目录: + +在工作目录下建立文件夹 QuantZoo,解压上述文件到 QuantZoo 中。你将获得这样的文件夹结构 + +``` +~/QuantZoo/Data/Cityscapes +~/QuantZoo/Data/Coco +~/QuantZoo/Data/DIV2K +~/QuantZoo/Data/IC15 +~/QuantZoo/Model/Imagenet +~/QuantZoo/Model/mmedit +~/QuantZoo/Model/mmseg +~/QuantZoo/Model/ocr +~/QuantZoo/Quantizers.py +~/QuantZoo/Util.py +``` + +### 1.3 创建入口文件 + +将 https://github.com/openppl-public/ppq/tree/master/ppq/samples/QuantZoo 目录下的文件复制到工作目录下 + +``` +~/QuantZoo_Imagenet.py +~/QuantZoo_OCR.py +~/QuantZoo_Yolo.py +~/QuantZoo_SuperRes.py +~/QuantZoo_Segmentation.py +~/QuantZoo +``` + +运行工作目录中的 python 文件即可完成 OnnxQuant 测试。 + +## 2. 环境依赖 + +1. 下载量化工具 ppq 0.6.6 以上版本,用户可以通过 pypi 进行安装 + +``` bash +pip install ppq +``` + +亦可以使用 github 上的最新版本进行安装 + +``` bash +git clone https://github.com/openppl-public/ppq.git +cd ppq +pip install -r requirements.txt +python setup.py install +``` + +2. 下载并安装所需的其他依赖库 + + * numpy + * onnx >= 1.9.0 + * protobuf + * torch >= 1.6.0 + * tqdm + * mmcv-full + * cityscapesscripts + * pycocotools + +3. 安装编译环境(对于 FP8 量化而言,该环境是必须的) + +``` bash +apt-get install ninja-build +``` + +## 3. 数据集与模型简介 + +在 OnnxQuant 中的所有模型已经预先完成了 Batchnorm 层的合并,并且模型已经完成训练过程。 + +### 3.1 图像分类 + +数据集:Imagenet + +数据切分方式: + * Calibration 数据为 Imagenet Validation Set 中随机抽取的 5000 张图片。 + * Test 数据为 Imagenet Validation Set 中随机抽取的 5000 张图片。 + * 提供数据切分脚本。 + +模型 efficientnet, mnasnet, mobilenetv2, mobilenetv3, resnet18, resnet50, vit 来自 torchvision. + +模型 once_for_all: https://github.com/mit-han-lab/once-for-all + +模型测试标准:分类准确率。 + +### 3.2 实例分割 + +数据集:cityscapes + +数据切分方式: + * Calibration 数据为 Cityscapes val 中随机抽取的 300 张图片。 + * Test 数据为 Cityscapes val 中随机抽取的 200 张图片。 + * 提供数据切分脚本。 + +模型:全部来自 mmedit + +模型测试标准:Miou。 + +### 3.3 目标检测 + +数据集: Coco 2017 + +数据切分方式: + + * Calibration 数据为 Coco 2017 Validation Set 中随机抽取的 1500 张图片。 + * Test 数据为 Coco 2017 Validation Set 中随机抽取的 300 张图片。 + * 提供数据切分脚本。 + +模型:全部来自 mmyolo + +模型测试标准:目标检测精准度(mAP)。 + +### 3.4 OCR + +数据集:IC15 + +数据切分方式: + + * Calibration 数据为 IC15 train 数据集。 + * Test 数据为 IC15 test 数据集。 + +模型:全部来自 paddle ocr + +模型测试标准:文字识别准确率。 + +### 3.5 图像超分辨率 + +数据集:DIV2K + +数据切分方式: + + * Calibration 数据为 DIV2K_train 数据集。 + * Test 数据为 DIV2K_valid 数据集。降采样方式为 x4, bicubic。 + +模型:全部来自 mmedit + +模型测试标准:峰值信噪比。 + +## 4. OnnxQuant 模型量化规则 + +### 4.1 综述: +在 OnnxQuant 中,我们希望评估模型量化的理论性能,同时隔离推量框架的具体执行细节。 + +我们要求量化所有的 **卷积与全连接层**,并且只对上述层的 **输入变量和权重** 添加量化-反量化操作,其 **偏置项(Bias)与输出变量** 依然保留为浮点精度。 +这一规则能够总体上模拟推理框架的图融合情况,简化量化过程并提升计算效率,并可获得较为准确的模型量化性能。 + +对于 Transformer Based 模型,OnnxQuant 将在 Layernormalization 层的输入变量上添加量化-反量化操作,其权重不参与量化。 + +### 4.2 量化细则 +OnnxQuant 关注以下三类量化细则下的模型性能: +| INT8 PERCHANNEL | INT8 PERTENSOR POWER-OF-2 | GRAPHCORE FP8 | +|:---|:---|:---| +| 权重使用 PERCHANNEL 量化 | 权重使用 PERTENSOR 量化,Scale 附加 POWER-OF-2 限制 | 权重使用 PERTENSOR FP8 量化 | +| 激活值使用 PERTENSOR 量化 | 激活值使用 PERTENSOR 量化,Scale 附加 POWER-OF-2 限制 | 激活值使用 PERTENSOR FP8 量化 | +| 量化范围为[-128, 127] | 量化范围为[-128, 127] | 量化范围为[-448.0, 448.0] | + +## 5. OnnxQuant Baseline + +在前文中,我们已经介绍了不同模型的测试标准,OnnxQuant 将以此为标准测试量化模型在测试数据集上的分类 / 检测 / 超分辨率 / 文字识别 / 实例分割的性能。 + +除此以外,OnnxQuant 额外引入三项通用指标对模型量化结果进行评估: + + * Average Quantization Error(AQE): 模型平均量化误差 + + * Max Quantization Error(MQE): 模型最大量化误差 + + * Output Quantization Error(OQE): 模型输出最大量化误差 + +OnnxQuant 使用相对误差评估模型量化误差,对于量化网络中的任意一个卷积层、全连接层、Layernormalization 层而言,OnnxQuant 取该层的输出结果 A 与浮点网络对应层的输出结果 B 进行对比。 +相对误差等于 || A - B || / || B ||,其中 || B || 表示向量 B 的 F-范数。 + + * 模型平均量化误差(AQE):模型中所有层的量化误差均值 + + * 模型最大量化误差(MQE):模型中所有层的量化误差最大值 + + * 模型输出最大量化误差(OQE):模型中所有输出层的量化误差最大值 + +图例:❗: 量化精度很差的模型;💔: 很差的单一指标 + +### INT8 PERCHANNEL + +| Classification | Float Accuracy | Quant Accuracy | AQE | MQE | OQE | +| ------------------ | ----------- | -------------- | ------ | ------- | ------- | +| ❗efficientnet_v1_b0 | 76.90% | 66.19% | 20.81%💔 | 60.34%💔 | 70.51%💔 | +| efficientnet_v1_b1 | 76.66% | 75.64% | 4.16% | 20.23%💔 | 15.50%💔 | +| efficientnet_v2 | 80.29% | 80.03% | 6.52% | 44.53%💔 | 41.45%💔 | +| ❗mnasnet 0.5 | 67.75% | 64.42% | 5.40% | 15.51%💔 | 24.88%💔 | +| mnasnet 1.0 | 73.48% | 72.54% | 2.29% | 5.51% | 5.71% | +| mobilenet_v2 | 71.37% | 70.99% | 4.75% | 21.13%💔 | 6.02% | +| ❗mobilenet_v3_small | 67.89% | 2.80% | 55.57%💔 | 123.36%💔 | 131.26%💔 | +| mobilenet_v3_large | 73.37% | 72.48% | 2.17% | 7.10% | 5.57% | +| resnet18 | 69.65% | 69.51% | 0.55% | 1.48% | 1.17% | +| resnet50 | 75.56% | 75.48% | 1.24% | 3.60% | 1.95% | +| once_for_all_71 | 72.30% | 71.75% | 4.11% | 33.14%💔 | 7.88% | +| once_for_all_73 | 74.54% | 74.38% | 3.49% | 32.25%💔 | 5.51% | +| ❗vit_b_16 | 80.00% | 77.90% | \* | \* | \* | + +| Segmentation | Float mIou | Quant mIou | AQE | MQE | OQE | +| -------------------------------------------- | ------- | ---------- | ----- | ------ | ----- | +| stdc1_512x1024_80k_cityscapes | 71.44% | 71.21% | 1.31% | 2.88% | 0.66% | +| pspnet_r50-d8_512x1024_40k_cityscapes | 76.48% | 76.34% | 1.77% | 4.77% | 1.40% | +| pointrend_r50_512x1024_80k_cityscapes | 75.66% | 75.99% | \* | \* | \* | +| fpn_r50_512x1024_80k_cityscapes | 73.86% | 75.20% | 1.73% | 5.62% | 0.61% | +| icnet_r18-d8_832x832_80k_cityscapes | 67.07% | 66.72% | 0.61% | 1.27% | 0.31% | +| fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes | 68.34% | 68.17% | 1.26% | 2.92% | 0.54% | +| ❗fast_scnn_lr0.12_8x4_160k_cityscapes | 70.22% | 69.38% | 3.71% | 10.19%💔 | 0.75% | +| fcn_r50-d8_512x1024_40k_cityscapes | 70.96% | 73.67% | 2.12% | 7.81%💔 | 2.14% | +| bisenetv2_fcn_4x4_1024x1024_160k_cityscapes | 71.67% | 71.91% | 1.28% | 5.80% | 0.92% | +| deeplabv3_r50-d8_512x1024_40k_cityscapes | 77.90% | 77.48% | 1.97% | 7.95%💔 | 1.17% | + +| SuperRes | Float PSNR | Quant PSNR | AQE | MQE | OQE | +| -------- | ------- | ---------- | ----- | ------ | ----- | +| edsr | 28.98 | 28.76 | 0.15% | 0.36% | 0.06% | +| rdn | 29.20 | 28.65 | 7.99%💔 | 19.94%💔 | 0.30% | +| srcnn | 27.76 | 27.18 | 4.89% | 29.80%💔 | 0.13% | +| srgan | 26.56 | 26.53 | 2.56% | 6.16% | 0.10% | + +| OCR | Float Accuracy | Quant Accuracy | AQE | MQE | OQE | +| ------------------------------- | ----------- | -------------- | ------- | ------- | ------ | +| ❗en_PP-OCRv3_rec_infer | 67.36% | 40.64% | 24.41%💔 | 64.69%💔 | 22.11%💔 | +| ❗en_number_mobile_v2.0_rec_infer | 46.22% | 21.47% | 81.85%💔 | 398.62%💔 | 6.48%💔 | +| ❗ch_PP-OCRv2_rec_infer | 54.89% | 4.91% | 114.55%💔 | 985.66%💔 | 99.83%💔 | +| ❗ch_PP-OCRv3_rec_infer | 62.69% | 0.63% | 116.20%💔 | 516.44%💔 | 72.43%💔 | +| ❗ch_ppocr_mobile_v2.0_rec_infer | 40.15% | 10.39% | 97.96%💔 | 523.46%💔 | 41.99%💔 | +| ch_ppocr_server_v2.0_rec_infer | 54.98% | 53.92% | 3.30% | 14.72%💔 | 8.95% | + +| Detection | Float mAP | Quant mAP | AQE | MQE | OQE | +| ------------- | ------ | --------- | ----- | ------ | ----- | +| yolov6p5_n | 49.80% | 47.10% | 2.40% | 13.36%💔 | 1.46% | +| yolov6p5_t | 52.40% | 50.60% | 6.00% | 17.56%💔 | 3.40% | +| ❗yolov5s6_n | 39.80% | 35.60% | 1.20% | 2.60% | 0.29% | +| ❗yolov5s6_s | 47.90% | 42.40% | 1.29% | 3.05% | 0.20% | +| ❗yolov7p5_tiny | 46.60% | 41.50% | 1.63% | 3.60% | 3.03% | +| ❗yolov7p5_l | 59.50% | 51.20% | 1.29% | 2.54% | 2.68% | +| ❗yolox_s | 49.30% | 43.00% | 4.64% | 12.61%💔 | 3.96% | +| ❗yolox_tiny | 45.00% | 39.10% | 1.35% | 3.12% | 0.90% | +| ppyoloe_m | 55.80% | 54.60% | 2.69% | 8.10% | 0.87% | +| ppyoloe_s | 50.30% | 49.00% | 1.55% | 3.97% | 0.77% | + +### INT8 PERTENSOR POWER-OF-2 + +| Classification | Float Accuracy | Quant Accuracy | AQE | MQE | OQE | +| ------------------ | ----------- | -------------- | -------- | ---------- | ------- | +| ❗efficientnet_v1_b0 | 76.90% | 0.06% | 20609%💔 | 746505%💔 | 12887%💔 | +| ❗efficientnet_v1_b1 | 76.66% | 0.46% | 279.77%💔 | 5344.33%💔 | 138.59%💔 | +| efficientnet_v2 | 80.29% | 78.83% | 10.39% | 56.11%💔 | 24.38%💔 | +| ❗mnasnet 0.5 | 67.75% | 0.08% | 524.41%💔 | 5873.77%💔 | 255.34%💔 | +| ❗mnasnet 1.0 | 73.48% | 67.37% | 12.74%💔 | 27.71%💔 | 30.01%💔 | +| ❗mobilenet_v2 | 71.37% | 62.30% | 16.56%💔 | 44.13%💔 | 37.90%💔 | +| ❗mobilenet_v3_small | 67.89% | 0.14% | 3537.28% | 183137.50%💔 | 173.55%💔 | +| ❗mobilenet_v3_large | 73.37% | 68.53% | 7.96% | 21.39%💔 | 24.24%💔 | +| resnet18 | 69.65% | 68.45% | 1.98% | 5.39% | 4.20% | +| resnet50 | 75.56% | 75.18% | 3.19% | 11.04% | 5.06% | +| ❗once_for_all_71 | 72.30% | 0.24% | 324.85%💔 | 1351.20%💔 | 248.77%💔 | +| ❗once_for_all_73 | 74.54% | 0.48% | 352.27%💔 | 1570.90%💔 | 218.98%💔 | + +| Segmentation | Float mIou | Quant mIou | AQE | MQE | OQE | +| -------------------------------------------- | ------- | ---------- | ------ | ------ | ----- | +| stdc1_512x1024_80k_cityscapes | 71.44% | 71.36% | 2.45% | 5.56% | 1.36% | +| pspnet_r50-d8_512x1024_40k_cityscapes | 76.48% | 76.02% | 3.75% | 6.76% | 3.04% | +| pointrend_r50_512x1024_80k_cityscapes | 75.66% | 75.79% | \* | \* | \* | +| fpn_r50_512x1024_80k_cityscapes | 73.86% | 74.06% | 5.29% | 15.85%💔 | 2.53% | +| icnet_r18-d8_832x832_80k_cityscapes | 67.07% | 67.02% | 0.97% | 2.21% | 0.51% | +| fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes | 68.34% | 67.30% | 3.23% | 7.67% | 4.70% | +| ❗fast_scnn_lr0.12_8x4_160k_cityscapes | 70.22% | 68.33% | 10.04%💔 | 24.91%💔 | 2.03% | +| fcn_r50-d8_512x1024_40k_cityscapes | 70.96% | 73.10% | 4.55% | 15.92%💔 | 4.18% | +| bisenetv2_fcn_4x4_1024x1024_160k_cityscapes | 71.67% | 71.39% | 4.21% | 44.38%💔 | 2.76% | +| deeplabv3_r50-d8_512x1024_40k_cityscapes | 77.90% | 77.92% | 4.14% | 9.54% | 2.74% | + +| SuperRes | Float PSNR | Quant PSNR | AQE | MQE | OQE | +| -------- | ------- | ---------- | ----- | ------ | ----- | +| edsr | 28.98 | 28.83 | 4.40% | 12.12%💔 | 0.06% | +| rdn | 29.20 | 28.83 | 3.91% | 26.16%💔 | 0.07% | +| ❗srcnn | 27.76 | 22.21 | 6.38% | 16.31%💔 | 1.67% | +| srgan | 26.56 | 26.25 | 9.19% | 23.23%💔 | 0.33% | + +| OCR | Float Accuracy | Quant Accuracy | AQE | MQE | OQE | +| ------------------------------- | ----------- | -------------- | ---------- | ----------- | ------- | +| ❗en_PP-OCRv3_rec_infer | 67.36% | 0.00% | 123.00%💔 | 813.22%💔 | 92.18%💔 | +| ❗en_number_mobile_v2.0_rec_infer | 46.22% | 3.42% | 141.80%💔 | 566.63%💔 | 9.33% | +| ❗ch_PP-OCRv2_rec_infer | 54.89% | 0.00% | 211218.00%💔 | 5565940.00%💔 | 430.37%💔 | +| ❗ch_PP-OCRv3_rec_infer | 62.69% | 0.00% | 430.96%💔 | 5892.44%💔 | 218.08%💔 | +| ❗ch_ppocr_mobile_v2.0_rec_infer | 40.15% | 0.00% | 364.47%💔 | 5660.18%💔 | 54.53%💔 | +| ch_ppocr_server_v2.0_rec_infer | 54.98% | 54.41% | 3.84% | 13.75%💔 | 5.47% | + +| Detection | Float mAP | Quant mAP | AQE | MQE | OQE | +| ------------- | ------ | --------- | ------ | ------- | ------ | +| ❗yolov6p5_n | 49.80% | 42.70% | 19.20%💔 | 111.97%💔 | 9.81% | +| ❗yolov6p5_t | 52.40% | 20.40% | 54.68%💔 | 153.04%💔 | 34.37%💔 | +| ❗yolov5s6_n | 39.80% | 32.10% | 4.56% | 11.63%💔 | 1.16% | +| ❗yolov5s6_s | 47.90% | 38.10% | 4.73% | 9.80% | 0.80% | +| ❗yolov7p5_tiny | 46.60% | 37.50% | 7.02% | 14.71% | 10.45%💔 | +| ❗yolov7p5_l | 59.50% | 39.70% | 7.72% | 20.07%💔 | 12.56%💔 | +| ❗yolox_s | 49.30% | 37.40% | 15.44%💔 | 35.91%💔 | 12.00%💔 | +| ❗yolox_tiny | 45.00% | 34.70% | 4.27% | 11.52% | 3.66% | +| ❗ppyoloe_m | 55.80% | 51.20% | 14.31%💔 | 35.48%💔 | 4.74% | +| ❗ppyoloe_s | 50.30% | 45.50% | 14.69%💔 | 36.28%💔 | 6.52% | + +## 6 Contribution + +您可以在 github 上提交 issue 来反馈运行过程中遇到的问题。 + +您可以通过微信群,qq群,github等渠道联系我们提交新的模型与测试数据加入 OnnxQuant 测试集合 + +## Licence + +This project is distributed under the Apache License, Version 2.0. diff --git a/ppq/utils/fetch.py b/ppq/utils/fetch.py index a03bee8a..e0342283 100644 --- a/ppq/utils/fetch.py +++ b/ppq/utils/fetch.py @@ -42,7 +42,7 @@ def tensor_random_fetch( """ tensor = tensor.flatten() num_of_elements = tensor.numel() - assert num_of_elements > 0, ('Can not fetch data from tensor with less than 1 elements.') + assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') if seed is None: indexer = generate_torch_indexer(num_of_fetches=num_of_fetches, num_of_elements=num_of_elements) @@ -72,7 +72,7 @@ def channel_random_fetch( tensor = tensor.transpose(0, channel_axis) tensor = tensor.flatten(start_dim=1) num_of_elements = tensor.shape[-1] - assert num_of_elements > 0, ('Can not fetch data from tensor with less than 1 elements.') + assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') if seed is None: indexer = generate_torch_indexer(num_of_fetches=fetchs_per_channel, num_of_elements=num_of_elements) @@ -100,7 +100,7 @@ def batch_random_fetch( """ tensor = tensor.flatten(start_dim=1) num_of_elements = tensor.shape[-1] - assert num_of_elements > 0, ('Can not fetch data from tensor with less than 1 elements.') + assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') if seed is None: indexer = generate_torch_indexer(num_of_fetches=fetches_per_batch, num_of_elements=num_of_elements)