Skip to content

Commit

Permalink
update bos token.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Apr 26, 2024
1 parent ae0af39 commit 0e2bca9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion dpo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,15 @@ def main():
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
prompt_template = get_conv_template(args.template_name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
tokenizer.eos_token = prompt_template.stop_str # eos token is required
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.debug(f"Tokenizer: {tokenizer}")

# Get datasets
if args.dataset_name is not None:
Expand Down
5 changes: 5 additions & 0 deletions orpo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,17 @@ def main():
tokenizer.eos_token = prompt_template.stop_str # eos token is required
tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token})
logger.info(f"Add eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}")
if tokenizer.bos_token_id is None:
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id
logger.info(f"Add bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}")
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Add pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}")
logger.debug(f"Tokenizer: {tokenizer}")

# Get datasets
if args.dataset_name is not None:
Expand Down
3 changes: 2 additions & 1 deletion ppo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,15 @@ def main():
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
prompt_template = get_conv_template(args.template_name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
tokenizer.eos_token = prompt_template.stop_str # eos token is required
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.debug(f"Tokenizer: {tokenizer}")

# Load model
peft_config = None
Expand Down
3 changes: 2 additions & 1 deletion reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,15 @@ def main():
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
prompt_template = get_conv_template(script_args.template_name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
tokenizer.eos_token = prompt_template.stop_str # eos token is required
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.debug(f"Tokenizer: {tokenizer}")

if script_args.use_peft:
logger.info("Fine-tuning method: LoRA(PEFT)")
Expand Down

0 comments on commit 0e2bca9

Please sign in to comment.