-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab8c961
commit 21dab52
Showing
2 changed files
with
419 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
<!--- | ||
Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Supervised Fine-Tuning of Llama 3 70B on one AWS Trainium instance | ||
|
||
[[open-in-collab]] | ||
|
||
_Note: The complete script for this tutorial can be downloaded [here](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/lora_finetune_llm.py)._ | ||
|
||
This tutorial will teach you how to fine-tune open source LLMs like [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-70B) on AWS Trainium. In our example, we are going to leverage the [Optimum Neuron](https://huggingface.co/docs/optimum-neuron/index), [Transformers](https://huggingface.co/docs/transformers/index) and [Datasets](https://huggingface.co/docs/datasets/index) libraries. | ||
|
||
You will learn how to: | ||
|
||
1. [Setup AWS Environment](#1-setup-aws-environment) | ||
2. [Load and process the dataset](#2-load-and-prepare-the-dataset) | ||
3. [Fine-tune Llama using LoRA on AWS Trainium with the `NeuronTrainer`](#3-fine-tune-llama-using-lora-on-aws-trainium-with-the-neurontrainer) | ||
4. [Launch Training](#4-launch-training) | ||
5. [Evaluate and test fine-tuned Llama model](#5-evaluate-and-test-fine-tuned-llama-model) | ||
|
||
<Tip> | ||
|
||
While we will use `Llama-3 70B` in this tutorial, it is completely possible to use other models, simply by swtiching the `model_id`. | ||
|
||
</Tip> | ||
|
||
## 1. Setup AWS Environment | ||
|
||
Before starting this tutorial, you will need to setup your environment: | ||
|
||
1. Create an AWS Trainium instance. **You will need a `trn1.32xlarge`, which contains 16 Neuron Devices.** You can follow this [guide](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance) to create one. | ||
2. Make sure you are logged in on the Hugging Face Hub: | ||
```bash | ||
huggingface-cli login --token YOUR_TOKEN | ||
``` | ||
3. Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 70B, for which there are two possibilities: | ||
* The official gated repo: [`meta-llama/Meta-Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | ||
* The non-official un-gated repo: [`NousResearch/Meta-Llama-3-70B`](https://huggingface.co/NousResearch/Meta-Llama-3-70B) | ||
4. Clone the Optimum Neuron repository, **which contains the [complete script](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/lora_finetune_llm.py) described in this tutorial:** | ||
```bash | ||
git clone https://github.com/huggingface/optimum-neuron.git | ||
``` | ||
|
||
## 2. Load and prepare the dataset | ||
|
||
For this tutorial, we will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k), an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization. | ||
|
||
Example: | ||
|
||
```python | ||
{ | ||
"instruction": "What is world of warcraft", | ||
"context": "", | ||
"response": ( | ||
"World of warcraft is a massive online multi player role playing game. " | ||
"It was released in 2004 by blizarre entertainment" | ||
) | ||
} | ||
``` | ||
|
||
We can use the `load_dataset()` method from the 🤗 Datasets library to load the `dolly` dataset very easily. | ||
|
||
```python | ||
from datasets import load_dataset | ||
from random import randrange | ||
|
||
# Load dataset from the hub | ||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | ||
|
||
print(f"dataset size: {len(dataset)}") | ||
print(dataset[randrange(len(dataset))]) | ||
# dataset size: 15011 | ||
``` | ||
To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `format_dolly` that takes a raw sample and returns a string with our format instruction. | ||
```python | ||
def format_dolly(sample): | ||
instruction = f"### Instruction\n{sample['instruction']}" | ||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None | ||
response = f"### Answer\n{sample['response']}" | ||
# join all the parts together | ||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) | ||
return prompt | ||
``` | ||
In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. In other words, we are stacking multiple samples to one sequence and split them with an EOS Token. | ||
While we could do this manually, we will use the `NeuronSFTTrainer` instead in the next section to do so. | ||
## 3. Supervised Fine-Tuning of Llama on AWS Trainium with the `NeuronSFTTrainer` | ||
Normally you would use the **[SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer)** and **[SFTConfig](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTConfig)** classes to perform supervised fine-tuning of PyTorch-based transformer models. | ||
Instead, here we will be using the [~`optimum.neuron.NeuronSFTTrainer`] and [~`optimum.neuron.NeuronSFTConfig`], these classes replicate the ones from the `trl` library while making sure they work properly on Neuron cores. | ||
Since Llama-3 70B is a big model it will not fit on a single Neuron device, even with distributed training. To actually fine-tune a 70B model using only one Trainium instance we need to use both LoRA and distributed training. | ||
In Optimum Neuron we support: | ||
1. [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html): It is an optimization of data-parallelism which consists in sharding the optimizer state (which usually represents half or more of the memory needed on the device) over the data-parallel ranks. | ||
2. [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): It is a technique which consists in sharding each of your model matrix-multiplications along a given axis (row or column) on multiple devices. It also known as intra-layer model parallelism. The number of devices to shard your parameters on is called the `tensor_parallel_size`. | ||
3. [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf): It is an optimization over Tensor Parallelism which shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations. | ||
4. [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): It consists in sharding the model block layers on multiple devices. It is also known as inter-layer model parallelism. The number of devices to shard your layers on is called the `pipeline_parallel_size`. | ||
<Tip> | ||
If you want to know more about distributed training you can take a look at the [documentation](https://huggingface.co/docs/optimum-neuron/guides/distributed_training). | ||
</Tip> | ||
Here, we will use Tensor Parallelism in conjuction with LoRA. | ||
Our training code will look as follows: | ||
```python | ||
from peft import LoraConfig | ||
from optimum.neuron import NeuronTrainer as Trainer | ||
from optimum.neuron.distributed import lazy_load_for_parallelism | ||
from optimum.neuron.utils import get_peft_model | ||
# Define the tensor_parallel_size | ||
tensor_parallel_size = 8 | ||
# Load model from the Hugging face Hub | ||
with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size): | ||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
# Injecting the LoRA adapters | ||
config = LoraConfig( | ||
r=64, | ||
lora_alpha=128, | ||
lora_dropout=0.0, | ||
target_modules=["embed_tokens", "lm_head", "q_proj", "v_proj"], | ||
task_type="CAUSAL_LM", | ||
) | ||
model = get_peft_model(model, config) | ||
trainer = Trainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
train_dataset=dataset, | ||
data_collator=default_data_collator, # no special collator needed since we stacked the dataset | ||
) | ||
# Start training | ||
trainer.train() | ||
trainer.save_model() # saves the tokenizer too for easy upload | ||
``` | ||
|
||
The key points here are: | ||
|
||
- We use the `lazy_load_for_parallelism` context manager to lazily load the model. This will not load the full model weights on each worker, but instead only load the required weights (sharded or full). **This is much more memory efficient, and often mandatory to use.** | ||
- We inject the LoRA adapters using `optimum.neuron.utils.get_peft_model` by specifying both the model to transform and the LoRA config. | ||
- We use the [~`optimum.neuron.NeuronTrainer`] to perform training. It will take the lazily loaded model, along with the `training_args`, which are an instance of [~`optimum.neuron.NeuronTrainingArguments`], and will handle all the parallelization and training on the Neuron cores. | ||
|
||
## 4. Launch Training | ||
|
||
We prepared a script called [lora_finetune_llm.py](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/lora_finetune_llm.py) summing up everything mentioned in this tutorial. | ||
|
||
PyTorch Neuron uses `torch_xla`. It evaluates operations lazily during execution of the training loops, which means it builds a symbolic graph in the background and the graph is executed on the hardware only when the tensor is printed, transfered to CPU, or `xm.mark_step()` is called. During execution, multiple graphs can be build depending on control-flow and it can take time to compile each graph sequentially. To alleviate that, the Neuron SDK provides `neuron_parallel_compile`, a tool which performs a fast trial run that builds all the graphs and compile them in parallel. This step is usually called precompilation. | ||
|
||
### Precompilation | ||
|
||
When training models on AWS Trainium we first need to compile our model with our training arguments. | ||
|
||
To overcome this, we added a [model cache repository](https://huggingface.co/docs/optimum-neuron/guides/cache_system), which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses. | ||
|
||
_Note: If your model configuration is not cached please open an issue on [Github](https://github.com/huggingface/optimum-neuron/issues), we are happy to include it._ | ||
|
||
The compilation command simply consists in calling your script as an input to the `neuron_parallel_compile` utility: | ||
|
||
```bash | ||
MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node=32 lora_finetune_llm.py \ | ||
--model_id meta-llama/Meta-Llama-3-70B \ | ||
--bf16 True \ | ||
--learning_rate 5e-5 \ | ||
--output_dir dolly_llama \ | ||
--overwrite_output_dir True \ | ||
--per_device_train_batch_size 1 \ | ||
--gradient_accumulation_steps 16 \ | ||
--gradient_checkpointing True \ | ||
--tensor_parallel_size 8 \ | ||
--max_steps 10 \ | ||
--logging_steps 10 | ||
``` | ||
|
||
<Tip> | ||
|
||
Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training. | ||
|
||
</Tip> | ||
|
||
_Note: Compiling without a cache can take a while. It will also create dummy files in the `dolly_llama_sharded` during compilation you will have to remove them afterwards. We also need to add `MALLOC_ARENA_MAX=64` to limit the CPU allocation to avoid potential crashes, don't remove it for now._ | ||
|
||
```bash | ||
# remove dummy artifacts which are created by the precompilation command | ||
rm -rf dolly_llama | ||
``` | ||
|
||
### Actual Training | ||
|
||
After compilation is done we can start our actual training with a similar command, we just need to remove the use of `neuron_parallel_compile`. | ||
|
||
We will use `torchrun` to launch our training script. `torchrun` is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as `nproc_per_node` arguments alongside our hyperparameters. | ||
|
||
The difference to the compilation command is that we changed from `max_steps=10` to `num_train_epochs=3`. | ||
|
||
Launch the training, with the following command. | ||
|
||
```bash | ||
MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 torchrun --nproc_per_node=32 lora_finetune_llm.py \ | ||
--model_id meta-llama/Meta-Llama-3-70B \ | ||
--bf16 True \ | ||
--learning_rate 5e-5 \ | ||
--output_dir dolly_llama \ | ||
--overwrite_output_dir True \ | ||
--skip_cache_push True \ | ||
--per_device_train_batch_size 1 \ | ||
--gradient_accumulation_steps 16 \ | ||
--gradient_checkpointing True \ | ||
--tensor_parallel_size 8 \ | ||
--num_train_epochs 3 \ | ||
--logging_steps 10 | ||
``` | ||
|
||
That's it, we successfully trained Llama-3 70B on AWS Trainium! | ||
|
||
But before we can share and test our model we need to consolidate our model. Since we used Tensor Parallelism during training, we saved sharded versions of the checkpoints. We need to consolidate them now. | ||
|
||
### Consolidate the Checkpoint | ||
|
||
The Optimum CLI provides a way of doing that very easily via the `optimum neuron consolidate [sharded_checkpoint] [output_dir]` command: | ||
|
||
```bash | ||
optimum-cli neuron consolidate dolly_llama dolly_llama | ||
``` | ||
|
||
### Merge the LoRA adapters | ||
|
||
TODO | ||
|
||
## 5. Evaluate and test fine-tuned Llama model | ||
|
||
As for training, to be able to run inference on AWS Trainium or AWS Inferentia2 we need to compile our model. In this case, we will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 for inference. | ||
|
||
Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the `NeuronModelForCausalLM` class to load our vanilla transformers checkpoint and convert it to neuron. | ||
|
||
```python | ||
from optimum.neuron import NeuronModelForCausalLM | ||
from transformers import AutoTokenizer | ||
|
||
compiler_args = {"num_cores": 2, "auto_cast_type": 'fp16'} | ||
input_shapes = {"batch_size": 1, "sequence_length": 2048} | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("dolly_llama") | ||
model = NeuronModelForCausalLM.from_pretrained( | ||
"dolly_llama", | ||
export=True, | ||
**compiler_args, | ||
**input_shapes) | ||
``` | ||
|
||
_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific._ | ||
|
||
```python | ||
# COMMENT IN if you want to save the compiled model | ||
# model.save_pretrained("compiled_dolly_llama") | ||
``` | ||
|
||
We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a `dict` with our `instruction` and optionally a `context`. | ||
|
||
```python | ||
def format_dolly_inference(sample): | ||
instruction = f"### Instruction\n{sample['instruction']}" | ||
context = f"### Context\n{sample['context']}" if "context" in sample else None | ||
response = f"### Answer\n" | ||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) | ||
return prompt | ||
|
||
|
||
def generate(sample): | ||
prompt = format_dolly_inference(sample) | ||
inputs = tokenizer(prompt, return_tensors="pt") | ||
outputs = model.generate( | ||
**inputs, | ||
max_new_tokens=512, | ||
do_sample=True, | ||
temperature=0.9, | ||
top_k=50, | ||
top_p=0.9 | ||
) | ||
return tokenizer.decode(outputs[0], skip_special_tokens=False)[len(prompt):] | ||
``` | ||
|
||
Let's test inference. First we test without a context. | ||
|
||
_Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2._ | ||
|
||
```python | ||
prompt = { | ||
"instruction": "Can you tell me something about AWS?" | ||
} | ||
res = generate(prompt) | ||
|
||
print(res) | ||
``` | ||
|
||
> AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys. | ||
That looks correct. Now, lets add some context, e.g. as you would do for RAG applications: | ||
|
||
```python | ||
prompt = { | ||
"instruction": "How can I train models on AWS Trainium?", | ||
"context": "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks." | ||
} | ||
res = generate(prompt) | ||
|
||
print(res) | ||
``` | ||
|
||
> You can use the Optimum Neuron interface to train models on AWS Trainium. | ||
Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium. |
Oops, something went wrong.