From 0e2bca93efeac616c70de9cf693c73da6aedea13 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 26 Apr 2024 21:02:07 +0800 Subject: [PATCH] update bos token. --- dpo_training.py | 3 ++- orpo_training.py | 5 +++++ ppo_training.py | 3 ++- reward_modeling.py | 3 ++- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/dpo_training.py b/dpo_training.py index 518d2ba..6fb3fe6 100644 --- a/dpo_training.py +++ b/dpo_training.py @@ -232,7 +232,7 @@ def main(): tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) prompt_template = get_conv_template(args.template_name) if tokenizer.eos_token_id is None: - tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT + tokenizer.eos_token = prompt_template.stop_str # eos token is required logger.info("Add eos token: {}".format(tokenizer.eos_token)) if tokenizer.pad_token_id is None: if tokenizer.unk_token_id is not None: @@ -240,6 +240,7 @@ def main(): else: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) + logger.debug(f"Tokenizer: {tokenizer}") # Get datasets if args.dataset_name is not None: diff --git a/orpo_training.py b/orpo_training.py index 3669c07..16b0064 100644 --- a/orpo_training.py +++ b/orpo_training.py @@ -237,12 +237,17 @@ def main(): tokenizer.eos_token = prompt_template.stop_str # eos token is required tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token}) logger.info(f"Add eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}") + if tokenizer.bos_token_id is None: + tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token}) + tokenizer.bos_token_id = tokenizer.eos_token_id + logger.info(f"Add bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}") if tokenizer.pad_token_id is None: if tokenizer.unk_token_id is not None: tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Add pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}") + logger.debug(f"Tokenizer: {tokenizer}") # Get datasets if args.dataset_name is not None: diff --git a/ppo_training.py b/ppo_training.py index 774652a..23bd087 100644 --- a/ppo_training.py +++ b/ppo_training.py @@ -238,7 +238,7 @@ def main(): tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) prompt_template = get_conv_template(args.template_name) if tokenizer.eos_token_id is None: - tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT + tokenizer.eos_token = prompt_template.stop_str # eos token is required logger.info("Add eos token: {}".format(tokenizer.eos_token)) if tokenizer.pad_token_id is None: if tokenizer.unk_token_id is not None: @@ -246,6 +246,7 @@ def main(): else: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) + logger.debug(f"Tokenizer: {tokenizer}") # Load model peft_config = None diff --git a/reward_modeling.py b/reward_modeling.py index bbabba2..b50b581 100644 --- a/reward_modeling.py +++ b/reward_modeling.py @@ -419,7 +419,7 @@ def main(): tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) prompt_template = get_conv_template(script_args.template_name) if tokenizer.eos_token_id is None: - tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT + tokenizer.eos_token = prompt_template.stop_str # eos token is required logger.info("Add eos token: {}".format(tokenizer.eos_token)) if tokenizer.pad_token_id is None: if tokenizer.unk_token_id is not None: @@ -427,6 +427,7 @@ def main(): else: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) + logger.debug(f"Tokenizer: {tokenizer}") if script_args.use_peft: logger.info("Fine-tuning method: LoRA(PEFT)")