Skip to content

Commit

Permalink
Update Sensetime API (#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonysy committed Jan 26, 2024
1 parent 4aa7456 commit 8ed022b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 24 deletions.
17 changes: 16 additions & 1 deletion configs/api_examples/eval_api_sensetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,22 @@
query_per_second=1,
max_out_len=2048,
max_seq_len=2048,
batch_size=8),
batch_size=8,
parameters={
"temperature": 0.8,
"top_p": 0.7,
"max_new_tokens": 1024,
"repetition_penalty": 1.05,
"know_ids": [],
"stream": True,
"user": "#*#***TestUser***#*#",
"knowledge_config": {
"control_level": "normal",
"knowledge_base_result": False,
"online_search_result": False
}
}
)
]

infer = dict(
Expand Down
103 changes: 80 additions & 23 deletions opencompass/models/sensetime_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -30,24 +32,32 @@ class SenseTime(BaseAPIModel):
def __init__(
self,
path: str,
key: str,
url: str,
key: str = 'ENV',
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
parameters: Optional[Dict] = None,
):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry)

if isinstance(key, str):
self.keys = os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key
else:
self.keys = key

self.headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}'
'Authorization': f'Bearer {self.keys}'
}
self.url = url
self.model = path
self.params = parameters

def generate(
self,
Expand Down Expand Up @@ -104,38 +114,85 @@ def _generate(
messages.append(msg)

data = {'messages': messages, 'model': self.model}
data.update(self.params)

stream = data['stream']

max_num_retries = 0
while max_num_retries < self.retry:
self.acquire()

max_num_retries += 1
raw_response = requests.request('POST',
url=self.url,
headers=self.headers,
json=data)
response = raw_response.json()
requests_id = raw_response.headers['X-Request-Id'] # noqa
self.release()

if response is None:
print('Connection error, reconnect.')
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here
# to slow down the request
self.wait()
continue
if raw_response.status_code == 200:
msg = response['data']['choices'][0]['message']
return msg

if (raw_response.status_code != 200):
if response['error']['code'] == 18:
# security issue
return 'error:unsafe'
if not stream:
response = raw_response.json()

if response is None:
print('Connection error, reconnect.')
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here
# to slow down the request
self.wait()
continue
if raw_response.status_code == 200:
msg = response['data']['choices'][0]['message']
return msg

if (raw_response.status_code != 200):
if response['error']['code'] == 18:
# security issue
return 'error:unsafe'
if response['error']['code'] == 17:
return 'error:too long'
else:
print(raw_response.text)
time.sleep(1)
continue
else:
# stream data to msg
raw_response.encoding = 'utf-8'

if raw_response.status_code == 200:
response_text = raw_response.text
data_blocks = response_text.split('data:')
data_blocks = data_blocks[1:]

first_block = json.loads(data_blocks[0])
if first_block['status']['code'] != 0:
msg = f"error:{first_block['status']['code']},"
f" {first_block['status']['message']}"
self.logger.error(msg)
return msg

msg = ''
for i, part in enumerate(data_blocks):
# print(f'process {i}: {part}')
try:
if part.startswith('[DONE]'):
break

json_data = json.loads(part)
choices = json_data['data']['choices']
for c in choices:
delta = c.get('delta')
msg += delta
except json.decoder.JSONDecodeError as err:
print(err)
self.logger.error(f'Error decoding JSON: {part}')
return msg

else:
print(raw_response.text)
print(raw_response.text,
raw_response.headers.get('X-Request-Id'))
time.sleep(1)
continue

print(response)
max_num_retries += 1

raise RuntimeError(raw_response.text)
raise RuntimeError(
f'request id: '
f'{raw_response.headers.get("X-Request-Id")}, {raw_response.text}')

0 comments on commit 8ed022b

Please sign in to comment.