Skip to content

Commit

Permalink
Add functionality to download models from sources other than HuggingF…
Browse files Browse the repository at this point in the history
…ace (#946)

support openmind model and dataset
  • Loading branch information
starmountain1997 authored Nov 8, 2024
1 parent 697bc77 commit 90192ff
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 2 deletions.
16 changes: 16 additions & 0 deletions docs/zh_cn/training/modify_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ XTuner 使用 MMEngine 的「纯 Python 风格的配置文件」,直接利用
1. 参考 `文档 <../preparation/pretrained_model.md>`__ 将其下载至本地
2. 修改\ ``pretrained_model_name_or_path``\

使用 openMind 模型?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
可在配置文件中新增 ``model_resource`` 参数, ``args`` 用作可变参数(如下载私有模型需传入token的情况):

.. code:: python
from openmind_hub import snapshot_download
# Model
pretrained_model_name_or_path = 'Tianjin_Ascend/Qwen1.5-4B'
model_resource = {
"fn": snapshot_download,
"args":{
# "token":"xxxxxxxxxx"
}
}
微调类型
-------------

Expand Down
2 changes: 2 additions & 0 deletions docs/zh_cn/training/open_source_dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ XTuner 使用上游库 ``datasets`` 的统一载入接口 ``load_dataset``\ 。
``dataset=dict(type=load_dataset, path=data_path)`` 中的 ``path``
参数即可。

若想使用 openMind 数据集,可将 ``dataset=dict(type=load_dataset, path=data_path)`` 中的 ``type`` 替换为 ``openmind.OmDataset``。


字段格式
--------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
from openmind_hub import snapshot_download
from openmind import OmDataset

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'Tianjin_Ascend/Qwen1.5-4B'
model_resource = {
"fn": snapshot_download,
"args":{
# "token":"xxxxxxxxxx"
}
}
use_varlen_attn = False

# Data
alpaca_en_path = 'AI_Connect/alpaca'
prompt_template = PROMPT_TEMPLATE.default
max_length = 2048
pack_to_max_length = True

# parallel
sequence_parallel_size = 1

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
# NPU does not support quantization
# quantization_config=dict(
# type=BitsAndBytesConfig,
# load_in_4bit=True,
# load_in_8bit=False,
# llm_int8_threshold=6.0,
# llm_int8_has_fp16_weight=False,
# bnb_4bit_compute_dtype=torch.float16,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type='nf4')
),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=OmDataset.load_dataset, path=alpaca_en_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length,
use_varlen_attn=use_varlen_attn)

sampler = SequenceParallelSampler \
if sequence_parallel_size > 1 else DefaultSampler

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=sampler, shuffle=True),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)
2 changes: 1 addition & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
__all__ = [
'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset',
'process_ms_dataset', 'LLaVADataset', 'expand2square',
'decode_base64_to_image', 'load_image', 'process_ms_dataset',
'decode_base64_to_image', 'load_image',
'load_intern_repo_tokenized_dataset',
'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
Expand Down
6 changes: 5 additions & 1 deletion xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict
from xtuner.registry import BUILDER, MAP_FUNC
from xtuner.tools.utils import (auto_dtype_of_deepspeed_config,
get_seed_from_checkpoint)
get_seed_from_checkpoint, set_model_resource)


def parse_args():
Expand Down Expand Up @@ -124,6 +124,9 @@ def check_cfg(cfg, args):
'deepspeed_zero3)`.')





def main():
args = parse_args()

Expand All @@ -136,6 +139,7 @@ def main():

# load config
cfg = Config.fromfile(args.config)
set_model_resource(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
Expand Down
18 changes: 18 additions & 0 deletions xtuner/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def get_streamer(model):
else:
return DecodeOutputStreamer

def set_model_resource(cfg):
if cfg.get("model_resource"):
fn = cfg["model_resource"].get("fn")
args = cfg["model_resource"].get("args", {})
local_path = fn(cfg["pretrained_model_name_or_path"], **args)
s = [(cfg._cfg_dict, k, v) for k, v in cfg._cfg_dict.items()]
while s:
current_d, current_k, current_v = s.pop()
if current_k == "pretrained_model_name_or_path":
current_d[current_k] = local_path

if isinstance(current_v, dict):
s.extend([(current_v, k, v) for k, v in current_v.items()])
elif isinstance(current_v, list):
for i in current_v:
if isinstance(i, dict):
s.extend((i, k, v) for k, v in i.items())


class DecodeOutputStreamer(BaseStreamer):
"""Default streamer for HuggingFace models."""
Expand Down

0 comments on commit 90192ff

Please sign in to comment.