Skip to content

Commit

Permalink
Add GPTQ Quantization (#1216)
Browse files Browse the repository at this point in the history
* v1 test draft

* code runs but outputs gibberish.

* draft v1.1

* remove duplicate

* remove dep to transformers and cleaning

* Add serialization and loading

* Clean code and doc

* add flexibility

* remove triton

* remove some dep with transformers

* add testing

* make style

* add accelerate flag

* handle device placement

* make style

* Apply suggestions

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: fxmarty <[email protected]>

* add doc in data.py

* apply suggestion for utils file

* remove multiple output

* fix Optional

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: fxmarty <[email protected]>

* remove useless check

* fix doc and style

* fix name

* replace catcher by prefoward hook

* update doctstring for true_sequential

* apply suggestion

* Fix import

* Add docstring for tests

* move args

* fix typo

* fix cpu offload and tokenizer

* fix typo

* fix offload cpu

* modify attribute

* more explicit error

* dataset optional

* add tqdm bar instead

* style

* add doc

* replace by tqdm.auto

* change model

* add CI

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* Update .github/workflows/test_gptq.yml

Co-authored-by: Younes Belkada <[email protected]>

* add peft compatibility

* Apply suggestions from code review doc

Co-authored-by: fxmarty <[email protected]>

* merge examples

* code review

* fix test

* make style

* change var

* fix doc

* add exllama

* change naming

* more doc

---------

Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
4 people authored Aug 10, 2023
1 parent 94bf766 commit 9f2943e
Show file tree
Hide file tree
Showing 14 changed files with 1,501 additions and 1 deletion.
90 changes: 90 additions & 0 deletions .github/workflows/test_gptq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
name: GPTQ Quantization / Test GPU

on:
workflow_dispatch:
schedule:
- cron: 0 1 */3 * * # at 1am every 3 days
pull_request:
types: [opened, synchronize, reopened, labeled]
# uncomment to enable on PR merge on main branch:
#push:
# branches:
# - main

jobs:
start-runner:
if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || contains( github.event.pull_request.labels.*.name, 'gpu-test') }}
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-0dc1c26161f869ed1
EC2_INSTANCE_TYPE: g4dn.xlarge
EC2_SUBNET_ID: subnet-859322b4,subnet-b7533b96,subnet-47cfad21,subnet-a396b2ad,subnet-06576a4b,subnet-df0f6180
EC2_SECURITY_GROUP: sg-0bb210cd3ec725a13
EC2_IAM_ROLE: optimum-ec2-github-actions-role
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
iam-role-name: ${{ env.EC2_IAM_ROLE }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-optimum-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
do-the-job:
name: Setup
needs: start-runner # required to start the main job when the runner is ready
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env:
AWS_REGION: us-east-1
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Build image
run: |
docker build -f tests/gptq/docker/Dockerfile_quantization_gpu -t gptq-gpu .
- name: Test with unittest within docker container
run: |
docker run --rm --gpus all -v $(pwd)/hf_cache:/root/.cache/huggingface --workdir=/workspace/optimum/tests gptq-gpu:latest
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner # required to get output from the start-runner job
- do-the-job # required to wait when the main job is done
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ always() && !(needs.start-runner.result == 'skipped' && needs.do-the-job.result == 'skipped') }} # required to stop the runner even if the error happened in the previous jobs are all skipped
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
5 changes: 5 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@
isExpanded: false
title: BetterTransformer
isExpanded: false
- sections:
- local: llm_quantization/usage_guides/quantization
title: GPTQ quantization
title: LLM quantization
isExpanded: false
- sections:
- local: utils/dummy_input_generators
title: Dummy input generators
Expand Down
1 change: 1 addition & 0 deletions docs/source/concept_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ models while respecting accuracy and latency constraints.
[PyTorch quantization functions](https://pytorch.org/docs/stable/quantization-support.html#torch-quantization-quantize-fx)
to allow graph-mode quantization of 🤗 Transformers models in PyTorch. This is a lower-level API compared to the two
mentioned above, giving more flexibility, but requiring more work on your end.
- The `optimum.llm_quantization` package allows to [quantize and run LLM models](https://huggingface.co/docs/optimum/llm_quantization/usage_guides/quantization)

## Going further: How do machines represent numbers?

Expand Down
104 changes: 104 additions & 0 deletions docs/source/llm_quantization/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Quantization

## AutoGPTQ Integration

🤗 Optimum collaborated with [AutoGPTQ library](https://github.com/PanQiWei/AutoGPTQ) to provide a simple API that apply GPTQ quantization on language models. With GPTQ quantization, you can quantize your favorite language model to 8, 6, 4 or even 2 bits. This comes without a big drop of performance and with faster inference speed. This is supported by most GPU hardwares.

If you want to quantize 🤗 Transformers models with GPTQ, follow this [documentation](https://huggingface.co/docs/transformers/main_classes/quantization).

To learn more about the quantization technique used in GPTQ, please refer to:
- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper
- the [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library used as the backend
Note that the AutoGPTQ library provides more advanced usage (triton backend, fused attention, fused MLP) that are not integrated with Optimum. For now, we leverage only the CUDA kernel for GPTQ.

### Requirements

You need to have the following requirements installed to run the code below:

- AutoGPTQ library:
`pip install auto-gptq`

- Optimum library:
`pip install --upgrade optimum`

- Install latest `transformers` library from source:
`pip install --upgrade git+https://github.com/huggingface/transformers.git`

- Install latest `accelerate` library:
`pip install --upgrade accelerate`

### Load and quantize a model

The [`~optimum.gptq.GPTQQuantizer`] class is used to quantize your model. In order to quantize your model, you need to provide a few arguemnts:
- the number of bits: `bits`
- the dataset used to calibrate the quantization: `dataset`
- the model sequence length used to process the dataset: `model_seqlen`
- the block name to quantize: `block_name_to_quantize`

With 🤗 Transformers integration, you don't need to pass the `block_name_to_quantize` and `model_seqlen` as we can retrieve them. However, for custom model, you need to specify them. Also, make sure that your model is converted to `torch.float16` before quantization.

```py
from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch
model_name = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

quantizer = GPTQQuantizer(bits=4, dataset="c4", block_name_to_quantize = "model.decoder.layers", model_seqlen = 2048)
quantized_model = quantizer.quantize_model(model, tokenizer)
```

<Tip warning={true}>
GPTQ quantization only works for text model for now. Futhermore, the quantization process can take a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the Hugging Face Hub if there is not already a GPTQ quantized version of the model you would like to quantize.
</Tip>

### Save the model

To save your model, use the save method from [`~optimum.gptq.GPTQQuantizer`] class. It will create a folder with your model state dict along with the quantization config.
```python
save_folder = "/path/to/save_folder/"
quantizer.save(model,save_folder)
```

### Load quantized weights

You can load your quantized weights by using the [`~optimum.gptq.load_quantized_model`] function.
Through the Accelerate library, it is possible to load a model faster with a lower memory usage. The model needs to be initialized using empty weights, with weights loaded as a next step.
```python
from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto")
```

### Exllama kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. If you want to change its value, you just need to pass `disable_exllama` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch

from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False)
```

Note that only 4-bit models are supported with exllama kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft.

#### Fine-tune a quantized model

With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details.

### References

[[autodoc]] gtpq.GPTQQuantizer
- all

[[autodoc]] gtpq.load_quantized_model
- all
15 changes: 15 additions & 0 deletions optimum/gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .quantizer import GPTQQuantizer, load_quantized_model
23 changes: 23 additions & 0 deletions optimum/gptq/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SEQLEN_KEYS_TRANFORMERS = ["max_position_embeddings", "seq_length", "n_positions"]
BLOCK_PATTERNS = [
"transformer.h",
"model.decoder.layers",
"gpt_neox.layers",
"model.layers",
]

GPTQ_CONFIG = "quantization_config.json"
Loading

0 comments on commit 9f2943e

Please sign in to comment.