Skip to content

Commit

Permalink
Merge pull request #857 from OptimalScale/yizhenjia-maintenance
Browse files Browse the repository at this point in the history
Usability update
  • Loading branch information
wheresmyhair committed Jun 14, 2024
2 parents f9116ab + f9d99e1 commit d7055e3
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 78 deletions.
3 changes: 1 addition & 2 deletions scripts/run_reward_modeling.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
#!/bin/bash
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
# Parses arguments
model_name_or_path=google/gemma-2b-it
Expand Down
3 changes: 1 addition & 2 deletions scripts/run_reward_modeling_with_lisa.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
#!/bin/bash
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
# Parses arguments
model_name_or_path=google/gemma-2b-it
Expand Down
3 changes: 1 addition & 2 deletions scripts/run_reward_modeling_with_lora.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
#!/bin/bash
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
# Parses arguments
model_name_or_path=google/gemma-2b-it
Expand Down
12 changes: 12 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class ModelArguments:
arch_type : str
Model architecture type.
padding_side : str
The side on which the tokenizer should have padding applied.
"""

model_name_or_path: Optional[str] = field(
Expand Down Expand Up @@ -296,6 +298,16 @@ class ModelArguments:
"choices": [None, "left", "right"],
},
)
padding_side: str = field(
default='right',
metadata={
"help": (
"The side on which the tokenizer should have padding applied. "
"LMFlow uses right padding by default. When set to `auto`, will "
"use padding_side from tokenizer.padding_side."),
"choices": ["right", "left", "auto"],
}
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down
1 change: 1 addition & 0 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def tokenize(self, dataset, add_special_tokens=True, *args, **kwargs):
(
raw_datasets.get_fingerprint()
+ str(self.tokenizer)
+ f'###padding_side={self.tokenizer.padding_side}'
+ ('###conversation_template=' + str(conversation_template) if "conversation" in dataset_type else "")
+ f'###disable_group_texts={data_args.disable_group_texts}'
+ f'###block_size={data_args.block_size}'
Expand Down
48 changes: 36 additions & 12 deletions src/lmflow/models/hf_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import os
import logging
from typing import Union, Optional
from typing import Union, Optional, Dict

import torch
import deepspeed
Expand All @@ -30,6 +30,7 @@
from lmflow.utils.constants import (
LMFLOW_LORA_TARGET_MODULES_MAPPING
)
from lmflow.args import ModelArguments


logger = logging.getLogger(__name__)
Expand All @@ -51,11 +52,12 @@
class HFModelMixin(BaseModel):
def __init__(
self,
model_args,
model_args: ModelArguments,
do_train: bool,
ds_config=None,
device: Optional[str]="gpu",
use_accelerator: bool=False,
hf_auto_model_additional_args: Optional[Dict]=None,
*args,
**kwargs
):
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(
self.model_args = model_args
self.tokenizer = self.__prepare_tokenizer(model_args)
self.torch_dtype = self.__prepare_dtype(model_args)
self.hf_model_config = self.__prepare_model_config(model_args)
self.hf_model_config = self.__prepare_model_config(model_args, hf_auto_model_additional_args)
self.quant_config = self.__prepare_quant_config(model_args)
self.peft_config = self.__prepare_peft_config(model_args)

Expand All @@ -106,11 +108,13 @@ def __init__(
self.tokenizer.eos_token_id = self.backend_model.config.eos_token_id
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if self.backend_model.config.pad_token_id is None:
self.backend_model.config.pad_token_id = self.tokenizer.pad_token_id


def __prepare_tokenizer(
self,
model_args
model_args: ModelArguments,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand All @@ -119,6 +123,8 @@ def __prepare_tokenizer(
"use_auth_token": True if model_args.use_auth_token else None,
"trust_remote_code": model_args.trust_remote_code,
}
if model_args.padding_side != 'auto':
tokenizer_kwargs["padding_side"] = model_args.padding_side

try:
if model_args.tokenizer_name:
Expand Down Expand Up @@ -163,7 +169,7 @@ def __prepare_tokenizer(

def __prepare_dtype(
self,
model_args
model_args: ModelArguments,
) -> torch.dtype:
if model_args.arch_type == 'text_regression':
if model_args.torch_dtype in ["auto", None, "bf16", "bfloat16"]:
Expand All @@ -189,8 +195,23 @@ def __prepare_dtype(

def __prepare_model_config(
self,
model_args
model_args: ModelArguments,
hf_auto_model_additional_args: Optional[Dict]=None,
):
"""Prepare model configuration for hf auto register,
Parameters
----------
model_args : ModelArguments
LMFlow model arguments.
hf_auto_model_additional_args : Optional[Dict], optional
Special configurations such as `num_labels` in `AutoModelForSequenceClassification`
(commonly used in reward modeling) will not preset in __prepare_model_config,
so it should be passed in hf_auto_model_additional_args.
Returns
-------
config : ModelConfig
hf model config.
"""
config_kwargs = {
"torch_dtype": self.torch_dtype,
"attn_implementation": "flash_attention_2" if model_args.use_flash_attention else None,
Expand All @@ -200,6 +221,9 @@ def __prepare_model_config(
"trust_remote_code": model_args.trust_remote_code,
"from_tf": bool(".ckpt" in model_args.model_name_or_path),
}
if hf_auto_model_additional_args is not None:
config_kwargs.update(hf_auto_model_additional_args)

if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
Expand All @@ -217,7 +241,7 @@ def __prepare_model_config(

def __prepare_quant_config(
self,
model_args
model_args: ModelArguments,
):
quant_config = None
if model_args.use_qlora:
Expand All @@ -236,7 +260,7 @@ def __prepare_quant_config(

def __prepare_peft_config(
self,
model_args
model_args: ModelArguments,
):
peft_config = None
if model_args.use_lora:
Expand Down Expand Up @@ -267,7 +291,7 @@ def __prepare_peft_config(

def __model_module_inject(
self,
model_args
model_args: ModelArguments,
) -> None:
"""Override some model modules with custom implementations.
Expand All @@ -286,8 +310,8 @@ def __model_module_inject(

def __prepare_model_for_training(
self,
model_args,
hf_auto_model: HF_AUTOMODEL_TYPE
model_args: ModelArguments,
hf_auto_model: HF_AUTOMODEL_TYPE,
):
# TODO: change to accelerate
logger.info("Preparing model for training")
Expand Down Expand Up @@ -326,7 +350,7 @@ def __prepare_model_for_training(

def __prepare_model_for_inference(
self,
model_args,
model_args: ModelArguments,
hf_auto_model: HF_AUTOMODEL_TYPE,
use_accelerator,
ds_config
Expand Down
47 changes: 37 additions & 10 deletions src/lmflow/models/hf_text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
from lmflow.models.interfaces.tunable import Tunable
from lmflow.models.hf_model_mixin import HFModelMixin
from lmflow.models.text_regression_model import TextRegressionModel
from lmflow.tokenization.hf_text_regression_model import tokenize_function
from lmflow.tokenization.hf_text_regression_model import paired_conversation_tokenize_function, tokenize_function
from lmflow.utils.conversation_template import PRESET_TEMPLATES
from lmflow.utils.constants import (
PAIRED_CONVERSATION_DATASET_DESCRIPTION,
TEXT2TEXT_DATASET_DESCRIPTION,
TEXT_ONLY_DATASET_DESCRIPTION,
CONVERSATION_ROLE_NAMES,
)

Expand Down Expand Up @@ -81,13 +83,15 @@ def __init__(
:param tune_strategy: tuning strategy: normal, none, lora or adapter
:param ds_config: deepspeed configuration for distributed training
"""
config_additional_args = {"num_labels": 1}
HFModelMixin.__init__(
self,
model_args=model_args,
do_train=True if tune_strategy == "normal" else False,
ds_config=ds_config,
device=device,
use_accelerator=use_accelerator,
hf_auto_model_additional_args=config_additional_args,
*args,
**kwargs
)
Expand Down Expand Up @@ -133,14 +137,28 @@ def tokenize(
raw_datasets = dataset
hf_raw_datasets = dataset.get_backend_dataset()
column_names = list(hf_raw_datasets.features) # in paired conversation, for example, would be 'chosen' and 'rejected'

# since this will be pickled to avoid _LazyModule error in Hasher force
# logger loading before tokenize_function
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

data_args = raw_datasets.get_data_args()

if dataset_type == "paired_conversation":
# Requires three types of information for tokenizing different datasets
# 1) Which fields require tokenization, e.g.
# "text2float": "text", but not "float"
# "text2text": both "input" and "output"
# 2) How will there tokenized sequence concatenated together, e.g.
# "text_only": "text" -> "text"
# "text2text": "input", "output" -> "input" + "output"
# 3) Which fields require loss in final computation, e.g.
# "text_only": "text"
# "text2text": "output" only
tokenized_column_order = None # Handles 1) and 2)
label_columns = None # Handles 3)
if dataset_type == "text_only":
tokenized_column_order = ["text"]
label_columns = ["text"]
elif dataset_type == "text2text":
tokenized_column_order = ["input", "output"]
label_columns = ["output"]
add_special_tokens = False
elif dataset_type == "paired_conversation":
if data_args.conversation_template:
if data_args.conversation_template in PRESET_TEMPLATES.keys():
conversation_template = PRESET_TEMPLATES[data_args.conversation_template]
Expand All @@ -157,28 +175,37 @@ def tokenize(
raise NotImplementedError(
f"Dataset type \"{dataset_type}\" is not supported, currently"
" only support following data types for HFTextRegressionModel:\n"
f" {PAIRED_CONVERSATION_DATASET_DESCRIPTION}\n"
f" 1) {TEXT_ONLY_DATASET_DESCRIPTION}\n"
f" 2) {TEXT2TEXT_DATASET_DESCRIPTION}\n"
f" 3) {PAIRED_CONVERSATION_DATASET_DESCRIPTION}\n"
)

# Whether to truncate long sequences to fit into max_length
use_truncation = False
if model_args.use_lora or data_args.disable_group_texts:
use_truncation = True

tokenize_fn = tokenize_function
tokenize_fn = paired_conversation_tokenize_function if "conversation" in dataset_type else tokenize_function
tokenize_fn_kwargs = {
"data_args": data_args,
"tokenizer": self.tokenizer,
"column_names": column_names,
"conversation_template": conversation_template
}
if "conversation" in dataset_type:
tokenize_fn_kwargs["conversation_template"] = conversation_template
else:
tokenize_fn_kwargs["label_columns"] = label_columns
tokenize_fn_kwargs["tokenized_column_order"] = tokenized_column_order
tokenize_fn_kwargs["add_special_tokens"] = add_special_tokens
tokenize_fn_kwargs["use_truncation"] = use_truncation

tokenize_kwargs = {}
if not data_args.streaming:
fingerprint = hashlib.md5(
(
raw_datasets.get_fingerprint()
+ str(self.tokenizer)
+ f'###padding_side={self.tokenizer.padding_side}'
+ ('###conversation_template=' + str(conversation_template) if "conversation" in dataset_type else "")
+ f'###disable_group_texts={data_args.disable_group_texts}'
+ f'###block_size={data_args.block_size}'
Expand Down
7 changes: 7 additions & 0 deletions src/lmflow/pipeline/rm_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ def switch_active_layers(self):
elif last_checkpoint is not None:
checkpoint = last_checkpoint

if self.finetuner_args.gradient_checkpointing:
if model.get_backend_model().config.use_cache:
logger.warning(
"Backend model config `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
model.get_backend_model().config.use_cache = False

train_result = trainer.train(resume_from_checkpoint=checkpoint)

trainer.save_model() # Saves the tokenizer too for easy upload
Expand Down
Loading

0 comments on commit d7055e3

Please sign in to comment.