-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Integrated Training and Inference -- Part 1 #532
Open
pppppM
wants to merge
32
commits into
InternLM:main
Choose a base branch
from
pppppM:refactor-llm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
360391a
[Feature] Support Sequence parallel (#456)
HIT-cwh 5fffd8c
integrated norm chat finetune and inference
pppppM 0f31481
Merge branch 'main' into refactor-llm
pppppM f1111a9
Merge branch 'main' into refactor-llm
pppppM 7edb76d
remove encoder
pppppM c5f38d2
add open-source dataset convert tool
pppppM e0ab003
fix shard count
pppppM b0e71b1
fix dataset bugs
pppppM 5da532c
add alpaca example
pppppM 5cfd71f
refactored the inheritance hierarchy
pppppM 2e1d238
adjust dir structure
pppppM 046e943
add BaseAlogrithm docstrings
pppppM 68afab4
add dataset docstring
pppppM cf8e8af
add pack dataset docstrings
pppppM 9dc1142
remove old collate fn
pppppM 662cebb
Merge branch 'main' into refactor-llm
pppppM 7ac1e0f
add new chat hook
pppppM 0ad84f2
add gradient disable interface
pppppM e85d176
add llava dataset example
pppppM 9f39627
batch_infer is no longer an abstract method
pppppM c0655a1
support auto model
pppppM fd9ecca
rename
pppppM 15860f9
update auto model
pppppM 7236d40
refactor dataset
pppppM 4db2955
enhance dataset convert
pppppM aad9ee3
remove useless code
pppppM 8daafcb
diff files support diff sample ratios
pppppM df60d91
unified naming
pppppM 6185c9b
Merge branch 'main' of github.com:InternLM/xtuner into refactor-llm
HIT-cwh ca272bf
support sp in TextFinetune
HIT-cwh b658d76
Merge pull request #2 from HIT-cwh/refactor-llm
pppppM d2f1002
Merge branch 'main' into refactor-llm
pppppM File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .huggingface import HFBackend | ||
from .lmdeploy import LMDeployBackend | ||
|
||
__all__ = ['HFBackend', 'LMDeployBackend'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from abc import abstractmethod | ||
from typing import List, Optional | ||
|
||
from xtuner.chat.streamer import SteamerType | ||
from xtuner.types import (ChatBackendProtocol, ChatMessages, ChatTemplate, | ||
SampleParams) | ||
|
||
|
||
class BaseBackend(ChatBackendProtocol): | ||
|
||
@property | ||
def chat_template(self) -> ChatTemplate: | ||
pass | ||
|
||
@abstractmethod | ||
def create_streamer(self, iterable: bool = False) -> SteamerType: | ||
pass | ||
|
||
@abstractmethod | ||
def chat(self, | ||
messages: ChatMessages, | ||
sample_params: Optional[SampleParams] = None, | ||
streamer: Optional[SteamerType] = None): | ||
pass | ||
|
||
@abstractmethod | ||
def batch_infer(self, | ||
messages: List[ChatMessages], | ||
sample_params: Optional[SampleParams] = None): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from peft import PeftModel | ||
from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||
BitsAndBytesConfig) | ||
from transformers import GenerationConfig as HFGenerationConfig | ||
from transformers import PreTrainedModel, PreTrainedTokenizer | ||
|
||
from xtuner.chat.streamer import HFTextIteratorStreamer, HFTextStreamer | ||
from xtuner.model.utils import LoadWoInit | ||
from xtuner.tools.utils import get_stop_criteria | ||
from xtuner.types import ChatMessages, ChatTemplate, SampleParams | ||
from .base import BaseBackend | ||
|
||
|
||
class HFBackend(BaseBackend): | ||
|
||
def __init__( | ||
self, | ||
chat_template: ChatTemplate, | ||
llm: PreTrainedModel, | ||
tokenizer: PreTrainedTokenizer, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.llm = llm | ||
self.llm.cuda() | ||
self.tokenizer = tokenizer | ||
|
||
self._chat_template = chat_template | ||
|
||
@property | ||
def chat_template(self) -> ChatTemplate: | ||
return self._chat_template | ||
|
||
@property | ||
def eos_token_id(self): | ||
if self.tokenizer.pad_token_id: | ||
return self.tokenizer.eos_token_id | ||
else: | ||
return self.tokenizer.eos_token_id | ||
|
||
@property | ||
def pad_token_id(self): | ||
return self.tokenizer.pad_token_id | ||
|
||
def build_llm_and_tokenizer(self, | ||
model_name_or_path, | ||
adapter=None, | ||
bits=None): | ||
|
||
if bits is None: | ||
quantization_config = None | ||
load_in_8bit = False | ||
elif bits == 4: | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
load_in_8bit=False, | ||
llm_int8_threshold=6.0, | ||
llm_int8_has_fp16_weight=False, | ||
bnb_4bit_compute_dtype=torch.float16, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type='nf4') | ||
load_in_8bit = False | ||
elif bits == 8: | ||
quantization_config = None | ||
load_in_8bit = True | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_name_or_path, | ||
trust_remote_code=True, | ||
encode_special_tokens=True) | ||
|
||
with LoadWoInit(): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name_or_path, | ||
device_map='auto', | ||
load_in_8bit=load_in_8bit, | ||
quantization_config=quantization_config, | ||
trust_remote_code=True, | ||
torch_dtype=torch.float16) | ||
|
||
if adapter is not None: | ||
model = PeftModel.from_pretrained(model, adapter) | ||
|
||
model.eval() | ||
return model, tokenizer | ||
|
||
def create_streamer(self, iterable=False): | ||
if iterable: | ||
return HFTextIteratorStreamer( | ||
self.tokenizer, | ||
skip_prompt=True, | ||
chat_template=self.chat_template) | ||
else: | ||
return HFTextStreamer( | ||
self.tokenizer, | ||
skip_prompt=True, | ||
chat_template=self.chat_template) | ||
|
||
def parse_sample_params(self, params: SampleParams) -> HFGenerationConfig: | ||
|
||
if params is None: | ||
params = SampleParams() | ||
|
||
hf_gen_config = HFGenerationConfig( | ||
max_new_tokens=params.max_new_tokens, | ||
do_sample=params.temperature > 0, | ||
temperature=params.temperature, | ||
top_k=params.top_k, | ||
top_p=params.top_p, | ||
repetition_penalty=params.repetition_penalty, | ||
seed=params.seed, | ||
eos_token_id=self.eos_token_id, | ||
pad_token_id=self.pad_token_id) | ||
|
||
stop_words = params.stop_words | ||
stop_words.extend(self.chat_template.stop_words) | ||
|
||
return hf_gen_config, stop_words | ||
|
||
def chat(self, | ||
messages: ChatMessages, | ||
streamer=None, | ||
sample_params: Optional[SampleParams] = None): | ||
|
||
prompt = messages.get_prompt(self.chat_template) | ||
ids = self.tokenizer.encode(prompt, return_tensors='pt') | ||
|
||
hf_gen_config, stop_words = self.parse_sample_params(sample_params) | ||
|
||
stop_criteria = get_stop_criteria( | ||
tokenizer=self.tokenizer, stop_words=stop_words) | ||
|
||
generate_output = self.llm.generate( | ||
inputs=ids.cuda(), | ||
streamer=streamer, | ||
generation_config=hf_gen_config, | ||
stopping_criteria=stop_criteria) | ||
|
||
output = self.tokenizer.decode( | ||
generate_output[0][len(ids[0]):], skip_special_tokens=True) | ||
|
||
for word in stop_words: | ||
output = output.rstrip(word) | ||
|
||
return output | ||
|
||
def batch_infer(self, | ||
messages: List[ChatMessages], | ||
sample_params: SampleParams | None = None): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .backend import LMDeployBackend | ||
|
||
__all__ = ['LMDeployBackend'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from lmdeploy.serve.async_engine import AsyncEngine | ||
|
||
from xtuner.types import ChatMessages, ChatTemplate | ||
|
||
|
||
class _AsyncEngine(AsyncEngine): | ||
"""Async inference engine.""" | ||
|
||
def __init__(self, chat_template: ChatTemplate, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
assert self.model_name == 'base' | ||
self.chat_template = chat_template | ||
|
||
async def _get_prompt_input(self, prompt: ChatMessages, | ||
do_preprocess: bool, sequence_start: bool): | ||
"""get input_ids, embeddings and offsets.""" | ||
|
||
decorated = prompt.get_prompt(self.chat_template) | ||
|
||
results = {} | ||
|
||
input_ids = self.tokenizer.encode(decorated, add_bos=sequence_start) | ||
|
||
results['input_ids'] = input_ids | ||
results['prompt'] = decorated | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import asyncio | ||
import os | ||
from typing import List, Optional, Union | ||
|
||
from lmdeploy.utils import get_logger | ||
|
||
from xtuner.types import ChatMessages, ChatTemplate, SampleParams | ||
from ...streamer import LMDeployTextIteratorStreamer, LMDeployTextStreamer | ||
from ..base import BaseBackend | ||
from ._engine import _AsyncEngine | ||
|
||
os.environ['TM_LOG_LEVEL'] = 'ERROR' | ||
logger = get_logger('lmdeploy') | ||
logger.setLevel('ERROR') | ||
|
||
_StreamerType = Union[LMDeployTextStreamer, LMDeployTextIteratorStreamer] | ||
|
||
|
||
class LMDeployBackend(BaseBackend): | ||
|
||
def __init__(self, chat_template, llm_name_or_path) -> None: | ||
super().__init__() | ||
|
||
self._engine = _AsyncEngine( | ||
chat_template, model_path=llm_name_or_path, model_name='base') | ||
|
||
self._chat_template = chat_template | ||
|
||
@property | ||
def chat_template(self) -> ChatTemplate: | ||
return self._chat_template | ||
|
||
def create_streamer(self, iterable=False): | ||
|
||
if iterable: | ||
return LMDeployTextIteratorStreamer() | ||
else: | ||
return LMDeployTextStreamer() | ||
|
||
def parse_sample_params(self, params: SampleParams): | ||
|
||
if params is None: | ||
params = SampleParams() | ||
|
||
stop_words = params.stop_words | ||
stop_words.extend(self.chat_template.stop_words) | ||
|
||
from lmdeploy.messages import GenerationConfig as LMDGenerationConfig | ||
lmd_gen_config = LMDGenerationConfig( | ||
max_new_tokens=params.max_new_tokens, | ||
temperature=params.temperature, | ||
top_k=params.top_k, | ||
top_p=params.top_p, | ||
repetition_penalty=params.repetition_penalty, | ||
random_seed=params.seed, | ||
stop_words=stop_words) | ||
|
||
return lmd_gen_config | ||
|
||
def chat(self, | ||
messages: ChatMessages, | ||
streamer: Optional[_StreamerType] = None, | ||
sample_params: Optional[SampleParams] = None): | ||
|
||
lmd_gen_config = self.parse_sample_params(sample_params) | ||
self.session_id += 1 | ||
import random | ||
|
||
generator = self._engine.generate( | ||
messages, random.randint(1, 100000), gen_config=lmd_gen_config) | ||
|
||
async def get_response(): | ||
out = '' | ||
async for res in generator: | ||
out += res.response | ||
if streamer: | ||
streamer.put(res.response) | ||
if streamer: | ||
streamer.end() | ||
return out | ||
|
||
loop = asyncio.new_event_loop() | ||
response = loop.run_until_complete(get_response()) | ||
return response | ||
|
||
def batch_infer(self, | ||
messages: List[ChatMessages], | ||
sample_params: Optional[SampleParams] = None): | ||
|
||
lmd_gen_config = self.parse_sample_params(sample_params) | ||
|
||
results = self._engine.batch_infer(messages, gen_config=lmd_gen_config) | ||
|
||
return [r.text for r in results] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from typing import Union | ||
|
||
from .huggingface import HFTextIteratorStreamer, HFTextStreamer | ||
from .lmdeploy import LMDeployTextIteratorStreamer, LMDeployTextStreamer | ||
|
||
SteamerType = Union[HFTextIteratorStreamer, HFTextStreamer, | ||
LMDeployTextIteratorStreamer, LMDeployTextStreamer] | ||
|
||
__all__ = [ | ||
'HFTextIteratorStreamer', 'HFTextStreamer', 'LMDeployTextIteratorStreamer', | ||
'LMDeployTextStreamer' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from transformers import TextIteratorStreamer, TextStreamer | ||
from transformers.models.auto import AutoTokenizer | ||
|
||
|
||
class HFTextIteratorStreamer(TextIteratorStreamer): | ||
|
||
def __init__(self, | ||
tokenizer: AutoTokenizer, | ||
skip_prompt: bool = False, | ||
timeout=None, | ||
chat_template=None, | ||
**decode_kwargs): | ||
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) | ||
self.chat_template = chat_template | ||
|
||
def on_finalized_text(self, text: str, stream_end: bool = False): | ||
|
||
for word in self.chat_template.stop_words: | ||
text = text.rstrip(word) | ||
super().on_finalized_text(text, stream_end) | ||
|
||
|
||
class HFTextStreamer(TextStreamer): | ||
|
||
def __init__(self, | ||
tokenizer: AutoTokenizer, | ||
skip_prompt: bool = False, | ||
chat_template=None, | ||
**decode_kwargs): | ||
super().__init__(tokenizer, skip_prompt, **decode_kwargs) | ||
self.chat_template = chat_template | ||
|
||
def on_finalized_text(self, text: str, stream_end: bool = False): | ||
|
||
for word in self.chat_template.stop_words: | ||
text = text.rstrip(word) | ||
super().on_finalized_text(text, stream_end) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是 quant 模型,直接 cuda 会有问题?