Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into andreyan/exporters_model_configs
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyanufr authored Jul 5, 2023
2 parents 5cb208f + bc5f825 commit 395ea37
Show file tree
Hide file tree
Showing 31 changed files with 2,935 additions and 352 deletions.
78 changes: 78 additions & 0 deletions docs/source/exporters/onnx/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,81 @@ For each model architecture, you can find the list of supported tasks via the [`
```

You can then pass one of these tasks to the `--task` argument in the `optimum-cli export onnx` command, as mentioned above.

## Custom export of Transformers models

Optimum allows for advanced users a finer-grained control over the configuration for the ONNX export. This is especially useful if you would like to export models with different keyword arguments, for example using `output_attentions=True` or `output_hidden_states=True`.

To support these use cases, [~exporters.__main__.main_export] supports two arguments: `model_kwargs` and `custom_onnx_configs`, which are used in the following fashion:

* `model_kwargs` allows to override some of the default arguments to the models `forward`, in practice as `model(**reference_model_inputs, **model_kwargs)`.
* `custom_onnx_configs` should be a `Dict[str, OnnxConfig]`, mapping from the submodel name (usually `model`, `encoder_model`, `decoder_model`, or `decoder_model_with_past` - [reference](https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py)) to a custom ONNX configuration for the given submodel.

A complete example is given below, allowing to export models with `output_attentions=True`.

```python
from optimum.exporters.onnx.__main__ import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig

from optimum.exporters.onnx.base import ConfigBehavior
from typing import Dict

class CustomWhisperOnnxConfig(WhisperOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs

if self._behavior is ConfigBehavior.ENCODER:
for i in range(self._config.encoder_layers):
common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"}
elif self._behavior is ConfigBehavior.DECODER:
for i in range(self._config.decoder_layers):
common_outputs[f"decoder_attentions.{i}"] = {
0: "batch_size",
2: "decoder_sequence_length",
3: "past_decoder_sequence_length + 1"
}
for i in range(self._config.decoder_layers):
common_outputs[f"cross_attentions.{i}"] = {
0: "batch_size",
2: "decoder_sequence_length",
3: "encoder_sequence_length_out"
}

return common_outputs

@property
def torch_to_onnx_output_map(self):
if self._behavior is ConfigBehavior.ENCODER:
# The encoder export uses WhisperEncoder that returns the key "attentions"
return {"attentions": "encoder_attentions"}
else:
return {}

model_id = "openai/whisper-tiny.en"
config = AutoConfig.from_pretrained(model_id)

custom_whisper_onnx_config = CustomWhisperOnnxConfig(
config=config,
task="automatic-speech-recognition",
)

encoder_config = custom_whisper_onnx_config.with_behavior("encoder")
decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False)
decoder_with_past_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=True)

custom_onnx_configs={
"encoder_model": encoder_config,
"decoder_model": decoder_config,
"decoder_with_past_model": decoder_with_past_config,
}

main_export(
model_id,
output="custom_whisper_onnx",
no_post_process=True,
model_kwargs={"output_attentions": True},
custom_onnx_configs=custom_onnx_configs
)
```
134 changes: 134 additions & 0 deletions examples/onnxruntime/training/stable-diffusion/text-to-image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Stable Diffusion Text-to-Image Fine-Tuning

This example shows how to leverage ONNX Runtime Training to fine-tune stable diffusion model on your own dataset.

Our team has tested finetuning `CompVis/stable-diffusion-v1-4` model on the `lambdalabs/pokemon-blip-captions` dataset and achieved the following speedup:
![image](https://github.com/microsoft/onnxruntime-training-examples/assets/31260940/00f199b1-3a84-4369-924d-fd6c613bd3b4)

___Note___:

___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___


## Running locally with PyTorch
### Installing the dependencies

___Note___: This example requires PyTorch nightly and [ONNX Runtime](https://github.com/Microsoft/onnxruntime) nightly
```bash
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu118
pip install onnxruntime-training --pre -f https://download.onnxruntime.ai/onnxruntime_nightly_cu118.html
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install
```
Or get your environment ready via Docker: [examples/onnxruntime/training/docker/Dockerfile-ort-nightly-cu118](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort-nightly-cu118)

Then, cd in the example folder and run
```bash
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

### Pokemon example

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

Run the following command to authenticate your token

```bash
huggingface-cli login
```

If you have already cloned the repo, then you won't need to go through these steps.

<br>

#### Hardware
Cited performance metrics used an NIVIDIA V100, 8-GPU cluster.

**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
<!-- accelerate_snippet_start -->
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" train_text_to_image.py --ort \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```
<!-- accelerate_snippet_end -->


To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"

accelerate launch --mixed_precision="fp16" train_text_to_image.py --ort \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```


Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `ORTStableDiffusionPipeline`


```python
from optimum.onnxruntime import ORTStableDiffusionPipeline

model_path = "path_to_saved_model"
pipe = ORTStableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```

#### Training with multiple GPUs

`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
for running distributed training with `accelerate`. Here is an example command:

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py --ort \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
accelerate>=0.16.0
transformers>=4.25.1
datasets
git+https://github.com/huggingface/diffusers
ftfy
tensorboard
Jinja2
Loading

0 comments on commit 395ea37

Please sign in to comment.