diff --git a/opencompass/models/lmdeploy_pytorch.py b/opencompass/models/lmdeploy_pytorch.py index 814c3cc68..f9d67da45 100644 --- a/opencompass/models/lmdeploy_pytorch.py +++ b/opencompass/models/lmdeploy_pytorch.py @@ -50,6 +50,7 @@ def __init__(self, max_seq_len=max_seq_len, meta_template=meta_template) from lmdeploy.pytorch import engine as tm + from lmdeploy.version import version_info if engine_config is not None: from lmdeploy.messages import PytorchEngineConfig @@ -71,6 +72,7 @@ def __init__(self, self.generator_ids = [i + 1 for i in range(concurrency)] self.gen_config = gen_config self.end_str = end_str + self.major_version, self.minor_version, _ = version_info def generate( self, @@ -145,9 +147,16 @@ def _generate(self, assert type( prompt) is str, 'We only support string for TurboMind Python API' input_ids = self.tokenizer.encode(prompt) - _, output_ids, _ = generator.infer(session_id, - input_ids, - gen_config=gen_config) + if self.major_version >= 0 and self.minor_version >= 4: + outputs = generator.infer(session_id, + input_ids, + gen_config=gen_config) + output_ids = outputs.token_ids + else: + _, output_ids, _ = generator.infer(session_id, + input_ids, + gen_config=gen_config) + # stop engine if hasattr(generator, 'end'): generator.end(session_id) diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py index 50c3e5ca6..9479f02f9 100644 --- a/opencompass/models/turbomind.py +++ b/opencompass/models/turbomind.py @@ -54,6 +54,7 @@ def __init__(self, max_seq_len=max_seq_len, meta_template=meta_template) from lmdeploy.turbomind import TurboMind + from lmdeploy.version import version_info if engine_config is not None: from lmdeploy.messages import TurbomindEngineConfig @@ -70,6 +71,7 @@ def __init__(self, self.generator_ids = [i + 1 for i in range(concurrency)] self.gen_config = gen_config self.end_str = end_str + self.major_version, self.minor_version, _ = version_info def generate(self, inputs: List[str], @@ -165,7 +167,10 @@ def _generate(self, sequence_end=True, step=0, stream_output=False): - _, output_ids, _ = outputs + if self.major_version >= 0 and self.minor_version >= 4: + output_ids = outputs.token_ids + else: + _, output_ids, _ = outputs response = self.tokenizer.decode(output_ids) response = valid_str(response) # used to trim