diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index e97099a0..d1e9da7b 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -333,6 +333,35 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont return multiplicand * multiplier +def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: + if len(values) != 11: + raise NotImplementedError('Not implement simplified MultiHeadAttention') + + q,k,v,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values + embed_dim = op.attributes.get('embed_dim') + num_heads = op.attributes.get('num_heads') + + if embed_dim is None or num_heads is None: + raise ValueError('Cannot fetch embed_dim or num_heads') + + # setup parameters + batch_size = q.shape[0] + head_dim = embed_dim // num_heads + scale = head_dim ** -0.5 + + q = F.linear(q, q_w, q_b) + k = F.linear(k, k_w, k_b) + v = F.linear(v, v_w, v_b) + + energy = (q @ k.transpose(-2, -1)) * scale + attn = energy.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim) + x = F.linear(x, o_w, o_b) + + return x + + def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: """Performs element-wise binary addition (with Numpy-style broadcasting support). @@ -786,6 +815,9 @@ def GatherND_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken reshaped_output = output.reshape(*shape_i, *shape_j, *shape_k) return output +def Gelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: + [input_value] = values + return F.gelu(input_value) def Greater_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: input_a, input_b = values @@ -1436,7 +1468,7 @@ def Split_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCo split = op.attributes.get('split', 0) [input_value] = values if 'split' not in op.attributes: - split = input_value.shape[axis] // len(op.outputs) + split = input_value.shape[axis] // len(op.outputs) outputs = torch.split(input_value, split, axis) return outputs @@ -1525,6 +1557,18 @@ def LeakyRelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke return output +def LayerNorm_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs): + if len(values) != 3: + raise ValueError('Unsupported LayerNorm without affine') + + input_data, weight, bias = values + eps = op.attributes.get('epsilon', 1e-5) + normalized_shape = weight.shape + + output = F.layer_norm(input_data, normalized_shape, weight, bias, eps) + return output + + def Pad_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs): mode = op.attributes.get('mode', 'constant') input_data = values[0] @@ -2118,20 +2162,20 @@ def Identity_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken return values[0] def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: - """ - Produces a one-hot tensor based on inputs. The locations represented by the index values in the 'indices' - input tensor will have 'on_value' and the other locations will have 'off_value' in the output tensor, - - where 'on_value' and 'off_value' are specified as part of required input argument 'values', - which is a two-element tensor of format [off_value, on_value]. - - The rank of the output tensor will be one greater than the rank of the input tensor. - The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'. - If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension, - i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'. - - The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices' - input tensor with values outside the range [-depth, depth-1] will result in one-hot representation + """Produces a one-hot tensor based on inputs. The locations represented by + the index values in the 'indices' input tensor will have 'on_value' and the + other locations will have 'off_value' in the output tensor, + + where 'on_value' and 'off_value' are specified as part of required input argument 'values', + which is a two-element tensor of format [off_value, on_value]. + + The rank of the output tensor will be one greater than the rank of the input tensor. + The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'. + If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension, + i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'. + + The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices' + input tensor with values outside the range [-depth, depth-1] will result in one-hot representation with all 'off_value' values in the output tensor. when axis = 0: @@ -2144,30 +2188,30 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC Attributes axis : int (default is -1) - (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that - the additional dimension will be inserted as the innermost/last dimension in the output tensor. + (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that + the additional dimension will be inserted as the innermost/last dimension in the output tensor. Negative value means counting dimensions from the back. Accepted range is [-r-1, r] where r = rank(indices). - + Inputs indices (non-differentiable) : T1 Input tensor containing indices. Any entries in the 'indices' input tensor with values outside the range [-depth, depth-1] - will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type, + will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type, the values will be casted to int64 before use. - + depth (non-differentiable) : T2 - Scalar specifying the number of classes in one-hot tensor. + Scalar specifying the number of classes in one-hot tensor. This is also the size of the one-hot dimension (specified by 'axis' attribute) added on in the output tensor. - The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1]. + The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1]. In case 'depth' is of non-integer type, it will be casted to int64 before use. values (non-differentiable) : T3 - Rank 1 tensor containing exactly two elements, - in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor, + Rank 1 tensor containing exactly two elements, + in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor, and 'off_value' is the value used for filling locations other than those specified in 'indices' input tensor. Outputs output (non-differentiable) : T3 - Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1. + Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1. The data type for the elements of the output tensor is the same as the type of input 'values' is used. """ # implementation from https://github.com/ToriML/onnx2pytorch/blob/master/onnx2pytorch/operations/onehot.py @@ -2187,10 +2231,10 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC order.insert(axis, -1) out = out.permute(order) return out - + def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: """ - Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is, + Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is, y = 1/x, is applied to the tensor elementwise. Version @@ -2231,11 +2275,13 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack 'Gather': Gather_forward, 'GatherElements': Gather_forward, 'GatherND': GatherND_forward, + 'Gelu': Gelu_forward, 'Gemm': Gemm_forward, 'grid_sampler': Grid_sampler_forward, 'GlobalAveragePool': AveragePool_forward, 'GlobalMaxPool': MaxPool2d_forward, 'Greater': Greater_forward, + 'LayerNorm': LayerNorm_forward, 'LeakyRelu': LeakyRelu_forward, 'Less': Less_forward, 'MatMul': MatMul_forward, @@ -2243,6 +2289,7 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack 'MaxPool': MaxPool2d_forward, 'Min': Eltwise_forward, 'Mul': Mul_forward, + 'MultiHeadAttention': MultiHeadAttention_forward, 'NonMaxSuppression': _NMS_forward, 'NonZero': NonZero_forward, 'Not': Not_forward, diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index 0556f49d..491f6b45 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -110,12 +110,12 @@ def quantize_operations( operation_platforms[op_name] = self.target_platform else: operation_platforms[op_name] = self.default_platform - # maunnl override. + # manual override. if op_name in operation_platforms: operation.platform = operation_platforms[op_name] # build operation_quantization_configs - # every quantable op MUST have a quantization config + # every quantizable op MUST have a quantization config # if operation.type is listed in quantable_operation_types while a operation_quantization_configs is given # it will override the setting of quantable_operation_types for op_name, operation in self._graph.operations.items():