Skip to content

Commit

Permalink
add exllama
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Aug 9, 2023
1 parent c506947 commit 744c249
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
17 changes: 17 additions & 0 deletions docs/source/optimization_toolbox/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto")
```

### Exllama kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. You just need to pass `disable_exllama = False` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch

from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False)
```

Note that only 4-bit models are supported with exllama kernels for now.

### References

[[autodoc]] gtpq.GPTQQuantizer
Expand Down
58 changes: 54 additions & 4 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@


if is_accelerate_available():
from accelerate import Accelerator, cpu_offload_with_hook, load_checkpoint_and_dispatch
from accelerate import (
Accelerator,
cpu_offload_with_hook,
load_checkpoint_and_dispatch,
)
from accelerate.hooks import remove_hook_from_module

if is_auto_gptq_available():
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear

Expand All @@ -63,6 +68,7 @@ def __init__(
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -99,6 +105,8 @@ def __init__(
The batch size of the dataset
pad_token_id (`Optional[int]`, defaults to `None`):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
"""

self.bits = bits
Expand All @@ -114,6 +122,7 @@ def __init__(
self.module_name_preceding_first_block = module_name_preceding_first_block
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama

if self.bits not in [2, 4, 6, 8]:
raise ValueError("only support quantize to [2,4,6,8] bits.")
Expand Down Expand Up @@ -184,7 +193,11 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
To keep track of the name of the current module
"""
QuantLinear = dynamically_import_QuantLinear(
use_triton=False, desc_act=self.desc_act, group_size=self.group_size
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
)
if isinstance(module, QuantLinear):
return
Expand Down Expand Up @@ -420,6 +433,12 @@ def tmp(_, input, output):
layer_inputs, layer_outputs = layer_outputs, []
torch.cuda.empty_cache()

if self.bits == 4 and not self.disable_exllama:
if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])):
logger.warning(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
)
self.disable_exllama = True
# Step 4: Pack the model at the end (Replacing the layers)
self.pack_model(model=model, quantizers=quantizers)

Expand All @@ -429,9 +448,31 @@ def tmp(_, input, output):
model.config.use_cache = use_cache
model.config.quantization_config = self.to_dict()

# Step 5: Any post-initialization that require device information, for example buffers initialization on device.
model = self.post_init_model(model)

torch.cuda.empty_cache()
return model

def post_init_model(self, model):
"""
Post-initialization that require device information, for example buffers initialization on device.
Args:
model (`nn.Module`):
The input model
"""
if self.bits == 4 and not self.disable_exllama:
if get_device(model) == torch.device("cpu") or (
hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
):
raise ValueError(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
)

return autogptq_post_init(model, use_act_order=self.desc_act)

def pack_model(
self,
model: nn.Module,
Expand All @@ -447,7 +488,11 @@ def pack_model(
A mapping of the layer name and the data needed to pack the layer
"""
QuantLinear = dynamically_import_QuantLinear(
use_triton=False, desc_act=self.desc_act, group_size=self.group_size
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
)
logger.info("Packing model...")
layers = get_layers(model)
Expand Down Expand Up @@ -514,6 +559,7 @@ def load_quantized_model(
offload_folder: Optional[str] = None,
offload_buffers: Optional[str] = None,
offload_state_dict: bool = False,
disable_exllama: bool = False,
):
"""
Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map.
Expand Down Expand Up @@ -546,6 +592,8 @@ def load_quantized_model(
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
picked contains `"disk"` values.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
Returns:
`nn.Module`: The quantized model
Expand All @@ -566,6 +614,7 @@ def load_quantized_model(
with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f:
quantize_config_dict = json.load(f)
quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
quantizer.disable_exllama = disable_exllama

model = quantizer.convert_model(model)

Expand All @@ -582,8 +631,9 @@ def load_quantized_model(
offload_buffers=offload_buffers,
offload_state_dict=offload_state_dict,
)

model = quantizer.post_init_model(model)
model.is_quantized = True
model.quantization_method = QuantizationMethod.GPTQ
# put on eval mode
model.eval()
return model
18 changes: 15 additions & 3 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class GPTQTest(unittest.TestCase):
bits = 4
group_size = 128
desc_act = False
disable_exllama = True

dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
Expand All @@ -61,7 +62,11 @@ def setUpClass(cls):

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
cls.quantizer = GPTQQuantizer(
bits=cls.bits, dataset=cls.dataset, group_size=cls.group_size, desc_act=cls.desc_act
bits=cls.bits,
dataset=cls.dataset,
group_size=cls.group_size,
desc_act=cls.desc_act,
disable_exllama=cls.disable_exllama,
)

cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer)
Expand All @@ -84,7 +89,11 @@ def test_quantized_layers_class(self):
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear

QuantLinear = dynamically_import_QuantLinear(
use_triton=False, desc_act=self.desc_act, group_size=self.group_size
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)

Expand Down Expand Up @@ -116,14 +125,17 @@ def test_serialization(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)

with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(empty_model, save_folder=tmpdirname, device_map={"": 0})
self.check_inference_correctness(quantized_model_from_saved)


class GPTQTestExllama(GPTQTest):
disable_exllama = False


class GPTQUtilsTest(unittest.TestCase):
"""
Test utilities
Expand Down

0 comments on commit 744c249

Please sign in to comment.