Skip to content

Commit

Permalink
Add WeightQuantizer and DynamicActQuantizer
Browse files Browse the repository at this point in the history
Summary:
This exposes the AffineQuantizedTensor and LinearActQuantizedTensor
subclass as a model level API that will replace the weights of linear layers
This is in preparation to replace existing tensor subclass APIs such as `change_linear_weights_to_int4_woqtensors`
but currently we can't combine the two quantizers due to some problem with parametrization/nn.Parameter
the error is:

raise KeyError(f"attribute '{name}' already exists")
KeyError: "attribute 'weight' already exists"

happens in
```
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
```

Test Plan:
regression tests:
```
python test/quantization/test_quant_api.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 23, 2024
1 parent 5741aa2 commit bd75fb2
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 64 deletions.
117 changes: 82 additions & 35 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
get_symmetric_quantization_config,
)

from torchao.quantization.subclass import (
to_aqt,
to_laqt,
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
TensorSubclassQuantizer,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -92,8 +99,8 @@ def __init__(self, m=64, n=32, k=64):
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
def example_inputs(self, batch_size=1):
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -423,20 +430,31 @@ def get_per_token_block_size(x):
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps
)
dynamic_act_quantizer = TensorSubclassQuantizer(to_laqt, input_quant_func=input_quant_func)

# note: order is important
m = weight_quantizer.quantize(m)
m = dynamic_act_quantizer.quantize(m)

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -475,16 +493,19 @@ def test_quantized_tensor_subclass_int4(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)
m = weight_quantizer.quantize(m)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand Down Expand Up @@ -515,12 +536,20 @@ def test_quantized_tensor_subclass_int8(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
def get_block_size(x):
return (1, x.shape[1])

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
get_block_size=get_block_size,
target_dtype=target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype
)

m = weight_quantizer.quantize(m)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -537,7 +566,7 @@ def to_quantized(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import to_aqt
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
Expand All @@ -563,20 +592,26 @@ def get_per_token_block_size(x):
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
get_block_size=get_weight_block_size,
target_dtype=target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype
)
dynamic_act_quantizer = TensorSubclassQuantizer(to_laqt, input_quant_func=input_quant_func)
m = weight_quantizer.quantize(m)
m = dynamic_act_quantizer.quantize(m)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
Expand All @@ -591,6 +626,18 @@ def dynamic_quant(linear):

self.assertTrue(torch.equal(res, ref))

# workaround for export path
from torchao.quantization.quant_api import _unwrap_tensor_subclass
m = _unwrap_tensor_subclass(m)
m = torch.export.export(m, example_inputs).module()
exported_model_res = m(*example_inputs)

self.assertTrue(torch.equal(exported_model_res, ref))

# make sure it compiles
torch._export.aot_compile(m, example_inputs)



if __name__ == "__main__":
unittest.main()
77 changes: 77 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
Expand Down Expand Up @@ -48,6 +49,7 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"TensorSubclassQuantizer",
"autoquant"
]

Expand Down Expand Up @@ -214,3 +216,78 @@ def replace_conv2d_1x1(conv):
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)

class UnwrapTensorSubclass(nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors

def _unwrap_tensor_subclass(model, filter_fn=None):
def insert_parametrization(lin):
parametrize.register_parametrization(lin, "weight", UnwrapTensorSubclass())
return lin

_replace_with_custom_fn_if_matches_filter(
model,
insert_parametrization,
_is_linear if filter_fn is None else filter_fn,
)

return model


def _get_linear_subclass_inserter(constructor, **kwargs):
def insert_subclass(lin):
# so that we don't modify the original kwargs
copied_kwargs = dict(kwargs)
get_block_size = copied_kwargs.pop("get_block_size", None)
if get_block_size:
block_size = get_block_size(lin.weight)
copied_kwargs["block_size"] = block_size
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
return lin

return insert_subclass

class TensorSubclassQuantizer(Quantizer):
def __init__(self, factory_fn, **kwargs):
super().__init__()
self.factory_fn = factory_fn
self.kwargs = kwargs

def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(self.factory_fn, **self.kwargs),
_is_linear if filter_fn is None else filter_fn,
)
return model
35 changes: 6 additions & 29 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"AffineQuantizedTensor",
"LinearActQuantizedTensor",
]


Expand Down Expand Up @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):

self.q_scales = q_scales
super().__init__(int_data, transposed)

Expand Down Expand Up @@ -629,32 +629,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

def to_aqt(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min = None,
quant_max = None,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
preserve_zero = True,
zero_point_domain = ZeroPointDomain.INT,
):
return AffineQuantizedTensor.from_float(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain
)

# TODO: merge with nf4 implements decorator
# aten op to their __torch_dispatch__ implemnetations for the tensor subclass
Expand Down Expand Up @@ -777,7 +751,7 @@ def dequantize(self, output_dtype=None):
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
Expand Down Expand Up @@ -1091,7 +1065,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
input_quant_func = tensor_attributes
input_quant_func, = tensor_attributes
return cls(
original_weight_tensor,
input_quant_func,
Expand Down Expand Up @@ -1176,3 +1150,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
raise NotImplementedError(
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

to_aqt = AffineQuantizedTensor.from_float
to_laqt = LinearActQuantizedTensor.from_float

0 comments on commit bd75fb2

Please sign in to comment.