Skip to content

Commit

Permalink
[Feature] Support the DatasetInfoHook of DPO training (#787)
Browse files Browse the repository at this point in the history
* [Feature] Support the DatasetInfoHook of DPO training

* fix yapf check
  • Loading branch information
xu-song authored Jul 11, 2024
1 parent 48df4c8 commit b92481f
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 21 deletions.
4 changes: 2 additions & 2 deletions xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
orpo_dpo_mix_40k_map_fn)
from xtuner.engine.hooks import (EvaluateChatHook,
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
Expand Down Expand Up @@ -141,7 +141,7 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
orpo_dpo_mix_40k_map_fn)
from xtuner.engine.hooks import (EvaluateChatHook,
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
Expand Down Expand Up @@ -151,7 +151,7 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
load_jsonl_dataset)
from xtuner.engine.hooks import (EvaluateChatHook,
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
Expand Down Expand Up @@ -155,7 +155,7 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
orpo_dpo_mix_40k_map_fn)
from xtuner.engine.hooks import (EvaluateChatHook,
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
Expand Down Expand Up @@ -170,7 +170,7 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
preference_collate_fn
from xtuner.dataset.preference_dataset import (build_preference_dataset,
orpo_dpo_mix_40k_map_fn)
from xtuner.engine.hooks import (EvaluateChatHook,
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model.dpo import DPO
Expand Down Expand Up @@ -170,7 +170,7 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
Expand Down
29 changes: 18 additions & 11 deletions xtuner/engine/hooks/dataset_info_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,26 @@ def __init__(self, tokenizer, is_intern_repo_dataset=False):
self.is_intern_repo_dataset = is_intern_repo_dataset

def log(self, runner, dataset, mode='train'):

def _log(input_ids, log_prefix=''):
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
# Try to split list to be compatible with IMAGE token
input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX)
text = log_prefix
for idx, ids in enumerate(input_ids):
text += self.tokenizer.decode(ids)
if idx != len(input_ids) - 1:
text += DEFAULT_IMAGE_TOKEN
runner.logger.info(text)

runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
input_ids = dataset[0]['input_ids']
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
# Try to split list to be compatible with IMAGE token
input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX)
text = ''
for idx, ids in enumerate(input_ids):
text += self.tokenizer.decode(ids)
if idx != len(input_ids) - 1:
text += DEFAULT_IMAGE_TOKEN
runner.logger.info(text)
if 'chosen_ids' in dataset[0]:
_log(dataset[0]['chosen_ids'], log_prefix='chosen: ')
_log(dataset[0]['rejected_ids'], log_prefix='rejected: ')
else:
_log(dataset[0]['input_ids'])

def before_train(self, runner) -> None:
do_train = runner.train_loop is not None
Expand Down

0 comments on commit b92481f

Please sign in to comment.