Skip to content

Commit

Permalink
adapt to lmdeploy v0.4.0 (#1073)
Browse files Browse the repository at this point in the history
* adapt to lmdeploy v0.4.0

* compatible
  • Loading branch information
lvhan028 authored Apr 28, 2024
1 parent 58a57a4 commit 1013dce
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
15 changes: 12 additions & 3 deletions opencompass/models/lmdeploy_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.pytorch import engine as tm
from lmdeploy.version import version_info

if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig
Expand All @@ -71,6 +72,7 @@ def __init__(self,
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info

def generate(
self,
Expand Down Expand Up @@ -145,9 +147,16 @@ def _generate(self,
assert type(
prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
if self.major_version >= 0 and self.minor_version >= 4:
outputs = generator.infer(session_id,
input_ids,
gen_config=gen_config)
output_ids = outputs.token_ids
else:
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)

# stop engine
if hasattr(generator, 'end'):
generator.end(session_id)
Expand Down
7 changes: 6 additions & 1 deletion opencompass/models/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self,
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info

if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig
Expand All @@ -70,6 +71,7 @@ def __init__(self,
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info

def generate(self,
inputs: List[str],
Expand Down Expand Up @@ -165,7 +167,10 @@ def _generate(self,
sequence_end=True,
step=0,
stream_output=False):
_, output_ids, _ = outputs
if self.major_version >= 0 and self.minor_version >= 4:
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids)
response = valid_str(response)
# used to trim
Expand Down

0 comments on commit 1013dce

Please sign in to comment.