-
Notifications
You must be signed in to change notification settings - Fork 435
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
91fba2c
commit 55e012e
Showing
3 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from mmengine.config import read_base | ||
from opencompass.models import LightllmApi | ||
from opencompass.partitioners import NaivePartitioner | ||
from opencompass.runners import LocalRunner | ||
from opencompass.tasks import OpenICLInferTask | ||
|
||
with read_base(): | ||
from .datasets.humaneval.humaneval_gen import humaneval_datasets | ||
|
||
datasets = [*humaneval_datasets] | ||
|
||
models = [ | ||
dict(abbr='LightllmApi', | ||
type=LightllmApi, | ||
url='http://localhost:8080/generate', | ||
max_out_len=1024, | ||
batch_size=8, | ||
generation_kwargs=dict( | ||
do_sample=False, | ||
ignore_eos=False, | ||
), | ||
), | ||
] | ||
|
||
infer = dict( | ||
partitioner=dict(type=NaivePartitioner), | ||
runner=dict( | ||
type=LocalRunner, | ||
max_num_workers=8, | ||
task=dict(type=OpenICLInferTask)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Dict, List, Optional | ||
import requests | ||
import json | ||
|
||
from opencompass.registry import MODELS | ||
from .base_api import BaseAPIModel | ||
from opencompass.utils.logging import get_logger | ||
|
||
|
||
@MODELS.register_module() | ||
class LightllmApi(BaseAPIModel): | ||
|
||
is_api: bool = True | ||
|
||
def __init__( | ||
self, | ||
path: str = 'LightllmApi', | ||
url: str = 'http://localhost:8080/generate', | ||
max_seq_len: int = 2048, | ||
meta_template: Optional[Dict] = None, | ||
retry: int = 2, | ||
generation_kwargs: Optional[Dict] = None, | ||
): | ||
|
||
super().__init__(path=path, | ||
max_seq_len=max_seq_len, | ||
meta_template=meta_template, | ||
retry=retry) | ||
self.logger = get_logger() | ||
self.url = url | ||
if generation_kwargs is not None: | ||
self.generation_kwargs = generation_kwargs | ||
else: | ||
self.generation_kwargs = {} | ||
|
||
def generate(self, inputs: List[str], max_out_len: int, | ||
**kwargs) -> List[str]: | ||
"""Generate results given a list of inputs. | ||
Args: | ||
inputs (List[str]): A list of strings or PromptDicts. | ||
The PromptDict should be organized in OpenCompass' | ||
API format. | ||
max_out_len (int): The maximum length of the output. | ||
Returns: | ||
List[str]: A list of generated strings. | ||
""" | ||
|
||
with ThreadPoolExecutor() as executor: | ||
results = list( | ||
executor.map(self._generate, inputs, | ||
[max_out_len] * len(inputs))) | ||
return results | ||
|
||
def wait(self): | ||
"""Wait till the next query can be sent. | ||
Applicable in both single-thread and multi-thread environments. | ||
""" | ||
return self.token_bucket.get_token() | ||
|
||
def _generate(self, input: str, max_out_len: int) -> str: | ||
max_num_retries = 0 | ||
while max_num_retries < self.retry: | ||
self.wait() | ||
header = {'content-type': 'application/json'} | ||
try: | ||
parameters = {'max_new_tokens': max_out_len} | ||
parameters.update(self.generation_kwargs) | ||
data = dict( | ||
inputs=input, | ||
parameters=parameters | ||
) | ||
raw_response = requests.post(self.url, | ||
headers=header, | ||
data=json.dumps(data)) | ||
except requests.ConnectionError: | ||
self.logger.error('Got connection error, retrying...') | ||
continue | ||
try: | ||
response = raw_response.json() | ||
return response['generated_text'] | ||
except: | ||
self.logger.error('JsonDecode error, got', | ||
str(raw_response.content)) | ||
max_num_retries += 1 | ||
|
||
raise RuntimeError('Calling LightllmApi failed after retrying for ' | ||
f'{max_num_retries} times. Check the logs for ' | ||
'details.') |