Skip to content

Commit

Permalink
[Feat] support lightllm api
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Nov 19, 2023
1 parent 91fba2c commit 55e012e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
31 changes: 31 additions & 0 deletions configs/eval_lightllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from mmengine.config import read_base
from opencompass.models import LightllmApi
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask

with read_base():
from .datasets.humaneval.humaneval_gen import humaneval_datasets

datasets = [*humaneval_datasets]

models = [
dict(abbr='LightllmApi',
type=LightllmApi,
url='http://localhost:8080/generate',
max_out_len=1024,
batch_size=8,
generation_kwargs=dict(
do_sample=False,
ignore_eos=False,
),
),
]

infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
max_num_workers=8,
task=dict(type=OpenICLInferTask)),
)
1 change: 1 addition & 0 deletions opencompass/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .openai_api import OpenAI # noqa: F401
from .xunfei_api import XunFei # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401
from .lightllm_api import LightllmApi # noqa: F401
92 changes: 92 additions & 0 deletions opencompass/models/lightllm_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional
import requests
import json

from opencompass.registry import MODELS
from .base_api import BaseAPIModel
from opencompass.utils.logging import get_logger


@MODELS.register_module()
class LightllmApi(BaseAPIModel):

is_api: bool = True

def __init__(
self,
path: str = 'LightllmApi',
url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Optional[Dict] = None,
):

super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
retry=retry)
self.logger = get_logger()
self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}

def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
return results

def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()

def _generate(self, input: str, max_out_len: int) -> str:
max_num_retries = 0
while max_num_retries < self.retry:
self.wait()
header = {'content-type': 'application/json'}
try:
parameters = {'max_new_tokens': max_out_len}
parameters.update(self.generation_kwargs)
data = dict(
inputs=input,
parameters=parameters
)
raw_response = requests.post(self.url,
headers=header,
data=json.dumps(data))
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
try:
response = raw_response.json()
return response['generated_text']
except:
self.logger.error('JsonDecode error, got',
str(raw_response.content))
max_num_retries += 1

raise RuntimeError('Calling LightllmApi failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')

0 comments on commit 55e012e

Please sign in to comment.