From 9f2943eb31ccc084a463cbd491a23bcc036d8929 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:46:59 -0400 Subject: [PATCH] Add GPTQ Quantization (#1216) * v1 test draft * code runs but outputs gibberish. * draft v1.1 * remove duplicate * remove dep to transformers and cleaning * Add serialization and loading * Clean code and doc * add flexibility * remove triton * remove some dep with transformers * add testing * make style * add accelerate flag * handle device placement * make style * Apply suggestions Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * add doc in data.py * apply suggestion for utils file * remove multiple output * fix Optional * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * remove useless check * fix doc and style * fix name * replace catcher by prefoward hook * update doctstring for true_sequential * apply suggestion * Fix import * Add docstring for tests * move args * fix typo * fix cpu offload and tokenizer * fix typo * fix offload cpu * modify attribute * more explicit error * dataset optional * add tqdm bar instead * style * add doc * replace by tqdm.auto * change model * add CI * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update .github/workflows/test_gptq.yml Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * add peft compatibility * Apply suggestions from code review doc Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * merge examples * code review * fix test * make style * change var * fix doc * add exllama * change naming * more doc --------- Co-authored-by: younesbelkada Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- .github/workflows/test_gptq.yml | 90 +++ docs/source/_toctree.yml | 5 + docs/source/concept_guides/quantization.mdx | 1 + .../usage_guides/quantization.mdx | 104 +++ optimum/gptq/__init__.py | 15 + optimum/gptq/constants.py | 23 + optimum/gptq/data.py | 263 +++++++ optimum/gptq/quantizer.py | 639 ++++++++++++++++++ optimum/gptq/utils.py | 115 ++++ optimum/utils/__init__.py | 1 + optimum/utils/import_utils.py | 5 + optimum/utils/testing_utils.py | 9 +- tests/gptq/Dockerfile_quantization_gpu | 26 + tests/gptq/test_quantization.py | 206 ++++++ 14 files changed, 1501 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test_gptq.yml create mode 100644 docs/source/llm_quantization/usage_guides/quantization.mdx create mode 100644 optimum/gptq/__init__.py create mode 100644 optimum/gptq/constants.py create mode 100644 optimum/gptq/data.py create mode 100644 optimum/gptq/quantizer.py create mode 100644 optimum/gptq/utils.py create mode 100644 tests/gptq/Dockerfile_quantization_gpu create mode 100644 tests/gptq/test_quantization.py diff --git a/.github/workflows/test_gptq.yml b/.github/workflows/test_gptq.yml new file mode 100644 index 0000000000..dcb7fb5565 --- /dev/null +++ b/.github/workflows/test_gptq.yml @@ -0,0 +1,90 @@ +name: GPTQ Quantization / Test GPU + +on: + workflow_dispatch: + schedule: + - cron: 0 1 */3 * * # at 1am every 3 days + pull_request: + types: [opened, synchronize, reopened, labeled] + # uncomment to enable on PR merge on main branch: + #push: + # branches: + # - main + +jobs: + start-runner: + if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || contains( github.event.pull_request.labels.*.name, 'gpu-test') }} + name: Start self-hosted EC2 runner + runs-on: ubuntu-latest + env: + AWS_REGION: us-east-1 + EC2_AMI_ID: ami-0dc1c26161f869ed1 + EC2_INSTANCE_TYPE: g4dn.xlarge + EC2_SUBNET_ID: subnet-859322b4,subnet-b7533b96,subnet-47cfad21,subnet-a396b2ad,subnet-06576a4b,subnet-df0f6180 + EC2_SECURITY_GROUP: sg-0bb210cd3ec725a13 + EC2_IAM_ROLE: optimum-ec2-github-actions-role + outputs: + label: ${{ steps.start-ec2-runner.outputs.label }} + ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + - name: Start EC2 runner + id: start-ec2-runner + uses: philschmid/philschmid-ec2-github-runner@main + with: + mode: start + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + ec2-image-id: ${{ env.EC2_AMI_ID }} + ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} + subnet-id: ${{ env.EC2_SUBNET_ID }} + security-group-id: ${{ env.EC2_SECURITY_GROUP }} + iam-role-name: ${{ env.EC2_IAM_ROLE }} + aws-resource-tags: > # optional, requires additional permissions + [ + {"Key": "Name", "Value": "ec2-optimum-github-runner"}, + {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} + ] + do-the-job: + name: Setup + needs: start-runner # required to start the main job when the runner is ready + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + env: + AWS_REGION: us-east-1 + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Build image + run: | + docker build -f tests/gptq/docker/Dockerfile_quantization_gpu -t gptq-gpu . + - name: Test with unittest within docker container + run: | + docker run --rm --gpus all -v $(pwd)/hf_cache:/root/.cache/huggingface --workdir=/workspace/optimum/tests gptq-gpu:latest + + stop-runner: + name: Stop self-hosted EC2 runner + needs: + - start-runner # required to get output from the start-runner job + - do-the-job # required to wait when the main job is done + runs-on: ubuntu-latest + env: + AWS_REGION: us-east-1 + if: ${{ always() && !(needs.start-runner.result == 'skipped' && needs.do-the-job.result == 'skipped') }} # required to stop the runner even if the error happened in the previous jobs are all skipped + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ env.AWS_REGION }} + - name: Stop EC2 runner + uses: philschmid/philschmid-ec2-github-runner@main + with: + mode: stop + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + label: ${{ needs.start-runner.outputs.label }} + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 185f51e884..7d62f00720 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -125,6 +125,11 @@ isExpanded: false title: BetterTransformer isExpanded: false +- sections: + - local: llm_quantization/usage_guides/quantization + title: GPTQ quantization + title: LLM quantization + isExpanded: false - sections: - local: utils/dummy_input_generators title: Dummy input generators diff --git a/docs/source/concept_guides/quantization.mdx b/docs/source/concept_guides/quantization.mdx index f751e9d47a..b9aca25ee9 100644 --- a/docs/source/concept_guides/quantization.mdx +++ b/docs/source/concept_guides/quantization.mdx @@ -185,6 +185,7 @@ models while respecting accuracy and latency constraints. [PyTorch quantization functions](https://pytorch.org/docs/stable/quantization-support.html#torch-quantization-quantize-fx) to allow graph-mode quantization of 🤗 Transformers models in PyTorch. This is a lower-level API compared to the two mentioned above, giving more flexibility, but requiring more work on your end. +- The `optimum.llm_quantization` package allows to [quantize and run LLM models](https://huggingface.co/docs/optimum/llm_quantization/usage_guides/quantization) ## Going further: How do machines represent numbers? diff --git a/docs/source/llm_quantization/usage_guides/quantization.mdx b/docs/source/llm_quantization/usage_guides/quantization.mdx new file mode 100644 index 0000000000..58b85d514c --- /dev/null +++ b/docs/source/llm_quantization/usage_guides/quantization.mdx @@ -0,0 +1,104 @@ +# Quantization + +## AutoGPTQ Integration + +🤗 Optimum collaborated with [AutoGPTQ library](https://github.com/PanQiWei/AutoGPTQ) to provide a simple API that apply GPTQ quantization on language models. With GPTQ quantization, you can quantize your favorite language model to 8, 6, 4 or even 2 bits. This comes without a big drop of performance and with faster inference speed. This is supported by most GPU hardwares. + +If you want to quantize 🤗 Transformers models with GPTQ, follow this [documentation](https://huggingface.co/docs/transformers/main_classes/quantization). + +To learn more about the quantization technique used in GPTQ, please refer to: +- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper +- the [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library used as the backend +Note that the AutoGPTQ library provides more advanced usage (triton backend, fused attention, fused MLP) that are not integrated with Optimum. For now, we leverage only the CUDA kernel for GPTQ. + +### Requirements + +You need to have the following requirements installed to run the code below: + +- AutoGPTQ library: +`pip install auto-gptq` + +- Optimum library: +`pip install --upgrade optimum` + +- Install latest `transformers` library from source: +`pip install --upgrade git+https://github.com/huggingface/transformers.git` + +- Install latest `accelerate` library: +`pip install --upgrade accelerate` + +### Load and quantize a model + +The [`~optimum.gptq.GPTQQuantizer`] class is used to quantize your model. In order to quantize your model, you need to provide a few arguemnts: +- the number of bits: `bits` +- the dataset used to calibrate the quantization: `dataset` +- the model sequence length used to process the dataset: `model_seqlen` +- the block name to quantize: `block_name_to_quantize` + +With 🤗 Transformers integration, you don't need to pass the `block_name_to_quantize` and `model_seqlen` as we can retrieve them. However, for custom model, you need to specify them. Also, make sure that your model is converted to `torch.float16` before quantization. + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer +from optimum.gptq import GPTQQuantizer, load_quantized_model +import torch +model_name = "facebook/opt-125m" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) + +quantizer = GPTQQuantizer(bits=4, dataset="c4", block_name_to_quantize = "model.decoder.layers", model_seqlen = 2048) +quantized_model = quantizer.quantize_model(model, tokenizer) +``` + + +GPTQ quantization only works for text model for now. Futhermore, the quantization process can take a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the Hugging Face Hub if there is not already a GPTQ quantized version of the model you would like to quantize. + + +### Save the model + +To save your model, use the save method from [`~optimum.gptq.GPTQQuantizer`] class. It will create a folder with your model state dict along with the quantization config. +```python +save_folder = "/path/to/save_folder/" +quantizer.save(model,save_folder) +``` + +### Load quantized weights + +You can load your quantized weights by using the [`~optimum.gptq.load_quantized_model`] function. +Through the Accelerate library, it is possible to load a model faster with a lower memory usage. The model needs to be initialized using empty weights, with weights loaded as a next step. +```python +from accelerate import init_empty_weights +with init_empty_weights(): + empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) +empty_model.tie_weights() +quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto") +``` + +### Exllama kernels for faster inference + +For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. If you want to change its value, you just need to pass `disable_exllama` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus. + +```py +from optimum.gptq import GPTQQuantizer, load_quantized_model +import torch + +from accelerate import init_empty_weights +with init_empty_weights(): + empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) +empty_model.tie_weights() +quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False) +``` + +Note that only 4-bit models are supported with exllama kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft. + +#### Fine-tune a quantized model + +With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ. +Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details. + +### References + +[[autodoc]] gtpq.GPTQQuantizer + - all + +[[autodoc]] gtpq.load_quantized_model + - all \ No newline at end of file diff --git a/optimum/gptq/__init__.py b/optimum/gptq/__init__.py new file mode 100644 index 0000000000..6c13647990 --- /dev/null +++ b/optimum/gptq/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .quantizer import GPTQQuantizer, load_quantized_model diff --git a/optimum/gptq/constants.py b/optimum/gptq/constants.py new file mode 100644 index 0000000000..70c2526651 --- /dev/null +++ b/optimum/gptq/constants.py @@ -0,0 +1,23 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SEQLEN_KEYS_TRANFORMERS = ["max_position_embeddings", "seq_length", "n_positions"] +BLOCK_PATTERNS = [ + "transformer.h", + "model.decoder.layers", + "gpt_neox.layers", + "model.layers", +] + +GPTQ_CONFIG = "quantization_config.json" diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py new file mode 100644 index 0000000000..bc70747e00 --- /dev/null +++ b/optimum/gptq/data.py @@ -0,0 +1,263 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from datasets import load_dataset + + +""" +Set of utilities for loading most used datasets (original dataset from GPTQ paper) and be able to easily use them during quantization +""" + + +def prepare_dataset( + examples: List[Dict[str, torch.LongTensor]], batch_size: int = 1, pad_token_id: Optional[int] = None +): + """ + Prepare the dataset by making sure that we have the right format and `batch_size` + Args: + examples (`List[Dict[str, torch.LongTensor]]`): + List of data to prepare + batch_size (`int`, defaults to `1`): + Batch size of the data + pad_token_id (`Optional[int]`, defaults to `None`): + Pad token id of the model + Returns: + ` List[Dict[str, torch.LongTensor]]`: Batched dataset + """ + new_examples = [] + for example in examples: + input_ids = example["input_ids"] + attention_mask = example["attention_mask"] + new_examples.append( + {"input_ids": torch.LongTensor(input_ids), "attention_mask": torch.LongTensor(attention_mask)} + ) + if batch_size > 1 and pad_token_id is None: + raise ValueError( + "You need to pass a `pad_token_id` in `quantize_model` if you want to have examples with batch size > 1" + ) + new_examples = [ + collate_data(new_examples[start : start + batch_size], pad_token_id) + for start in range(0, len(new_examples), batch_size) + ] + return new_examples + + +def collate_data( + blocks: List[Dict[str, torch.LongTensor]], + contain_labels: bool = False, + pad_token_id: Optional[int] = None, +) -> Dict[str, torch.LongTensor]: + """ + Collate data in `blocks` + Args: + blocks (`List[Dict[str, torch.LongTensor]]`): + List of tensors that we need to batch together + pad_token_id (`Optional[int]`, defaults to `None`): + Pad token id of the model + contain_labels (`bool`, defaults to `False`): + Set True to also process the labels + + Returns: + `Dict[str, torch.LongTensor]`: Batched data + """ + + def pad_block(block, pads): + return torch.cat((pads.to(block.device), block), dim=-1).long() + + input_ids_blocks = [block["input_ids"] for block in blocks] + attention_mask_blocks = [block["attention_mask"] for block in blocks] + if contain_labels: + label_blocks = [block["labels"] for block in blocks] + label_max_len = max([block.size(-1) for block in label_blocks]) + + bsz = len(blocks) + inp_max_len = max([block.size(-1) for block in input_ids_blocks]) + + for i in range(bsz): + block_bsz, block_inp_len = input_ids_blocks[i].shape + pad_num = inp_max_len - block_inp_len + if pad_num > 0: + input_ids_blocks[i] = pad_block(input_ids_blocks[i], torch.ones((block_bsz, pad_num)) * pad_token_id) + attention_mask_blocks[i] = pad_block(attention_mask_blocks[i], torch.zeros((block_bsz, pad_num))) + if contain_labels: + block_label_len = label_blocks[i].shape[-1] + label_pad_num = label_max_len - block_label_len + if label_pad_num > 0: + label_blocks[i] = pad_block(label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100) + + data = { + "input_ids": torch.cat(input_ids_blocks, dim=0).long(), + "attention_mask": torch.cat(attention_mask_blocks, dim=0).long(), + } + if contain_labels: + data["labels"] = torch.cat(label_blocks, dim=0).long() + + return data + + +def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): + if split == "train": + data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + elif split == "validation": + data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + text = "".join([" \n" if s == "" else s for s in data["text"]]) + + enc = tokenizer(text, return_tensors="pt") + dataset = [] + for _ in range(nsamples): + i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = enc.input_ids[:, i:j] + attention_mask = torch.ones_like(inp) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + return dataset + + +def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): + if split == "train": + data = load_dataset( + "allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train" + ) + elif split == "validation": + data = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + dataset = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(data) - 1) + enc = tokenizer(data[i]["text"], return_tensors="pt") + if enc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = enc.input_ids[:, i:j] + attention_mask = torch.ones_like(inp) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + + return dataset + + +def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): + if split == "train": + data = load_dataset( + "allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train" + ) + elif split == "validation": + data = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + dataset = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(data) - 1) + enc = tokenizer(data[i]["text"], return_tensors="pt") + if enc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = enc.input_ids[:, i:j] + attention_mask = torch.ones_like(inp) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + + return dataset + + +def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): + if split == "train": + data = load_dataset("ptb_text_only", "penn_treebank", split="train") + elif split == "validation": + data = load_dataset("ptb_text_only", "penn_treebank", split="validation") + + enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") + + dataset = [] + for _ in range(nsamples): + i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = enc.input_ids[:, i:j] + attention_mask = torch.ones_like(inp) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + + return dataset + + +def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): + if split == "train": + data = load_dataset("ptb_text_only", "penn_treebank", split="train") + elif split == "validation": + data = load_dataset("ptb_text_only", "penn_treebank", split="test") + + enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") + + dataset = [] + for _ in range(nsamples): + i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = enc.input_ids[:, i:j] + attention_mask = torch.ones_like(inp) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + return dataset + + +def get_dataset( + dataset_name: str, tokenizer: Any, nsamples: int = 128, seqlen: int = 2048, seed: int = 0, split: str = "train" +): + """ + Get the dataset from the original paper of GPTQ + + Args: + dataset_name (`str`): + Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`. + tokenizer (`Any`): + Tokenizer of the model + nsamples (`int`, defaults to `128`): + Number of samples + seqlen (`int`, defaults to `2048`): + The sequence length of the model + seed (`int`, defaults to `0`): + Seed + split (`str`, defaults to `train`): + Split of the dataset. Can be either "train" or "validation" + Returns: + `List[Dict[str,torch.LongTensor]]`: The tokenized dataset. + """ + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + get_dataset_map = { + "wikitext2": get_wikitext2, + "c4": get_c4, + "c4-new": get_c4_new, + "ptb": get_ptb, + "ptb-new": get_ptb_new, + } + if split not in ["train", "test"]: + raise ValueError(f"The split need to be 'train' or 'validation' but found {split}") + if dataset_name not in get_dataset_map: + raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}") + get_dataset_fn = get_dataset_map[dataset_name] + return get_dataset_fn(tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py new file mode 100644 index 0000000000..1352946b60 --- /dev/null +++ b/optimum/gptq/quantizer.py @@ -0,0 +1,639 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team and GPTQ and AutoGPTQ authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +import os +from logging import getLogger +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.pytorch_utils import Conv1D +from transformers.utils.quantization_config import QuantizationMethod + +from ..utils import is_accelerate_available, is_auto_gptq_available +from ..utils.modeling_utils import recurse_getattr +from .constants import GPTQ_CONFIG +from .data import get_dataset, prepare_dataset +from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen + + +if is_accelerate_available(): + from accelerate import ( + Accelerator, + cpu_offload_with_hook, + load_checkpoint_and_dispatch, + ) + from accelerate.hooks import remove_hook_from_module + +if is_auto_gptq_available(): + from auto_gptq.modeling._utils import autogptq_post_init + from auto_gptq.quantization import GPTQ + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + +logger = getLogger(__name__) + + +class GPTQQuantizer(object): + r""" + A simple API for GPTQ Quantization + """ + + def __init__( + self, + bits: int, + dataset: Optional[Union[List[str], str]] = None, + group_size: int = 128, + damp_percent: float = 0.01, + desc_act: bool = True, + sym: bool = True, + true_sequential: bool = True, + use_cuda_fp16: bool = False, + model_seqlen: Optional[int] = None, + block_name_to_quantize: Optional[str] = None, + module_name_preceding_first_block: Optional[List[str]] = None, + batch_size: int = 1, + pad_token_id: Optional[int] = None, + disable_exllama: bool = False, + *args, + **kwargs, + ): + """ + Args: + bits (`int`): + The number of bits to quantize to, supported numbers are (2, 3, 4, 8). + dataset (`Union[List[str],str]`, defaults to None): + The dataset used for quantization. You can provide your own dataset in a list of string or just use the original datasets used + in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. + group_size (int, defaults to 128): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + damp_percent (`float`, defaults to `0.01`): + The percent of the average Hessian diagonal to use for dampening, recommended value is 0.01. + desc_act (`bool`, defaults to `True`): + Whether to quantize columns in order of decreasing activation size. + Setting it to False can significantly speed up inference but the perplexity may become slightly worse. + Also known as act-order. + sym (`bool`, defaults to `True`): + Whether to use symetric quantization. + true_sequential (`bool`, defaults to `True`): + Whether to perform sequential quantization even within a single Transformer block. + Instead of quantizing the entire block at once, we perform layer-wise quantization. + As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers. + use_cuda_fp16 (`bool`, defaults to `False`): + Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. + model_seqlen (`Optional[int]`, defaults to `None`): + The maximum sequence length that the model can take. + block_name_to_quantize (`Optional[str]`, defaults to `None`): + The transformers block name to quantize. + module_name_preceding_first_block (`Optional[List[str]]`, defaults to `None`): + The layers that are preceding the first Transformer block. + batch_size (`int`, defaults to `1`): + The batch size of the dataset + pad_token_id (`Optional[int]`, defaults to `None`): + The pad token id. Needed to prepare the dataset when `batch_size` > 1. + disable_exllama (`bool`, defaults to `False`): + Whether to use exllama backend. Only works with `bits` = 4. + """ + + self.bits = bits + self.dataset = dataset + self.group_size = group_size + self.damp_percent = damp_percent + self.desc_act = desc_act + self.sym = sym + self.true_sequential = true_sequential + self.use_cuda_fp16 = use_cuda_fp16 + self.model_seqlen = model_seqlen + self.block_name_to_quantize = block_name_to_quantize + self.module_name_preceding_first_block = module_name_preceding_first_block + self.batch_size = batch_size + self.pad_token_id = pad_token_id + self.disable_exllama = disable_exllama + + if self.bits not in [2, 4, 6, 8]: + raise ValueError("only support quantize to [2,4,6,8] bits.") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + + def to_dict(self): + """ + Returns the args in dict format. + """ + return copy.deepcopy(self.__dict__) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]): + """ + Instantiates a `GPTQQuantizer` using config_dict as kwargs + + Args: + config_dict (`Dict[str,Any]`): + quantization config + + Returns: + `GPTQQuantizer`: The quantizer object instantiated from those parameters. + """ + return cls(**config_dict) + + def convert_model(self, model: nn.Module): + """ + Convert the model to a GPTQ model by getting and replacing the layers. + + Args: + model (`nn.Module`): + Model to be converted + + """ + if self.block_name_to_quantize is None: + self.block_name_to_quantize = get_block_name_with_pattern(model) + block_name = self.block_name_to_quantize + layers_to_be_replaced = get_layers(model, prefix=block_name) + self._replace_by_quant_layers(model, layers_to_be_replaced) + + return model + + def get_no_split_module_classes(self, model): + """ + Get the modules that should not be split across multiple devices. + Args: + model (`nn.Module`): + The input model + """ + + block_class_name = recurse_getattr(model, self.block_name_to_quantize)[0].__class__.__name__ + no_split_module_classes = [block_class_name] + return no_split_module_classes + + def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: str = ""): + """ + Replaces linear layers in `module` by `QuantLinear` + + Args: + module (`nn.Module`): + Module to quantize + names (`List[str]`): + List of names of the module to quantize + name (`str`, defaults to `""`): + To keep track of the name of the current module + """ + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama, + ) + if isinstance(module, QuantLinear): + return + for attr in dir(module): + layer = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + device = get_device(layer) + delattr(module, attr) + if isinstance(layer, nn.Linear): + in_features = layer.in_features + out_features = layer.out_features + elif isinstance(layer, nn.Conv2d): + in_features = layer.in_channels + out_features = layer.out_channels + elif isinstance(layer, Conv1D): + in_features = layer.weight.shape[0] + out_features = layer.weight.shape[1] + if not (self.desc_act) or self.group_size == -1: + new_layer = QuantLinear( + self.bits, self.group_size, in_features, out_features, True, use_cuda_fp16=self.use_cuda_fp16 + ) + else: + new_layer = QuantLinear(self.bits, self.group_size, in_features, out_features, True) + new_layer.device = device + setattr(module, attr, new_layer.to(device)) + for name1, child in module.named_children(): + self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1) + + @torch.no_grad() + def quantize_model(self, model: nn.Module, tokenizer: Any): + """ + Quantizes the model using the dataset + + Args: + model (`nn.Module`): + The model to quantize + tokenizer (`Any`): + The tokenizer to use in order to prepare the dataset. You can pass either: + - A custom tokenizer object. + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + Returns: + `nn.Module`: The quantized model + """ + + if not is_auto_gptq_available(): + raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`") + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed to quantize model.") + + model.eval() + + # For Transformer model + has_config = False + has_device_map = False + if hasattr(model, "config"): + has_config = True + use_cache = model.config.use_cache + model.config.use_cache = False + + if hasattr(model, "hf_device_map"): + devices = list(model.hf_device_map.values()) + if "disk" in devices: + raise ValueError("disk offload is not supported with GPTQ quantization") + if "cpu" in devices and len(model.hf_device_map) > 1: + logger.info("Cpu offload is not recommended. There might be some issues with the memory") + hook = None + for name, device in model.hf_device_map.items(): + if device == "cpu": + module = recurse_getattr(model, name) + remove_hook_from_module(module, recurse=True) + module, hook = cpu_offload_with_hook(module, prev_module_hook=hook) + # If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work + has_device_map = True + + if hasattr(model, "dtype"): + self.use_cuda_fp16 = model.dtype == torch.float16 + + if self.model_seqlen is None: + self.model_seqlen = get_seqlen(model) + + device = get_device(model) + + # Step 1: Prepare the data + if isinstance(tokenizer, str): + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + except Exception: + raise ValueError( + f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained` + with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input. + For now, we only support quantization for text model. Support for vision, speech and multimodel will come later.""" + ) + if self.dataset is None: + raise ValueError("You need to pass `dataset` in order to quantize your model") + elif isinstance(self.dataset, str): + dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train") + elif isinstance(self.dataset, list): + dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] + else: + raise ValueError("You need to pass a list of string or a string for `dataset`") + + dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) + + # Step 2: get the input of the 1st block + # To do that, we need to put the modules preceding the first block on the same device as the first bloc. + # Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs. + + layer_inputs = [] + layer_outputs = [] + layer_input_kwargs = [] + + if self.block_name_to_quantize is None: + self.block_name_to_quantize = get_block_name_with_pattern(model) + + if self.module_name_preceding_first_block is None: + self.module_name_preceding_first_block = get_preceding_modules(model, self.block_name_to_quantize) + + blocks = recurse_getattr(model, self.block_name_to_quantize) + + if not has_device_map: + # put modules from module_name_preceding_first_block on cuda + for module_name in self.module_name_preceding_first_block: + module = recurse_getattr(model, module_name) + if module is None: + raise ValueError(f"Module {module_name} was not found in model") + module = module.to(0) + blocks[0] = blocks[0].to(0) + + def store_input_hook(_, input, *args): + kwargs = args[0] + input = input[0] + if input is None: + if "hidden_states" in kwargs: + input = kwargs["hidden_states"] + else: + raise ValueError("No input value found in the foward pass") + layer_inputs.append(input) + other_kwargs = {} + for k, v in kwargs.items(): # make sure other arguments also be captured + if k not in ["hidden_states"]: + other_kwargs[k] = v + layer_input_kwargs.append(other_kwargs) + raise ValueError + + handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + for data in dataset: + for k, v in data.items(): + # put the data on gpu, we won't put them back to cpu + data[k] = v.to(0) + try: + model(**data) + except ValueError: + pass + + handle.remove() + if not has_device_map: + blocks[0].to(device) + for module_name in self.module_name_preceding_first_block: + module = recurse_getattr(model, module_name) + if module is None: + raise ValueError(f"Module {module_name} was not found in model") + + torch.cuda.empty_cache() + + # Step 3: Quantize the blocks + quantizers = {} + for i, block in enumerate(tqdm(blocks, desc=f"Quantizing {self.block_name_to_quantize} blocks ")): + logger.info(f"Start quantizing block {self.block_name_to_quantize} {i + 1}/{len(blocks)}") + # move block to cuda if needed + # in case we have offload modules, we need to put them on cuda because of GPTQ object + if not has_device_map or get_device(block) == torch.device("cpu"): + block = block.to(0) + layers = get_layers(block) + if self.true_sequential: + # lazy sequential but works well + layers_name_list = [[key] for key in layers.keys()] + else: + layers_name_list = [list(layers.keys())] + logger.info(f"Module to quantize {layers_name_list}") + for subset_name_list in tqdm(layers_name_list, leave=False, desc="Quantizing layers inside the block"): + subset_layers = {name: layers[name] for name in subset_name_list} + gptq = {} + handles = [] + # add hook for each layer in subset_layers + for name in subset_layers: + gptq[name] = GPTQ(subset_layers[name]) + gptq[name].quantizer.configure(bits=self.bits, sym=self.sym, perchannel=True) + + def add_batch(name): + def tmp(_, input, output): + gptq[name].add_batch(input[0].data, output.data) + + return tmp + + # because it adding a hook will replace the old one. + handles.append(subset_layers[name].register_forward_hook(add_batch(name))) + # update Hessian for each layer in subset_layers thanks to the hook + for j in range(len(dataset)): + # the args are already on the gpu + # don't need to store the output + block(layer_inputs[j], **layer_input_kwargs[j]) + # remove hook + for h in handles: + h.remove() + for name in subset_name_list: + logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...") + scale, zero, g_idx = gptq[name].fasterquant( + percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act + ) + quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = ( + gptq[name].quantizer, + scale, + zero, + g_idx, + ) + gptq[name].free() + del subset_layers + # we get the new output from the partial quantized block + for j in range(len(dataset)): + layer_output = block(layer_inputs[j], **layer_input_kwargs[j])[0] + layer_outputs.append(layer_output) + + # put back to device + if not has_device_map: + blocks[i] = block.to(device) + del layers + del layer_inputs + layer_inputs, layer_outputs = layer_outputs, [] + torch.cuda.empty_cache() + + if self.bits == 4 and not self.disable_exllama: + if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])): + logger.warning( + "Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`" + ) + self.disable_exllama = True + # Step 4: Pack the model at the end (Replacing the layers) + self.pack_model(model=model, quantizers=quantizers) + + model.is_quantized = True + model.quantization_method = QuantizationMethod.GPTQ + if has_config: + model.config.use_cache = use_cache + model.config.quantization_config = self.to_dict() + + # Step 5: Any post-initialization that require device information, for example buffers initialization on device. + model = self.post_init_model(model) + + torch.cuda.empty_cache() + return model + + def post_init_model(self, model): + """ + Post-initialization that require device information, for example buffers initialization on device. + + Args: + model (`nn.Module`): + The input model + """ + if self.bits == 4 and not self.disable_exllama: + if get_device(model) == torch.device("cpu") or ( + hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"]) + ): + raise ValueError( + "Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU." + "You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object" + ) + + return autogptq_post_init(model, use_act_order=self.desc_act) + + def pack_model( + self, + model: nn.Module, + quantizers: Dict[str, Tuple], + ): + """ + Pack the model by replacing the layers by quantized layers + + Args: + model (`nn.Module`): + The model to pack + quantizers (`Dict[str,Tuple]`): + A mapping of the layer name and the data needed to pack the layer + """ + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama, + ) + logger.info("Packing model...") + layers = get_layers(model) + layers = {n: layers[n] for n in quantizers} + self._replace_by_quant_layers(model, quantizers) + qlayers = get_layers(model, [QuantLinear]) + for name in qlayers: + logger.info(name) + quantizers[name], scale, zero, g_idx = quantizers[name] + # so far can only pack layer on CPU + layer_device = qlayers[name].device + qlayers[name].to("cpu") + layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu") + qlayers[name].pack(layers[name], scale, zero, g_idx) + qlayers[name].to(layer_device) + + logger.info("Model packed.") + + def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = False): + """ + Save model state dict and configs + + Args: + model (`nn.Module`): + Model to be saved. The model can be wrapped or unwraped. + save_dir (`str`): + Directory to which to save. Will be created if it doesn't exist. + max_shard_size (`str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + safe_serialization (`bool`, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + + """ + + if not is_accelerate_available(): + raise RuntimeError( + "You need to install accelerate in order to save a quantized model. You can do it with `pip install accelerate`" + ) + + os.makedirs(save_dir, exist_ok=True) + model = model.to("cpu") + # save model and config + accelerator = Accelerator() + accelerator.save_model(model, save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2) + + +def load_quantized_model( + model: nn.Module, + save_folder: str, + quant_config_name: str = GPTQ_CONFIG, + state_dict_name: Optional[str] = None, + device_map: Optional[str] = None, + max_memory: Optional[Dict] = None, + no_split_module_classes: Optional[Dict] = None, + offload_folder: Optional[str] = None, + offload_buffers: Optional[str] = None, + offload_state_dict: bool = False, + disable_exllama: bool = False, +): + """ + Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map. + + Args: + model (`nn.Module`): + The model can be enpty or not. + save_folder (`str`): + Directory to which to load the weights. + quant_config_name (`str`, defaults to `GPTQ_CONFIG`): + Name of the quantization config file + state_dict_name (`Optional[str]`, defaults to `None`): + Name of the state dict file + device_map (`Optional[str]`, defaults to `None`): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. + max_memory (`Optional[Dict]`, defaults to `None`): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. + no_split_module_classes (`Optional[Dict]`, defaults to `None`): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + offload_folder (`Optional[str]`, defaults to `None`): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_buffers (`Optional[str]`, defaults to `None`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + offload_state_dict (`bool`, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map + picked contains `"disk"` values. + disable_exllama (`bool`, defaults to `False`): + Whether to use exllama backend. Only works with `bits` = 4. + + Returns: + `nn.Module`: The quantized model + """ + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed to run quantized model.") + if not is_auto_gptq_available(): + raise RuntimeError("auto-gptq is required in order to load quantized weights : `pip install auto-gptq`") + if not is_accelerate_available(): + raise RuntimeError( + "You need to install accelerate in order to load and dispatch weights to" + "a quantized model. You can do it with `pip install accelerate`" + ) + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.") + + with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f: + quantize_config_dict = json.load(f) + quantizer = GPTQQuantizer.from_dict(quantize_config_dict) + quantizer.disable_exllama = disable_exllama + + model = quantizer.convert_model(model) + + if no_split_module_classes is None: + no_split_module_classes = quantizer.get_no_split_module_classes(model) + + model = load_checkpoint_and_dispatch( + model, + checkpoint=os.path.join(save_folder, state_dict_name) if state_dict_name is not None else save_folder, + device_map=device_map, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + offload_folder=offload_folder, + offload_buffers=offload_buffers, + offload_state_dict=offload_state_dict, + ) + + model = quantizer.post_init_model(model) + model.is_quantized = True + model.quantization_method = QuantizationMethod.GPTQ + model.eval() + return model diff --git a/optimum/gptq/utils.py b/optimum/gptq/utils.py new file mode 100644 index 0000000000..b7387561c2 --- /dev/null +++ b/optimum/gptq/utils.py @@ -0,0 +1,115 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from logging import getLogger +from typing import Optional, Union + +import torch +from torch import nn +from transformers.pytorch_utils import Conv1D + +from .constants import BLOCK_PATTERNS, SEQLEN_KEYS_TRANFORMERS + + +logger = getLogger(__name__) + + +""" +Set of utilities to get specific attributes of a model +""" + + +def get_layers(module: nn.Module, layers=[Conv1D, nn.Conv2d, nn.Linear], prefix: Optional[str] = None, name: str = ""): + """ + Get all the layers with a specific prefix in the module + Args: + module (`nn.Module`): + The module that contains our layers + layers (`list`, defaults to `[Conv1D, nn.Conv2d, nn.Linear]`): + Type of the layers that we want to get + prefix (`Optional[str]`, defaults to `None`): + Prefix of layers + name (`str`, defaults to `""`): + Used for recursion. Don't modify + + Returns: + `Dict[str,Union[Conv1D, nn.Conv2d, nn.Linear]]`: Mapping of the name of the layer and the actual layer + """ + for layer in layers: + if isinstance(module, layer): + if prefix is not None: + if name.startswith(prefix): + return {name: module} + else: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(get_layers(child, layers=layers, prefix=prefix, name=name + "." + name1 if name != "" else name1)) + return res + + +def get_block_name_with_pattern(model: nn.Module): + """ + Get the name of the module that contains the transformers blocks by checking if any modules has a specific pattern + + Args: + model (`nn.Module`): + The input model + Returns: + `str`: The name of the module that contains the Transformer blocks. + """ + modules_names = [n for n, _ in model.named_modules()] + for pattern_candidate in BLOCK_PATTERNS: + pattern_candidate = pattern_candidate + if any([pattern_candidate in name for name in modules_names]): + return pattern_candidate + raise ValueError("Block pattern could not be match. Pass `block_name_to_quantize` argument in `quantize_model`") + + +def get_preceding_modules(model: nn.Module, module_name: str): + previous_module_name = [] + stop_adding = False + + def _get_preceding_modules(model: nn.Module, module_name: str, name: str = ""): + nonlocal stop_adding + for name_bis, child in model.named_children(): + new_name = name + "." + name_bis if name != "" else name_bis + if new_name == module_name: + stop_adding = True + break + _get_preceding_modules(child, module_name, name=new_name) + if not stop_adding: + previous_module_name.append(name) + return previous_module_name + + return _get_preceding_modules(model, module_name) + + +def get_device(obj: Union[torch.Tensor, nn.Module]): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + +def get_seqlen(model: nn.Module): + if hasattr(model, "config"): + model_config = model.config.to_dict() + if any([k in model_config for k in SEQLEN_KEYS_TRANFORMERS]): + for key in SEQLEN_KEYS_TRANFORMERS: + if key in model_config: + return model_config[key] + logger.info( + "We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`" + ) + return 2048 diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 534885c851..153499497d 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -31,6 +31,7 @@ check_if_pytorch_greater, check_if_transformers_greater, is_accelerate_available, + is_auto_gptq_available, is_diffusers_available, is_onnx_available, is_onnxruntime_available, diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 5e6049bd41..a08bb1af19 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -46,6 +46,7 @@ _pydantic_available = importlib.util.find_spec("pydantic") is not None _accelerate_available = importlib.util.find_spec("accelerate") is not None _diffusers_available = importlib.util.find_spec("diffusers") is not None +_auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None torch_version = None if is_torch_available(): @@ -100,6 +101,10 @@ def is_diffusers_available(): return _diffusers_available +def is_auto_gptq_available(): + return _auto_gptq_available + + @contextmanager def check_if_pytorch_greater(target_version: str, message: str): r""" diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index e48a128051..cdbce3df8f 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -24,7 +24,7 @@ import torch -from . import is_accelerate_available, is_diffusers_available +from . import is_accelerate_available, is_auto_gptq_available, is_diffusers_available # Used to test the hub @@ -55,6 +55,13 @@ def require_accelerate(test_case): return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) +def require_auto_gptq(test_case): + """ + Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed. + """ + return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" torch_device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/gptq/Dockerfile_quantization_gpu b/tests/gptq/Dockerfile_quantization_gpu new file mode 100644 index 0000000000..34a2a13552 --- /dev/null +++ b/tests/gptq/Dockerfile_quantization_gpu @@ -0,0 +1,26 @@ +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 +CMD nvidia-smi + +# Ignore interactive questions during `docker build` +ENV DEBIAN_FRONTEND noninteractive + +# Install and update tools to minimize security vulnerabilities +RUN apt-get update +RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \ + bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev python3-pip && \ + apt-get clean +RUN unattended-upgrade +RUN apt-get autoremove -y + +RUN python3 -m pip install -U pip + +RUN pip install torch torchvision torchaudio +RUN pip install transformers accelerate auto-gptq datasets + +# Install Optimum +COPY . /workspace/optimum +RUN pip install /workspace/optimum[tests] + +ENV RUN_SLOW=1 +WORKDIR /workspace/optimum/tests/ +CMD pytest gptq/test_*.py --durations=0 -s -vvvvv diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py new file mode 100644 index 0000000000..77b668c2a0 --- /dev/null +++ b/tests/gptq/test_quantization.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.testing_utils import slow + +from optimum.gptq import GPTQQuantizer, load_quantized_model +from optimum.gptq.data import get_dataset +from optimum.utils.testing_utils import require_accelerate, require_auto_gptq, require_torch_gpu + + +@slow +@require_auto_gptq +@require_torch_gpu +class GPTQTest(unittest.TestCase): + model_name = "bigscience/bloom-560m" + + input_text = "Hello my name is" + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") + EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.") + + # this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings + EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 + + bits = 4 + group_size = 128 + desc_act = False + disable_exllama = True + + dataset = [ + "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." + ] + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + cls.model_fp16 = AutoModelForCausalLM.from_pretrained( + cls.model_name, torch_dtype=torch.float16, device_map={"": 0} + ) + cls.mem_fp16 = cls.model_fp16.get_memory_footprint() + + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + cls.quantizer = GPTQQuantizer( + bits=cls.bits, + dataset=cls.dataset, + group_size=cls.group_size, + desc_act=cls.desc_act, + disable_exllama=cls.disable_exllama, + ) + + cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer) + + def test_memory_footprint(self): + """ + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + + mem_quantized = self.quantized_model.get_memory_footprint() + + self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + + def test_quantized_layers_class(self): + """ + A simple test to check if the model conversion has been done correctly by checking on the + the class type of the linear layers of the converted models + """ + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama, + ) + self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear) + + def check_inference_correctness(self, model): + """ + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # Get the generation + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + # Check the exactness of the result + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_generate_quality(self): + self.check_inference_correctness(self.quantized_model) + + @require_accelerate + def test_serialization(self): + """ + Test the serialization of the model and the loading of the quantized weights + """ + from accelerate import init_empty_weights + + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantizer.save(self.quantized_model, tmpdirname) + self.quantized_model.config.save_pretrained(tmpdirname) + with init_empty_weights(): + empty_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16) + empty_model.tie_weights() + quantized_model_from_saved = load_quantized_model(empty_model, save_folder=tmpdirname, device_map={"": 0}) + self.check_inference_correctness(quantized_model_from_saved) + + +class GPTQTestExllama(GPTQTest): + disable_exllama = False + + +class GPTQUtilsTest(unittest.TestCase): + """ + Test utilities + """ + + model_name = "facebook/opt-125m" + expected_seqlen = 2048 + expected_block_name = "model.decoder.layers" + expected_block_name_class = "OPTDecoderLayer" + expected_preceding_modules = [ + "model.decoder.embed_tokens", + "model.decoder.embed_positions", + "model.decoder.final_layer_norm", + ] + + def test_get_seqlen(self): + from optimum.gptq.utils import get_seqlen + + model = AutoModelForCausalLM.from_pretrained(self.model_name) + seqlen = get_seqlen(model) + self.assertEqual(seqlen, self.expected_seqlen) + + def test_get_block_name(self): + from optimum.gptq.utils import get_block_name_with_pattern + from optimum.utils import recurse_getattr + + model = AutoModelForCausalLM.from_pretrained(self.model_name) + block_name = get_block_name_with_pattern(model) + self.assertEqual(block_name, self.expected_block_name) + block_class_name = recurse_getattr(model, block_name)[0].__class__.__name__ + self.assertEqual(block_class_name, self.expected_block_name_class) + + def test_get_preceding_modules(self): + from optimum.gptq.utils import get_preceding_modules + + model = AutoModelForCausalLM.from_pretrained(self.model_name) + modules_names = get_preceding_modules(model, self.expected_block_name) + self.assertCountEqual(modules_names, self.expected_preceding_modules) + + +class BloomGPTQUtilsTest(GPTQUtilsTest): + model_name = "bigscience/bloom-560m" + expected_seqlen = 2048 + expected_block_name = "transformer.h" + expected_block_name_class = "BloomBlock" + expected_preceding_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm"] + + +class GPTQDataTest(unittest.TestCase): + """ + Test data + """ + + model_name = "facebook/opt-125m" + NBSAMPLES = 128 + SEQLEN = 2048 + + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) + + @parameterized.expand(["wikitext2", "c4", "ptb", "c4-new", "ptb-new"]) + def test_dataset(self, dataset): + train_dataset = get_dataset( + dataset, self.tokenizer, nsamples=self.NBSAMPLES, seqlen=self.SEQLEN, split="train" + ) + self.assertEqual(len(train_dataset), self.NBSAMPLES) + self.assertCountEqual(list(train_dataset[0].keys()), ["input_ids", "attention_mask"]) + self.assertEqual(list(train_dataset[0]["input_ids"].size()), [1, self.SEQLEN])