Skip to content

Commit

Permalink
Update 20230306 (#403)
Browse files Browse the repository at this point in the history
* 为 equalization pass 修复一个 bug
* 为 lsq pass 添加一个接口,允许用户传入优化器
* 添加 fuse matmul+add 函数
* 修复了一些 typo
* 上传了 QuantZoo 数据集
  • Loading branch information
ZhangZhiPku committed Mar 9, 2023
1 parent e0298ad commit 71f5abc
Show file tree
Hide file tree
Showing 13 changed files with 1,147 additions and 22 deletions.
32 changes: 32 additions & 0 deletions ppq/IR/morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,38 @@ def fuse_selfattention(self):
if v is not None: non_empty_attr[k] = v
op._attributes = non_empty_attr

def fuse_matmul_add(self, verbose: bool = True):
"""
Fuse Matmul + bias add to PPQBiasFusedMatMul
PPQBiasFusedMatMul is a temporary operation which will be splited when exporting.
"""
graph, fused = self.graph, False
for current_op in [_ for _ in graph.operations.values()]:
if current_op.type != 'MatMul': continue

# check down-stream op is add
next_ops = graph.get_downstream_operations(current_op)
if len(next_ops) != 1: continue
if next_ops[0].type != 'Add': continue

# check if is a constant add
fusing_op = next_ops[0]
if fusing_op.num_of_parameter == 1:

# do graph fusion
bias = fusing_op.parameters[0].value
graph.remove_operation(fusing_op, keep_coherence=True)
graph.create_variable(value=bias, is_parameter=True, dest_ops=[current_op])
current_op.type = 'PPQBiasFusedMatMul'
fused = True

if verbose:
print(f'Fusing graph op: {current_op.name} + {fusing_op.name}')

if not fused:
ppq_warning("No suitable matmul + add was found, check your graph again.")


class GraphDecomposer(GraphCommandProcessor):
"""Since PPQ 0.6.4, GraphDecomposer is introduced to split some complex
Expand Down
13 changes: 1 addition & 12 deletions ppq/parser/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,6 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op])
graph.remove_variable(op.inputs[-1])

class PPQBiasFusedMatMulExporter(OperationExporter):
def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
if op.num_of_input == 3: bias = op.inputs[-1]
assert bias.is_parameter and bias.value is not None, 'MatMul Format Error'

bias_op = graph.create_operation(op_type='Add')
op.type = 'MatMul'
graph.insert_op_after(bias_op, op)
graph.create_variable(value=bias.value, is_parameter=True, dest_ops=[bias_op])
graph.remove_variable(op.inputs[-1])

OP_CONVERTERS = {
'ConstantOfShape': ConstantOfShapeExporter,
'MMCVRoiAlign': MMCVExporter,
Expand All @@ -83,7 +72,7 @@ def export(self, op: Operation, graph: BaseGraph, **kwargs) -> Operation:
'QLinearMul': OOSExporter,
'QLinearReduceMean': OOSExporter,
'QLinearSigmoid': OOSExporter,
'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter
# 'PPQBiasFusedMatMul': PPQBiasFusedMatMulExporter
}

def convert_value(value: Union[int, float, np.ndarray, torch.Tensor]) -> str:
Expand Down
2 changes: 1 addition & 1 deletion ppq/quantization/algorithm/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def key_value_from_upstream(
# ----------------------------------
# step - 3, extract activation from op:
# ----------------------------------
if including_act and op.inputs[0].value is not None:
if including_act and op.outputs[0].value is not None:
a = op.outputs[0].value * act_multiplier
buffer.append(a)

Expand Down
6 changes: 3 additions & 3 deletions ppq/quantization/algorithm/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __init__(
if self.is_parameter and is_parameter_trainable:
self.param_backup = self.var.value.clone()

# There is 4 checks for training scale:
# There is 4 checks for scale training:
# 1. scale is valid
# 2. state is active
# 3. do not have POWER_OF_2 policy but Must have Linear policy
Expand All @@ -348,7 +348,7 @@ def __init__(
self.is_scale_trainable = True
self.scale_backup = self.config.scale.detach().clone()

# There is 4 checks for training offset:
# There is 4 checks for offset training:
# 1. offset is valid
# 2. state is active
# 3. do not have SYMMETRICAL policy
Expand Down Expand Up @@ -419,4 +419,4 @@ def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> to
quantized = (quantized - offset.detach()) * scale
quantized = quantized
return quantized

2 changes: 1 addition & 1 deletion ppq/quantization/observer/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def hist_to_scale_offset(

losses, quant_bins = [], 2 ** (config.num_of_bits - 1)

# following code is curcial, do not move
# following code is curcial, do not remove
histogram[: int(hist_bins * .002)] = 0
histogram[int(hist_bins * .002)] = 1

Expand Down
7 changes: 5 additions & 2 deletions ppq/quantization/optim/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,9 @@ class LearnedStepSizePass(TrainingBasedPass):
def __init__(
self, name: str = 'PPQ LSQ Optimization', interested_layers: List[str] = [],
steps: int = 500, gamma: float = 0.0, is_scale_trainable: bool = True,
lr: float = 5e-5, block_size: int = None, expire_device: str = 'cpu',
lr: float = 5e-5, block_size: int = 5, expire_device: str = 'cpu',
collecting_device: str = 'cuda', loss_fn: Callable = torch_mean_square_error,
optimizer: Any = None
) -> None:
super().__init__(name=name)
self.interested_layers = interested_layers
Expand All @@ -722,6 +723,7 @@ def __init__(
self.gamma = gamma
self.steps = steps
self.lr = lr
self.optimizer = optimizer

def finetune(
self, steps: int, learning_rate: float, block: TrainableBlock, executor: TorchExecutor,
Expand Down Expand Up @@ -764,8 +766,9 @@ def finetune(
return 0, 0

# initilize optimizer.
if optimizer is None:
if self.optimizer is None:
optimizer = torch.optim.Adam(tensors, lr=learning_rate)
else: optimizer = self.optimizer(tensors, lr=learning_rate)

dataset_length = len(qt_inputs)
if dataset_length == 0: raise ValueError('Dataset is empty.')
Expand Down
165 changes: 165 additions & 0 deletions ppq/samples/QuantZoo/QuantZoo_Imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Test Quantization System Performace on Image Classification Models with ILSVRC2012 Dataset

# Should contains model file(.onnx)
MODEL_DIR = 'QuantZoo/Model/Imagenet'

# Should contains Calib & Test Img Folder
CALIB_DIR = 'QuantZoo/Data/Imagenet/Calib'
TEST_DIR = 'QuantZoo/Data/Imagenet/Test'

# calibration & test batchsize
BATCHSIZE = 32

# Quantizer Configuration
SYMMETRICAL = True
PER_CHANNEL = True
POWER_OF_2 = False
BIT_WIDTH = 8

# write report to here
REPORT_DIR = 'QuantZoo/Reports'

CONFIGS = [
{
'Model': 'efficientnet_v1_b0',
'Output': ['/features/features.8/features.8.2/Mul_output_0']
},
{
'Model': 'efficientnet_v1_b1',
'Output': ['/features/features.8/features.8.2/Mul_output_0']
},
{
'Model': 'efficientnet_v2_s',
'Output': ['/features/features.7/features.7.2/Mul_output_0']
},
{
'Model': 'mnasnet0_5',
'Output': ['/layers/layers.16/Relu_output_0']
},
{
'Model': 'mnasnet1_0',
'Output': ['/layers/layers.16/Relu_output_0']
},
{
'Model': 'mobilenet_v2',
'Output': ['/features/features.18/features.18.2/Clip_output_0']
},
{
'Model': 'resnet18',
'Output': ['/layer4/layer4.1/relu_1/Relu_output_0']
},
{
'Model': 'resnet50',
'Output': ['/layer4/layer4.2/relu_2/Relu_output_0']
},

{
'Model': 'mobilenet_v3_large',
'Output': ['/classifier/classifier.1/Mul_output_0']
},
{
'Model': 'mobilenet_v3_small',
'Output': ['/classifier/classifier.1/Mul_output_0']
},
{
'Model': 'v100_gpu64@[email protected]_finetune@25',
'Output': ['471']
},
{
'Model': 'v100_gpu64@[email protected]_finetune@25',
'Output': ['471']
},
{
# vit_b_16 requires BATCHSIZE = 1!
'Model': 'vit_b_16',
'Output': ['onnx::Gather_1703']
}
]

import os

import torch

import ppq.lib as PFL
from ppq.api import ENABLE_CUDA_KERNEL, load_onnx_graph
from ppq.core import TargetPlatform
from ppq.executor import TorchExecutor
from ppq.quantization.optim import (LayerwiseEqualizationPass,
LearnedStepSizePass, ParameterQuantizePass,
RuntimeCalibrationPass)
from QuantZoo.Data.Imagenet.Eval import (evaluate_ppq_module_with_imagenet,
load_imagenet_from_directory)
from QuantZoo.Quantizers import MyFP8Quantizer, MyInt8Quantizer
from QuantZoo.Util import error_analyze


calib_loader = load_imagenet_from_directory(
directory=CALIB_DIR, batchsize=BATCHSIZE,
shuffle=False, require_label=False,
num_of_workers=8)


test_loader = load_imagenet_from_directory(
directory=TEST_DIR, batchsize=BATCHSIZE,
shuffle=False, require_label=True,
num_of_workers=8)


with ENABLE_CUDA_KERNEL():
for config in CONFIGS:
model = config['Model']
monitoring_vars = config['Output']

print(f"Ready to run quant benchmark on {model}")
graph = load_onnx_graph(onnx_import_file=os.path.join(MODEL_DIR, model + '.onnx'))

if model == 'vit_b_16':
if BATCHSIZE == 32:
raise Exception('To Evaluate vit_b_16, change batchsize to 1, change calibration method to minmax.')
from ppq.IR import GraphMerger
processor = GraphMerger(graph)
processor.fuse_matmul_add()
processor.fuse_layernorm()
processor.fuse_gelu()

quantizer = MyInt8Quantizer(
graph=graph, sym=SYMMETRICAL, power_of_2=POWER_OF_2,
num_of_bits=BIT_WIDTH, per_channel=PER_CHANNEL)
# quantizer = MyFP8Quantizer(graph=graph)

# convert op to quantable-op
for name, op in graph.operations.items():
if op.type in {'Conv', 'ConvTranspose', 'MatMul', 'Gemm',
'PPQBiasFusedMatMul', 'LayerNormalization'}:
quantizer.quantize_operation(name, platform=TargetPlatform.INT8)

# build quant pipeline.
pipeline = PFL.Pipeline([
# LayerwiseEqualizationPass(iteration=10),
ParameterQuantizePass(),
RuntimeCalibrationPass(),
# LearnedStepSizePass(steps=500, collecting_device='cuda', block_size=5)
])

# call pipeline.
executor = TorchExecutor(graph=graph)
executor.tracing_operation_meta(torch.zeros(size=[BATCHSIZE, 3, 224, 224]).cuda())

pipeline.optimize(
graph=graph, dataloader=calib_loader, verbose=True,
calib_steps=32, collate_fn=lambda x: x.to('cuda'), executor=executor)

# evaluation
acc = evaluate_ppq_module_with_imagenet(
model=graph, imagenet_validation_loader=test_loader,
batchsize=BATCHSIZE, device='cuda', verbose=False)
print(f'Model Classify Accurarcy = {acc: .4f}%')

# error analyze
performance = error_analyze(
graph=graph,
outputs=monitoring_vars,
dataloader=test_loader,
collate_fn=lambda x: x[0].to('cuda'),
verbose=True
)
Loading

0 comments on commit 71f5abc

Please sign in to comment.