-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][Feature] DPO #434
Open
amulil
wants to merge
19
commits into
InternLM:main
Choose a base branch
from
amulil:dpo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[WIP][Feature] DPO #434
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
d1ca875
init dpo file
amulil 67c9251
test dpo
amulil e0e4f54
[WIP] len_chosen不能为0
xiaohangguo ec18b27
xx
xiaohangguo cc2fbe3
update how to get the length of chosen answer
amulil df99211
[WIP] test dpo __getitem__
xiaohangguo 548cf79
Merge branch 'main' into dpo
xiaohangguo d121ca8
xx
xiaohangguo 5eb65dd
Merge branch 'dpo' of https://github.com/amulil/xtuner into dpo
xiaohangguo ca8881f
update dpo
amulil 90c8d20
update dpo
amulil 3fbdb60
fix loss nan problem
amulil 46e24be
add full dpo config
amulil aefd1de
support dpo with qlora
amulil 8797fbc
add dpo loss type
amulil b2ef99b
fix conflicts
hhaAndroid 231000f
update
hhaAndroid 26e878a
update
hhaAndroid 15e337f
update
amulil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
213 changes: 213 additions & 0 deletions
213
xtuner/configs/internlm/internlm2_chat_1_8b/internlm2_chat_1_8b_qlora_dpo_ultra_e3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
from datasets import load_dataset | ||
from mmengine.dataset import DefaultSampler | ||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | ||
LoggerHook, ParamSchedulerHook) | ||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR | ||
from peft import LoraConfig | ||
from torch.optim import AdamW | ||
from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||
BitsAndBytesConfig) | ||
|
||
from xtuner.dataset import DPODataset | ||
from xtuner.dataset.collate_fns import default_collate_fn | ||
from xtuner.dataset.map_fns import ultra_map_fn, template_map_fn_factory | ||
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, | ||
VarlenAttnArgsToMessageHubHook) | ||
from xtuner.engine.runner import TrainLoop | ||
from xtuner.model import DPO | ||
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE | ||
|
||
####################################################################### | ||
# PART 1 Settings # | ||
####################################################################### | ||
# Model | ||
pretrained_model_name_or_path = 'internlm/internlm2-chat-1_8b' | ||
use_varlen_attn = False | ||
|
||
# Data | ||
ultra_path = 'HuggingFaceH4/ultrachat_200k' | ||
prompt_template = PROMPT_TEMPLATE.internlm2_chat | ||
max_length = 2048 | ||
pack_to_max_length = True | ||
|
||
# Scheduler & Optimizer | ||
batch_size = 1 # per_device | ||
accumulative_counts = 16 | ||
dataloader_num_workers = 0 | ||
max_epochs = 3 | ||
optim_type = AdamW | ||
lr = 2e-4 | ||
betas = (0.9, 0.999) | ||
weight_decay = 0 | ||
max_norm = 1 # grad clip | ||
warmup_ratio = 0.03 | ||
|
||
# Save | ||
save_steps = 500 | ||
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) | ||
|
||
# Evaluate the generation performance during the training | ||
evaluation_freq = 500 | ||
SYSTEM = SYSTEM_TEMPLATE.alpaca | ||
evaluation_inputs = [ | ||
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' | ||
] | ||
|
||
####################################################################### | ||
# PART 2 Model & Tokenizer # | ||
####################################################################### | ||
tokenizer = dict( | ||
type=AutoTokenizer.from_pretrained, | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
trust_remote_code=True, | ||
padding_side='right') | ||
|
||
model = dict( | ||
type=DPO, # TODO | ||
use_varlen_attn=use_varlen_attn, | ||
llm=dict( | ||
type=AutoModelForCausalLM.from_pretrained, | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
trust_remote_code=True, | ||
torch_dtype=torch.float16, | ||
quantization_config=dict( | ||
type=BitsAndBytesConfig, | ||
load_in_4bit=True, | ||
load_in_8bit=False, | ||
llm_int8_threshold=6.0, | ||
llm_int8_has_fp16_weight=False, | ||
bnb_4bit_compute_dtype=torch.float16, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type='nf4')), | ||
lora=dict( | ||
type=LoraConfig, | ||
r=64, | ||
lora_alpha=16, | ||
lora_dropout=0.1, | ||
bias='none', | ||
task_type='CAUSAL_LM'), | ||
beta=0.1) | ||
|
||
####################################################################### | ||
# PART 3 Dataset & Dataloader # | ||
####################################################################### | ||
ultra = dict( | ||
type=DPODataset, # TODO | ||
data_path=ultra_path, | ||
tokenizer=tokenizer, | ||
max_length=max_length, | ||
dataset_map_fn=ultra_map_fn, | ||
template_map_fn=dict( | ||
type=template_map_fn_factory, template=prompt_template), | ||
remove_unused_columns=True, | ||
shuffle_before_pack=True, | ||
pack_to_max_length=pack_to_max_length, | ||
use_varlen_attn=use_varlen_attn) | ||
|
||
train_dataloader = dict( | ||
batch_size=batch_size, | ||
num_workers=dataloader_num_workers, | ||
dataset=ultra, | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) | ||
|
||
####################################################################### | ||
# PART 4 Scheduler & Optimizer # | ||
####################################################################### | ||
# optimizer | ||
optim_wrapper = dict( | ||
type=AmpOptimWrapper, | ||
optimizer=dict( | ||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | ||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | ||
accumulative_counts=accumulative_counts, | ||
loss_scale='dynamic', | ||
dtype='float16') | ||
|
||
# learning policy | ||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 | ||
param_scheduler = [ | ||
dict( | ||
type=LinearLR, | ||
start_factor=1e-5, | ||
by_epoch=True, | ||
begin=0, | ||
end=warmup_ratio * max_epochs, | ||
convert_to_iter_based=True), | ||
dict( | ||
type=CosineAnnealingLR, | ||
eta_min=0.0, | ||
by_epoch=True, | ||
begin=warmup_ratio * max_epochs, | ||
end=max_epochs, | ||
convert_to_iter_based=True) | ||
] | ||
|
||
# train, val, test setting | ||
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) | ||
|
||
####################################################################### | ||
# PART 5 Runtime # | ||
####################################################################### | ||
# Log the dialogue periodically during the training process, optional | ||
custom_hooks = [ | ||
dict(type=DatasetInfoHook, tokenizer=tokenizer), | ||
dict( | ||
type=EvaluateChatHook, | ||
tokenizer=tokenizer, | ||
every_n_iters=evaluation_freq, | ||
evaluation_inputs=evaluation_inputs, | ||
system=SYSTEM, | ||
prompt_template=prompt_template) | ||
] | ||
|
||
if use_varlen_attn: | ||
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] | ||
|
||
# configure default hooks | ||
default_hooks = dict( | ||
# record the time of every iteration. | ||
timer=dict(type=IterTimerHook), | ||
# print log every 10 iterations. | ||
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), | ||
# enable the parameter scheduler. | ||
param_scheduler=dict(type=ParamSchedulerHook), | ||
# save checkpoint per `save_steps`. | ||
checkpoint=dict( | ||
type=CheckpointHook, | ||
by_epoch=False, | ||
interval=save_steps, | ||
max_keep_ckpts=save_total_limit), | ||
# set sampler seed in distributed evrionment. | ||
sampler_seed=dict(type=DistSamplerSeedHook), | ||
) | ||
|
||
# configure environment | ||
env_cfg = dict( | ||
# whether to enable cudnn benchmark | ||
cudnn_benchmark=False, | ||
# set multi process parameters | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
# set distributed parameters | ||
dist_cfg=dict(backend='nccl'), | ||
) | ||
|
||
# set visualizer | ||
visualizer = None | ||
|
||
# set log level | ||
log_level = 'INFO' | ||
|
||
# load from which checkpoint | ||
load_from = None | ||
|
||
# whether to resume training from the loaded checkpoint | ||
resume = False | ||
|
||
# Defaults to use random seed and disable `deterministic` | ||
randomness = dict(seed=None, deterministic=False) | ||
|
||
# set log processor | ||
log_processor = dict(by_epoch=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import json | ||
import os | ||
|
||
import torch | ||
from datasets import Dataset as HFDataset | ||
from datasets import DatasetDict | ||
from mmengine.config import Config, ConfigDict | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
from xtuner.registry import BUILDER | ||
from .huggingface import process_hf_dataset | ||
from .utils import expand2square | ||
|
||
|
||
class DPODataset(Dataset): | ||
|
||
def __init__(self, | ||
data_path, | ||
tokenizer, | ||
max_dataset_length=None, | ||
dataset_map_fn=None, | ||
template_map_fn=None, | ||
max_length=2048): | ||
super().__init__() | ||
# TODO | ||
pass | ||
|
||
def __len__(self): | ||
# TODO | ||
pass | ||
|
||
def __getitem__(self, index): | ||
# TODO | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from collections import OrderedDict | ||
|
||
from mmengine.config import Config, ConfigDict | ||
from mmengine.model import BaseModel | ||
from mmengine.runner import load_checkpoint | ||
from peft import get_peft_model, prepare_model_for_kbit_training | ||
from torch import nn | ||
|
||
from xtuner.registry import BUILDER | ||
from .modules import dispatch_modules | ||
from .utils import (LoadWoInit, find_all_linear_names, | ||
get_peft_model_state_dict, make_inputs_require_grad, | ||
traverse_dict) | ||
|
||
|
||
class DPO(BaseModel): | ||
|
||
def __init__(self, | ||
llm, | ||
ref_llm=None, | ||
lora=None, | ||
peft_model=None, | ||
use_activation_checkpointing=True, | ||
use_varlen_attn=False): | ||
super().__init__() | ||
with LoadWoInit(): | ||
self.llm = self._build_from_cfg_or_module(llm) | ||
self.llm.config.use_cache = False | ||
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) | ||
|
||
if use_activation_checkpointing: | ||
# For backward compatibility | ||
if hasattr(self.llm, 'enable_input_require_grads'): | ||
self.llm.enable_input_require_grads() | ||
else: | ||
self.llm.get_input_embeddings().register_forward_hook( | ||
make_inputs_require_grad) | ||
|
||
# enable gradient checkpointing for memory efficiency | ||
self.gradient_checkpointing_enable() | ||
|
||
if isinstance(lora, dict) or isinstance(lora, Config) or isinstance( | ||
lora, ConfigDict): | ||
self.lora = BUILDER.build(lora) | ||
else: | ||
self.lora = lora | ||
self.peft_model = peft_model | ||
self.use_lora = lora is not None | ||
if self.use_lora: | ||
self._prepare_for_lora(peft_model, use_activation_checkpointing) | ||
|
||
self._is_init = True | ||
# Determines whether to calculate attention based on the | ||
# seq_len dimension (use_varlen_attn = False) or the actual length of | ||
# the sequence. | ||
self.use_varlen_attn = use_varlen_attn | ||
|
||
# TODO: Add ref model and ref model config | ||
self.ref_llm = None | ||
|
||
def gradient_checkpointing_enable(self): | ||
self.activation_checkpointing_enable() | ||
|
||
def activation_checkpointing_enable(self): | ||
self.llm.gradient_checkpointing_enable() | ||
|
||
def gradient_checkpointing_disable(self): | ||
self.activation_checkpointing_disable() | ||
|
||
def activation_checkpointing_disable(self): | ||
self.llm.gradient_checkpointing_disable() | ||
|
||
def _prepare_for_lora(self, | ||
peft_model=None, | ||
use_activation_checkpointing=True): | ||
self.llm = prepare_model_for_kbit_training( | ||
self.llm, use_activation_checkpointing) | ||
if self.lora.target_modules is None: | ||
modules = find_all_linear_names(self.llm) | ||
self.lora.target_modules = modules | ||
|
||
self.llm = get_peft_model(self.llm, self.lora) | ||
if peft_model is not None: | ||
_ = load_checkpoint(self, peft_model) | ||
|
||
def init_weights(self): | ||
pass | ||
|
||
def _build_from_cfg_or_module(self, cfg_or_mod): | ||
if isinstance(cfg_or_mod, nn.Module): | ||
return cfg_or_mod | ||
elif isinstance(cfg_or_mod, dict): | ||
traverse_dict(cfg_or_mod) | ||
return BUILDER.build(cfg_or_mod) | ||
else: | ||
raise NotImplementedError | ||
|
||
def forward(self, data, data_samples=None, mode='loss'): | ||
|
||
if mode == 'loss': | ||
return self.compute_loss(data, data_samples) | ||
elif mode == 'predict': | ||
return self.predict(data, data_samples) | ||
elif mode == 'tensor': | ||
return self._forward(data, data_samples) | ||
else: | ||
raise NotImplementedError | ||
|
||
def _forward(self, data, data_samples=None): | ||
|
||
outputs = self.llm(**data) | ||
|
||
return outputs | ||
|
||
def predict(self, data, data_samples=None): | ||
outputs = self.llm(**data) | ||
logits_dict = [{'logits': logits} for logits in outputs.logits] | ||
return logits_dict | ||
|
||
def compute_loss(self, data, data_samples=None): | ||
# TODO | ||
pass | ||
|
||
def state_dict(self, *args, **kwargs): | ||
state_dict = super().state_dict(*args, **kwargs) | ||
if not self.use_lora: | ||
return state_dict | ||
to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict) | ||
return OrderedDict(to_return) | ||
|
||
def __getattr__(self, name: str): | ||
try: | ||
return super().__getattr__(name) | ||
except AttributeError: | ||
return getattr(self.llm, name) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ref_llm, 也支持 api model