Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support balanced dataset to speed-up VL training #906

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import AutoTokenizer

from xtuner.dataset import InternVL_V1_5_Dataset, BalancedDataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.hooks import DatasetInfoHook, VarlenAttnArgsToMessageHubHook
from xtuner.engine.runner import TrainLoop
from xtuner.model import InternVL_V1_5
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
path = '/model/internvl'

# Data
data_path='pack_internvl_sft_1.2M.json'
'''about the data pack_internvl_sft_1.2M.json
use the scripts data_preprocess_stastics.sh in the xtuner/tools
to generate pack_internvl_sft_1.2M.json
'''
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 4096

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 4
dataloader_num_workers = 8
max_epochs = 1
optim_type = AdamW
# official 1024 -> 4e-5
lr = 1e-6
betas = (0.9, 0.999)
weight_decay = 0.05
max_norm = 1 # grad clip
warmup_ratio = 0.03

# Save
save_steps = 1000
save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)

#######################################################################
# PART 2 Model & Tokenizer & Image Processor #
#######################################################################
model = dict(
type=InternVL_V1_5,
model_path=path,
freeze_llm=False,
freeze_visual_encoder=False, # or False
use_varlen_attn=True
)

evaluation_freq = 0
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
llava_dataset = dict(
type=BalancedDataset,
model_path=path,
data_path=data_path,
vit_packed_length=9, # The value for vit packed length
llm_packed_length=4096, # The value for llm packed length
llm_thresh=4068, # The value for llm thresh
template=prompt_template,
max_length=max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=llava_dataset,
sampler=dict(
type=LengthGroupedSampler,
length_property='modality_length',
per_device_batch_size=batch_size * accumulative_counts),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=True, balance_data=True))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=path,
trust_remote_code=True)

custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
]

custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] # vallen_attention 依赖的 Hook

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
save_optimizer=False,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

3 changes: 2 additions & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .refcoco_json import (InvRefCOCOJsonDataset, RefCOCOJsonDataset,
RefCOCOJsonEvalDataset)
from .utils import decode_base64_to_image, expand2square, load_image
from .fast_dataset import BalancedDataset

# ignore FutureWarning in hf datasets
warnings.simplefilter(action='ignore', category=FutureWarning)
Expand All @@ -25,5 +26,5 @@
'load_intern_repo_tokenized_dataset',
'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
'load_json_file', 'InternVL_V1_5_Dataset'
'load_json_file', 'InternVL_V1_5_Dataset', 'BalancedDataset'
]
6 changes: 4 additions & 2 deletions xtuner/dataset/collate_fns/default_collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
def default_collate_fn(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False,
use_varlen_attn: bool = False):
use_varlen_attn: bool = False,
balance_data: bool = False):
seq_parallel_world_size = get_sequence_parallel_world_size()

input_ids, labels = [], []
Expand All @@ -22,7 +23,7 @@ def default_collate_fn(instances: Sequence[Dict],
assert len(instances) == 1, (
f'If utilizing varlen attention, the batch size should be'
f' set to 1, but got {len(instances)}')
assert not has_image, 'Currently, it is not configured to '
assert not has_image or balance_data, 'Currently, it is not configured to '
'accommodate the use of varlen Attention in multimodal training'

if has_image:
Expand All @@ -39,6 +40,7 @@ def default_collate_fn(instances: Sequence[Dict],
pixel_values.append(example['pixel_values'])

ori_length = [len(ids) for ids in input_ids]

if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
Expand Down
Loading