-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support the reasoning from BaiLing LLM #1541
Open
cuauty
wants to merge
5
commits into
open-compass:main
Choose a base branch
from
cuauty:bailing_api_oc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
dafc12e
[Feature] Support the reasoning from BaiLing LLM
cuauty 82bf0ca
Add the api example
cuauty 7851fd5
Revise the generation arguments
cuauty d96cf4c
[fix] set the batch size
cuauty caa5cf2
Retry under flowcontrol of serverside
christopherdy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from mmengine.config import read_base | ||
|
||
from opencompass.models import BailingAPI | ||
from opencompass.partitioners import NaivePartitioner | ||
from opencompass.runners.local_api import LocalAPIRunner | ||
from opencompass.tasks import OpenICLInferTask | ||
|
||
with read_base(): | ||
from opencompass.configs.datasets.ceval.ceval_gen import ceval_datasets | ||
from opencompass.configs.summarizers.medium import summarizer | ||
|
||
datasets = [ | ||
*ceval_datasets, | ||
] | ||
|
||
models = [ | ||
dict( | ||
path="Bailing-Lite-0830", | ||
token="xxxxxx", # please give your key | ||
url="https://bailingchat.alipay.com/chat/completions", | ||
type=BailingAPI, | ||
generation_kwargs={}, | ||
query_per_second=1, | ||
max_seq_len=4096, | ||
), | ||
] | ||
|
||
infer = dict( | ||
partitioner=dict(type=NaivePartitioner), | ||
runner=dict( | ||
type=LocalAPIRunner, | ||
max_num_workers=2, | ||
concurrent_users=2, | ||
task=dict(type=OpenICLInferTask), | ||
), | ||
) | ||
|
||
work_dir = "outputs/api_bailing/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from opencompass.models import BailingAPI | ||
|
||
api_meta_template = dict( | ||
round=[ | ||
dict(role="HUMAN", api_role="HUMAN"), | ||
dict(role="BOT", api_role="BOT", generate=False), | ||
], | ||
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], | ||
) | ||
|
||
models = [ | ||
dict( | ||
path="Bailing-Lite-0830", | ||
token="", # set your token | ||
url="https://bailingchat.alipay.com/chat/completions", | ||
type=BailingAPI, | ||
meta_template=api_meta_template, | ||
query_per_second=1, | ||
max_seq_len=4096, | ||
batch_size=1, | ||
generation_kwargs={ | ||
"temperature": 0.4, | ||
"top_p": 1.0, | ||
"top_k": -1, | ||
"n": 1, | ||
"logprobs": 1, | ||
"use_beam_search": False, | ||
}, | ||
), | ||
] |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from opencompass.models import BailingAPI | ||
|
||
api_meta_template = dict( | ||
round=[ | ||
dict(role="HUMAN", api_role="HUMAN"), | ||
dict(role="BOT", api_role="BOT", generate=False), | ||
], | ||
reserved_roles=[dict(role="SYSTEM", api_role="SYSTEM")], | ||
) | ||
|
||
models = [ | ||
dict( | ||
path="Bailing-Pro-0920", | ||
token="", # set your token | ||
url="https://bailingchat.alipay.com/chat/completions", | ||
type=BailingAPI, | ||
meta_template=api_meta_template, | ||
query_per_second=1, | ||
max_seq_len=4096, | ||
batch_size=1, | ||
generation_kwargs={ | ||
"temperature": 0.4, | ||
"top_p": 1.0, | ||
"top_k": -1, | ||
"n": 1, | ||
"logprobs": 1, | ||
"use_beam_search": False, | ||
}, | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import concurrent | ||
import concurrent.futures | ||
import os | ||
import socket | ||
import traceback | ||
from typing import Dict, List, Optional, Union | ||
|
||
import requests | ||
from requests.adapters import HTTPAdapter | ||
from retrying import retry | ||
from urllib3.connection import HTTPConnection | ||
|
||
from opencompass.utils.prompt import PromptList | ||
|
||
from .base_api import BaseAPIModel | ||
|
||
PromptType = Union[PromptList, str] | ||
|
||
|
||
class HTTPAdapterWithSocketOptions(HTTPAdapter): | ||
def __init__(self, *args, **kwargs): | ||
self._socket_options = HTTPConnection.default_socket_options + [ | ||
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), | ||
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 75), | ||
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 30), | ||
(socket.SOL_TCP, socket.TCP_KEEPCNT, 120), | ||
] | ||
super(HTTPAdapterWithSocketOptions, self).__init__(*args, **kwargs) | ||
|
||
def init_poolmanager(self, *args, **kwargs): | ||
if self._socket_options is not None: | ||
kwargs["socket_options"] = self._socket_options | ||
super(HTTPAdapterWithSocketOptions, self).init_poolmanager(*args, **kwargs) | ||
|
||
|
||
class BailingAPI(BaseAPIModel): | ||
"""Model wrapper around Bailing Service. | ||
|
||
Args: | ||
ouput_key (str): key for prediction | ||
query_per_second (int): The maximum queries allowed per second | ||
between two consecutive calls of the API. Defaults to 1. | ||
generation_kwargs: other params | ||
retry (int): Number of retires if the API call fails. Defaults to 2. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path: str, | ||
token: str, | ||
url: str, | ||
meta_template: Optional[Dict] = None, | ||
query_per_second: int = 1, | ||
retry: int = 3, | ||
generation_kwargs: Dict = {}, | ||
max_seq_len=4096, | ||
): | ||
super().__init__( | ||
path=path, | ||
max_seq_len=max_seq_len, | ||
query_per_second=query_per_second, | ||
meta_template=meta_template, | ||
retry=retry, | ||
generation_kwargs=generation_kwargs, | ||
) | ||
|
||
self.logger.info(f"Bailing API Model Init path: {path} url={url}") | ||
if not token: | ||
token = os.environ.get("BAILING_API_KEY") | ||
if token: | ||
self._headers = {"Authorization": f"Bearer {token}"} | ||
else: | ||
raise RuntimeError(f"There is not valid token.") | ||
self._headers["Content-Type"] = "application/json" | ||
self._url = url if url else "https://bailingchat.alipay.com/chat/completions" | ||
self._model = path | ||
self._sessions = [] | ||
self._num = ( | ||
int(os.environ.get("BAILING_API_PARALLEL_NUM")) | ||
if os.environ.get("BAILING_API_PARALLEL_NUM") | ||
else 1 | ||
) | ||
try: | ||
for _ in range(self._num): | ||
adapter = HTTPAdapterWithSocketOptions() | ||
sess = requests.Session() | ||
sess.mount("http://", adapter) | ||
sess.mount("https://", adapter) | ||
self._sessions.append(sess) | ||
except Exception as e: | ||
self.logger.error(f"Fail to setup the session. {e}") | ||
raise e | ||
|
||
def generate( | ||
self, | ||
inputs: Union[List[str], PromptList], | ||
max_out_len: int = 4096, | ||
) -> List[str]: | ||
"""Generate results given a list of inputs. | ||
|
||
Args: | ||
inputs (Union[List[str], PromptList]): A list of strings or PromptDicts. | ||
The PromptDict should be organized in OpenCompass' API format. | ||
max_out_len (int): The maximum length of the output. | ||
|
||
Returns: | ||
List[str]: A list of generated strings. | ||
""" | ||
with concurrent.futures.ThreadPoolExecutor( | ||
max_workers=self._num, | ||
) as executor: | ||
future_to_m = { | ||
executor.submit( | ||
self._generate, | ||
self._sessions[i % self._num], | ||
input, | ||
max_out_len, | ||
): i | ||
for i, input in enumerate(inputs) | ||
} | ||
results = [] | ||
for future in concurrent.futures.as_completed(future_to_m): | ||
m = future_to_m[future] | ||
resp = future.result() | ||
if resp and resp.status_code == 200: | ||
try: | ||
result = resp.json() | ||
except: | ||
results.append("") | ||
else: | ||
if ( | ||
result.get("choices") | ||
and result["choices"][0].get("message") | ||
and result["choices"][0]["message"].get("content") | ||
): | ||
results.append(result["choices"][0]["message"]["content"]) | ||
else: | ||
results.append("") | ||
self.flush() | ||
return results | ||
|
||
def _generate( | ||
self, | ||
sess, | ||
input: Union[str, PromptList], | ||
max_out_len: int, | ||
) -> str: | ||
"""Generate results given an input. | ||
|
||
Args: | ||
inputs (str or PromptList): A string or PromptDict. | ||
The PromptDict should be organized in OpenCompass' API format. | ||
max_out_len (int): The maximum length of the output. | ||
|
||
Returns: | ||
str: The generated string. | ||
""" | ||
if isinstance(input, str): | ||
messages = [{"role": "user", "content": input}] | ||
else: | ||
messages = [] | ||
for item in input: | ||
content = item["prompt"] | ||
if not content: | ||
continue | ||
message = {"content": content} | ||
if item["role"] == "HUMAN": | ||
message["role"] = "user" | ||
elif item["role"] == "BOT": | ||
message["role"] = "assistant" | ||
elif item["role"] == "SYSTEM": | ||
message["role"] = "system" | ||
else: | ||
message["role"] = item["role"] | ||
messages.append(message) | ||
request = { | ||
"model": self._model, | ||
"messages": messages, | ||
"max_seq_len": max( | ||
max_out_len if max_out_len else 4096, | ||
self.max_seq_len if self.max_seq_len else 4096, | ||
), | ||
} | ||
request.update(self.generation_kwargs) | ||
try: | ||
retry_num = 0 | ||
while retry_num < self.retry: | ||
response = self._infer_result(request, sess) | ||
if response.status_code == 200: | ||
break # success | ||
elif response.status_code == 426: | ||
retry_num += 1 # retry | ||
else: | ||
raise ValueError(f"Status code = {response.status_code}") | ||
else: | ||
raise ValueError( | ||
f"Exceed the maximal retry times. Last status code = {response.status_code}" | ||
) | ||
except Exception as e: | ||
self.logger.error( | ||
f"Fail to inference request={request}; model_name={self.path}; error={e}, stack:{traceback.format_exc()}" | ||
) | ||
raise e | ||
return response | ||
|
||
@retry(stop_max_attempt_number=3, wait_fixed=16000) # ms | ||
def _infer_result(self, request, sess): | ||
response = sess.request( | ||
"POST", | ||
self._url, | ||
json=request, | ||
headers=self._headers, | ||
timeout=500, | ||
) | ||
return response |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For
.py
files, please use lowercase. Additionally, the models should be copied intoopencompass/configs/modeling
folder.