Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat.] support encoder-decoder tuning and ChatGLM, Vicuna inference #152

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion examples/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,16 @@ def main():
pipeline_args=pipeline_args,
)
dataset = Dataset(data_args)
model = AutoModel.get_model(model_args)
model = AutoModel.get_model(
model_args,
lang=data_args.lang,
forced_bos_token=data_args.forced_bos_token,
source_prefix = data_args.source_prefix,
streaming = data_args.streaming,
preprocessing_num_workers = data_args.preprocessing_num_workers,
overwrite_cache = data_args.overwrite_cache,
max_source_length = data_args.max_source_length
)

# Tokenization and text grouping must be done in the main process
with pipeline_args.main_process_first(desc="dataset map tokenization"):
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ wandb==0.14.0
deepspeed==0.8.3
trl @ git+https://github.com/lvwerra/trl.git#egg=trl-0.4.1
sentencepiece
icetk==0.0.7
cpm_kernels==1.0.11
transformers @ git+https://github.com/huggingface/transformers@c612628
flask
flask_cors
flask_cors
17 changes: 17 additions & 0 deletions scripts/run_chatbot_seq2seq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

model=THUDM/chatglm-6b
lora_args=""
if [ $# -ge 1 ]; then
model=$1
fi
if [ $# -ge 2 ]; then
lora_args="--lora_model_path $2"
fi

CUDA_VISIBLE_DEVICES=0 \
deepspeed examples/chatbot.py \
--arch_type encoder_decoder \
--deepspeed configs/ds_config_chatbot.json \
--model_name_or_path ${model} \
${lora_args}
101 changes: 101 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TrainingArguments,
Seq2SeqTrainingArguments
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
Expand Down Expand Up @@ -99,6 +100,10 @@ class ModelArguments:
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
arch_type: bool = field(
default="decoder_only",
metadata={"help": "The architecture type of the model. Currently supported decoder_only or encoder_decoder"}
)
config_overrides: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -165,6 +170,15 @@ class ModelArguments:
default=True,
metadata={"help": "Whether use disk mapping when memory is not enough."}
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -225,6 +239,8 @@ class DatasetArguments:
each parameter, such as a help message.
"""

lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})

dataset_path: Optional[str] = field(
default=None, metadata={"help": "The path of the dataset to use."}
)
Expand Down Expand Up @@ -309,6 +325,83 @@ class DatasetArguments:
default=None,
metadata={"help": "Evaluation File Path"},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": (
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
)
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)

forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": (
"The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
)
},
)

def __post_init__(self):
if self.streaming:
Expand All @@ -332,6 +425,13 @@ class FinetunerArguments(TrainingArguments):
"""
pass

@dataclass
class Seq2SeqFinetunerArguments(Seq2SeqTrainingArguments):
"""
Adapt transformers.TrainingArguments
"""
pass


@dataclass
class EvaluatorArguments:
Expand Down Expand Up @@ -497,6 +597,7 @@ class InferencerArguments:


PIPELINE_ARGUMENT_MAPPING = {
"seq2seq_finetuner": Seq2SeqFinetunerArguments,
"finetuner": FinetunerArguments,
"evaluator": EvaluatorArguments,
"inferencer": InferencerArguments,
Expand Down
7 changes: 5 additions & 2 deletions src/lmflow/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""

from lmflow.models.hf_decoder_model import HFDecoderModel

from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel

class AutoModel:

@classmethod
def get_model(self, model_args, *args, **kwargs):
# TODO (add new models)
return HFDecoderModel(model_args, *args, **kwargs)
if model_args.arch_type == "encoder_decoder":
return HFEncoderDecoderModel(model_args, *args, **kwargs)
elif model_args.arch_type == "decoder_only":
return HFDecoderModel(model_args, *args, **kwargs)
22 changes: 22 additions & 0 deletions src/lmflow/models/encoder_decoder_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python
# coding=utf-8
"""A one-line summary of the module or program, terminated by a period.

Leave one blank line. The rest of this docstring should contain an
overall description of the module or program. Optionally, it may also
contain a brief desription of exported classes and funcctions and/or usage
examples.

Typical usage example:

foo = ClassFoo()
bar = foo.FunctionBar()
"""

from lmflow.models.base_model import BaseModel


class EncoderDecoderModel(BaseModel):

def __init__(self, *args, **kwargs):
pass
2 changes: 0 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@
from lmflow.models.decoder_model import DecoderModel
from lmflow.models.interfaces.tunable import Tunable


logger = logging.getLogger(__name__)


class HFDecoderModel(DecoderModel, Tunable):
r"""
Initializes a HFDecoderModel instance.
Expand Down
Loading