-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support InternVL 1.5/2.0 finetune (#737)
* support internvl finetune * fix * fix * fix * update * update * update * update cfg * fix * support phi3 * support full+lora+qlora * support internvl 26b * fix lora cfg * update all * update * update * update config * update config * update config * fix type and add readme * update readme * RENAME * fix * update * support internvl2 * update
- Loading branch information
1 parent
5a93e7d
commit f30ad4c
Showing
30 changed files
with
4,867 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# InterVL Full Pipeline | ||
|
||
English | [简体中文](./README_zh-CN.md) | ||
|
||
## InterVL 2 | ||
|
||
> [InternVL-2: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy](https://internvl.github.io/blog/2024-07-02-InternVL-2.0/) | ||
We introduce InternVL-2, currently the most powerful open-source Multimodal Large Language Model (MLLM). The InternVL-2 family includes models ranging from a 2B model, suitable for edge devices, to a 108B model, which is significantly more powerful. With larger-scale language models, InternVL-2-Pro demonstrates outstanding multimodal understanding capabilities, matching the performance of commercial closed-source models across various benchmarks. | ||
|
||
InternVL-2 family is built upon the following designs: | ||
|
||
- Progressive with larger language models: We introduce a progressive alignment training strategy, resulting in the first vision foundation model aligned with large language models. By employing the progressive training strategy where the model scales from small to large while the data refines from coarse to fine, we have completed the training of large models at a relatively low cost. This approach has demonstrated excellent performance with limited resources. | ||
- Multimodal input: With one set of parameters, our model supports multiple modalities of input, including text, images, video, audio, and 3D point clouds. | ||
- Multitask output: Our model supports various output formats, such as images, bounding boxes, and masks, demonstrating extensive versatility. By connecting the MLLM with multiple downstream task decoders, InternVL-2 can be generalized to hundreds of vision-language tasks while achieving performance comparable to expert models. | ||
|
||
<div align="center"> | ||
<img src="https://github.com/OpenGVLab/InternVL/assets/17425982/07845268-8b2c-4dc7-88dd-d10a173bdafe" alt="Image" /> | ||
</div> | ||
|
||
### Basic Introduction | ||
|
||
- `./v2/` contains the configuration files for training InterVL 2 | ||
- Supported fine-tuning of the InternVL 2B/4B/8B/26B model in full/LoRA/QLoRA single-image mode for now. We will support fine-tuning on multiple images and videos as soon as possible. | ||
- After training, you can use the `./v1_5/convert_to_official.py` script to convert the model trained by XTuner to the official format, so as to reuse all the official supported toolchains | ||
- All configurations are based on 8xA100 80G graphics cards, 2B/4B can use ZERO1 training, 8B models can use ZERO2, 26B models must run ZERO3, and there is no excessive adjustment of parameters, you can modify them according to your own needs | ||
- It is verified with LLaVA SFT data, which cannot fully reflect the fine-tuning performance. You can customize the data according to your own needs. We will provide a relatively fair fine-tuning dataset later | ||
|
||
### Data preparation | ||
|
||
If you also want to use the LLaVA SFT dataset for training, please refer to the [document](../../../docs/en/user_guides/dataset_prepare.md#llava-dataset) to prepare the data. | ||
|
||
For custom data, support multiple json and jsonl formats, the data organization can refer to the LLaVA SFT format, and support data sampling operations. | ||
|
||
**(1) Support multiple json or jsonl data** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
**(2) Support custom sampling** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
repeat_times=[2,0.5,3.5], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
### Training | ||
|
||
The provided configuration is mainly used for fine-tuning based on the official weights. After preparing the data, you can use the following command to train: | ||
|
||
```bash | ||
NPROC_PER_NODE=8 xtuner train internvl_v2_internlm2_5_8b_lora_finetune --deepspeed deepspeed_zero2 | ||
``` | ||
|
||
Default saved in `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/`. | ||
|
||
### Model Conversion | ||
|
||
After training, we will get a set of weights, that is `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/iter_xxx.pth`, in order to facilitate evaluation and dialogue, we can convert it to official weights. | ||
|
||
```bash | ||
python xtuner/configs/internvl/v1_5/convert_to_official.py xtuner/configs/internvl/v2/internvl_v2_internlm2_5_8b_lora_finetune.py ./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/iter_xxx.pth ./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/convert_model/ | ||
``` | ||
|
||
Here, a complete set of official weights including configuration will be generated under `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/convert_model`, you can use the [official toolchain](https://huggingface.co/OpenGVLab/InternVL2-8B) for evaluation and dialogue. | ||
|
||
If you encounter any problems during use, please feel free to contact us!!! | ||
|
||
## InterVL 1.5 | ||
|
||
> [How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites](https://arxiv.org/abs/2404.16821) | ||
In this report, we introduce InternVL 1.5, an open-source multimodal large language model (MLLM) to bridge the capability gap between open-source and proprietary commercial models in multimodal understanding. We introduce three simple improvements: (1) Strong Vision Encoder: we explored a continuous learning strategy for the large-scale vision foundation model -- InternViT-6B, boosting its visual understanding capabilities, and making it can be transferred and reused in different LLMs. (2) Dynamic High-Resolution: we divide images into tiles ranging from 1 to 40 of 448×448 pixels according to the aspect ratio and resolution of the input images, which supports up to 4K resolution input. (3) High-Quality Bilingual Dataset: we carefully collected a high-quality bilingual dataset that covers common scenes, document images, and annotated them with English and Chinese question-answer pairs, significantly enhancing performance in OCR- and Chinese-related tasks. We evaluate InternVL 1.5 through a series of benchmarks and comparative studies. Compared to both open-source and proprietary models, InternVL 1.5 shows competitive performance, achieving state-of-the-art results in 8 of 18 benchmarks. | ||
|
||
<div align="center"> | ||
<img src="https://github.com/InternLM/xtuner/assets/17425982/6dbe6a46-f01a-4c9d-ba44-0d857e5c0373" alt="Image" width="700" /> | ||
</div> | ||
|
||
### Basic Introduction | ||
|
||
- `./v1_5/` contains the configuration files for training InterVL 1.5 | ||
- Support InternVL 2B/4B/26B model full/LoRA/Qing efficiency and performance, it is recommended to choose the 4B model first | ||
- After training, you can use the `./v1_5/convert_to_official.py` script to convert the model trained by XTuner to the official format, so as to reuse all the official supported toolchains | ||
- All configurations are based on 8xA100 80G graphics cards, 2B/4B can use ZERO1 training, 8B models must run ZERO2, and there is no excessive adjustment of parameters, you can modify them according to your own needs | ||
- It is verified with LLaVA SFT data, which cannot fully reflect the fine-tuning performance. You can customize the data according to your own needs. We will provide a relatively fair fine-tuning dataset later | ||
|
||
### Data preparation | ||
|
||
If you also want to use the LLaVA SFT dataset for training, please refer to the [document](../../../docs/en/user_guides/dataset_prepare.md#llava-dataset) to prepare the data. | ||
|
||
For custom data, support multiple json and jsonl formats, the data organization can refer to the LLaVA SFT format, and support data sampling operations. | ||
|
||
**(1) Support multiple json or jsonl data** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
**(2) Support custom sampling** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
repeat_times=[2,0.5,3.5], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
### Training | ||
|
||
The provided configuration is mainly used for fine-tuning based on the official weights. After preparing the data, you can use the following command to train: | ||
|
||
```bash | ||
NPROC_PER_NODE=8 xtuner train internvl_v1_5_phi3_4b_lora_finetune --deepspeed deepspeed_zero1 | ||
# NPROC_PER_NODE=8 xtuner train internvl_v1_5_internlm2_26b_lora_finetune.py --deepspeed deepspeed_zero3 | ||
``` | ||
|
||
Default saved in `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/`. | ||
|
||
### Model Conversion | ||
|
||
After training, we will get a set of weights, that is `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/iter_xxx.pth`, in order to facilitate evaluation and dialogue, we can convert it to official weights. | ||
|
||
```bash | ||
python xtuner/configs/internvl/v1_5/convert_to_official.py xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_lora_finetune.py ./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/iter_xxx.pth ./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/internvl_v1_5_phi3_4b/ | ||
``` | ||
|
||
Here, a complete set of official weights including configuration will be generated under `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/internvl_v1_5_phi3_4b/`, you can use the [official toolchain](https://github.com/OpenGVLab/InternVL) for evaluation and dialogue. | ||
|
||
If you encounter any problems during use, please feel free to contact us!!! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# InterVL 全流程 | ||
|
||
[English](./README.md) | 简体中文 | ||
|
||
## InterVL 2 | ||
|
||
> [InternVL-2: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy](https://internvl.github.io/blog/2024-07-02-InternVL-2.0/) | ||
我们引入了 InternVL-2,目前最强大的开源多模态大语言模型(MLLM)。InternVL-2 系列包括从适合于边缘设备的 2B 模型到强大的 108B 模型等多种规模的模型。借助更大规模的语言模型,InternVL-2-Pro 展现出了出色的多模态理解能力,在各种基准测试中的性能与商业闭源模型相匹配。 | ||
|
||
InternVL-2 系列基于以下设计: | ||
|
||
- 渐进式的大型语言模型:我们引入了一种渐进式对齐训练策略,实现了首个与大型语言模型对齐的视觉基础模型。通过采用从小到大模型扩展、从粗到细数据优化的渐进式训练策略,我们以较低的成本完成了大模型的训练。这种方法已经展示了出色的性能,资源有限的情况下也能取得良好的结果。 | ||
- 多模态输入:使用一套参数,我们的模型支持文本、图像、视频、音频和 3D 点云等多种输入模态。 | ||
- 多任务输出:我们的模型支持图像、边界框和掩码等各种输出格式,展现出广泛的多功能性。通过将 MLLM 与多个下游任务解码器相连接,InternVL-2 可以泛化到数百个视觉语言任务,并取得与专家模型相当的性能。 | ||
|
||
<div align="center"> | ||
<img src="https://github.com/OpenGVLab/InternVL/assets/17425982/07845268-8b2c-4dc7-88dd-d10a173bdafe" alt="Image" /> | ||
</div> | ||
|
||
### 基本说明 | ||
|
||
- `./v2/` 包含着 InterVL 2 训练配置的配置文件 | ||
- 支持了 InternVL 2B/4B/8B/26B 模型全量/LoRA/QLoRA 单图模式的微调,会尽快支持多图和视频的微调。 | ||
- 在训练完成后,可以使用 `./v1_5/convert_to_official.py` 脚本将 XTuner 训练的模型转换为官方格式,从而复用官方所支持的所有工具链 | ||
- 目前所有配置都是以 8xA100 80G 显卡为基准,2B/4B 可以使用 ZERO1 训练,8B 模型要 ZERO2 运行,26B 模型必须要 ZERO3,并且没有对参数进行过多的调整,你可以按照你自己的需求进行修改 | ||
- 目前是以 LLaVA SFT 数据进行验证,无法充分反应微调性能,你可以根据自己的需求进行数据自定义,后续我们会提供一个相对公平的微调数据集 | ||
|
||
### 数据准备 | ||
|
||
如果你也想使用 LLaVA SFT 数据集进行训练,请参考[文档](../../../docs/zh_cn/user_guides/dataset_prepare.md#llava-dataset) 准备数据。 | ||
|
||
对于自定义数据,支持多种 json 和 jsonl 格式,内部数据组织可以参考 LLaVA SFT 格式,且支持数据采样操作。 | ||
|
||
**(1) 支持多个 json 或者 jsonl 数据** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
**(2) 支持自定义采样** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
repeat_times=[2,0.5,3.5], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
### 训练流程 | ||
|
||
所提供的配置主要用于基于官方权重继续微调。在准备好数据后,你可以使用以下命令进行训练: | ||
|
||
```bash | ||
NPROC_PER_NODE=8 xtuner train internvl_v2_internlm2_5_8b_lora_finetune --deepspeed deepspeed_zero2 | ||
``` | ||
|
||
默认保存在 `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/`。 | ||
|
||
### 模型转换 | ||
|
||
训练后,我们将获得一组权重即 `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/iter_xxx.pth`,为了方便评测和对话,可以将其转换为官方权重。 | ||
|
||
```bash | ||
python xtuner/configs/internvl/v1_5/convert_to_official.py xtuner/configs/internvl/v2/internvl_v2_internlm2_5_8b_lora_finetune.py ./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/iter_xxx.pth ./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/convert_model/ | ||
``` | ||
|
||
此时,会在 `./work_dirs/internvl_v2_internlm2_5_8b_lora_finetune/convert_model` 下生成一组包括配置的完整官方权重,你可以使用[官方工具链](https://huggingface.co/OpenGVLab/InternVL2-8B)进行评测和对话。 | ||
|
||
如果你在使用中碰到任何问题,欢迎联系我们!!! | ||
|
||
## InterVL 1.5 | ||
|
||
> [How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites](https://arxiv.org/abs/2404.16821) | ||
在本报告中,我们介绍了开源多模态大语言模型 InternVL 1.5,以弥补开源模型与商业专有模型在多模态理解能力上的差距。我们引入了三项简单的改进:(1) 强大的视觉编码器:我们探索了大规模视觉基础模型 InternViT-6B 的连续学习策略,提升了其视觉理解能力,并使其可以在不同的大语言模型中进行迁移和重复利用。(2) 动态高分辨率:我们根据输入图像的长宽比和分辨率,将图像划分为从1到40个448×448像素的瓦片,支持高达4K分辨率的输入。(3) 高质量双语数据集:我们精心收集了一个高质量的双语数据集,涵盖了常见场景、文档图像,并用英语和中文问答对进行了注释,显著提升了在OCR和中文相关任务中的性能。我们通过一系列基准测试和对比研究评估了 InternVL 1.5。与开源和专有模型相比,InternVL 1.5 表现出了竞争力,在18个基准中的8个中取得了最先进的结果。 | ||
|
||
<div align="center"> | ||
<img src="https://github.com/InternLM/xtuner/assets/17425982/6dbe6a46-f01a-4c9d-ba44-0d857e5c0373" alt="Image" width="700" /> | ||
</div> | ||
|
||
### 基本说明 | ||
|
||
- `./v1_5/` 包含着 InterVL 1.5 训练配置的配置文件 | ||
- 支持 InternVL 2B/4B/26B 模型全量/LoRA/QLoRA 微调,综合考虑效率性能,建议你优先选择 4B 模型 | ||
- 在训练完成后,可以使用 `./v1_5/convert_to_official.py` 脚本将 XTuner 训练的模型转换为官方格式,从而复用官方所支持的所有工具链 | ||
- 目前所有配置都是以 8xA100 80G 显卡为基准,2B/4B 可以使用 ZERO1 训练,26B 模型必须要 ZERO3 运行,并且没有对参数进行过多的调整,你可以按照你自己的需求进行修改 | ||
- 目前是以 LLaVA SFT 数据进行验证,无法充分反应微调性能,你可以根据自己的需求进行数据自定义,后续我们会提供一个相对公平的微调数据集 | ||
|
||
### 数据准备 | ||
|
||
如果你也想使用 LLaVA SFT 数据集进行训练,请参考[文档](../../../docs/zh_cn/user_guides/dataset_prepare.md#llava-dataset) 准备数据。 | ||
|
||
对于自定义数据,支持多种 json 和 jsonl 格式,内部数据组织可以参考 LLaVA SFT 格式,且支持数据采样操作。 | ||
|
||
**(1) 支持多个 json 或者 jsonl 数据** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
**(2) 支持自定义采样** | ||
|
||
```text | ||
llava_dataset = dict( | ||
type=InternVL_V1_5_Dataset, | ||
model_path=path, | ||
data_paths=['a.json','b.jsonl','c.json'], | ||
image_folders=['a',None,'c'], | ||
repeat_times=[2,0.5,3.5], | ||
template=prompt_template, | ||
max_length=max_length) | ||
``` | ||
|
||
### 训练流程 | ||
|
||
所提供的配置主要用于基于官方权重继续微调。在准备好数据后,你可以使用以下命令进行训练: | ||
|
||
```bash | ||
NPROC_PER_NODE=8 xtuner train internvl_v1_5_phi3_4b_lora_finetune --deepspeed deepspeed_zero1 | ||
# NPROC_PER_NODE=8 xtuner train internvl_v1_5_internlm2_26b_lora_finetune.py --deepspeed deepspeed_zero3 | ||
``` | ||
|
||
默认保存在 `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/`。 | ||
|
||
### 模型转换 | ||
|
||
训练后,我们将获得一组权重即 `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/iter_xxx.pth`,为了方便评测和对话,可以将其转换为官方权重。 | ||
|
||
```bash | ||
python xtuner/configs/internvl/v1_5/convert_to_official.py xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_lora_finetune.py ./work_dirs/iter_xxx.pth ./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/internvl_v1_5_phi3_4b/ | ||
``` | ||
|
||
此时,会在 `./work_dirs/internvl_v1_5_phi3_4b_lora_finetune/internvl_v1_5_phi3_4b/` 下生成一组包括配置的完整官方权重,你可以使用[官方工具链](https://github.com/OpenGVLab/InternVL)进行评测和对话。 | ||
|
||
如果你在使用中碰到任何问题,欢迎联系我们!!! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import argparse | ||
import os.path as osp | ||
|
||
import torch | ||
from mmengine.config import Config | ||
from transformers import AutoTokenizer | ||
|
||
from xtuner.model.utils import LoadWoInit | ||
from xtuner.registry import BUILDER | ||
|
||
|
||
def convert_to_official(config, trained_path, save_path): | ||
cfg = Config.fromfile(config) | ||
cfg.model.pretrained_pth = trained_path | ||
cfg.model.quantization_vit = False | ||
cfg.model.quantization_llm = False | ||
|
||
with LoadWoInit(): | ||
model = BUILDER.build(cfg.model) | ||
model.to(torch.bfloat16) | ||
|
||
if model.use_visual_encoder_lora: | ||
vision_model = model.model.vision_model.merge_and_unload() | ||
model.model.vision_model = vision_model | ||
|
||
if model.use_llm_lora: | ||
language_model = model.model.language_model.merge_and_unload() | ||
model.model.language_model = language_model | ||
|
||
model.model.save_pretrained(save_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
cfg.model.model_path, trust_remote_code=True) | ||
tokenizer.save_pretrained(save_path) | ||
|
||
print(model) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description='Convert the pth model to HuggingFace model') | ||
parser.add_argument('config', help='config file name or path.') | ||
parser.add_argument('trained_model_pth', help='The trained model path.') | ||
parser.add_argument( | ||
'save_path', help='The path to save the converted model.') | ||
args = parser.parse_args() | ||
|
||
if osp.realpath(args.trained_model_pth) == osp.realpath(args.save_path): | ||
raise ValueError( | ||
'The trained path and save path should not be the same.') | ||
|
||
convert_to_official(args.config, args.trained_model_pth, args.save_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.