Skip to content

Commit

Permalink
[Feature] Support Lightllm API (#613)
Browse files Browse the repository at this point in the history
* [Feature] Support Lightllm api

* formatting & renaming

---------

Co-authored-by: Leymore <[email protected]>
  • Loading branch information
helloyongyang and Leymore committed Nov 21, 2023
1 parent 7199acc commit d3b0d5c
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 0 deletions.
33 changes: 33 additions & 0 deletions configs/eval_lightllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from mmengine.config import read_base
from opencompass.models import LightllmAPI
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask

with read_base():
from .datasets.humaneval.humaneval_gen import humaneval_datasets

datasets = [*humaneval_datasets]

models = [
dict(
abbr='LightllmAPI',
type=LightllmAPI,
url='http://localhost:8080/generate',
max_out_len=1024,
batch_size=8,
generation_kwargs=dict(
do_sample=False,
ignore_eos=False,
),
),
]

infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
max_num_workers=8,
task=dict(type=OpenICLInferTask),
),
)
63 changes: 63 additions & 0 deletions docs/en/advanced_guides/evaluation_lightllm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Evaluation with Lightllm

We now support the evaluation of large language models using [Lightllm](https://github.com/ModelTC/lightllm) for inference. Developed by SenseTime, LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance. Lightllm provides support for various large Language models, allowing users to perform model inference through Lightllm, locally deploying it as a service. During the evaluation process, OpenCompass feeds data to Lightllm through an API and processes the response. OpenCompass has been adapted for compatibility with Lightllm, and this tutorial will guide you on using OpenCompass to evaluate models with Lightllm as the inference backend.

## Setup

### Install OpenCompass

Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.

### Install Lightllm

Please follow the [Lightllm homepage](https://github.com/ModelTC/lightllm) to install the Lightllm. Pay attention to aligning the versions of relevant dependencies, especially the version of the Transformers.

## Evaluation

We use the evaluation of Humaneval with the llama2-7B model as an example.

### Step-1: Deploy the model locally as a service using Lightllm.

```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \
--port 8080 \
--tp 1 \
--max_total_token_num 120000
```

\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.

You can use the following Python script to quickly test whether the current service has been successfully started.

```python
import time
import requests
import json

url = 'http://localhost:8080/generate'
headers = {'Content-Type': 'application/json'}
data = {
'inputs': 'What is AI?',
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': 1024,
}
}
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
print(response.json())
else:
print('Error:', response.status_code, response.text)
```

### Step-2: Evaluate the above model using OpenCompass.

```shell
python run.py configs/eval_lightllm.py
```

You are expected to get the evaluation results after the inference and evaluation.

\*\*Note: \*\*In `eval_lightllm.py`, please align the configured URL with the service address from the previous step.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
advanced_guides/new_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/evaluation_lightllm.md
advanced_guides/code_eval_service.md
advanced_guides/multimodal_eval.md
advanced_guides/prompt_attack.md
Expand Down
63 changes: 63 additions & 0 deletions docs/zh_cn/advanced_guides/evaluation_lightllm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# 评测 Lightllm 模型

我们支持评测使用 [Lightllm](https://github.com/ModelTC/lightllm) 进行推理的大语言模型。Lightllm 是由商汤科技开发,是一个基于 Python 的 LLM 推理和服务框架,以其轻量级设计、易于扩展和高速性能而著称,Lightllm 对多种大模型都进行了支持。用户可以通过 Lightllm 进行模型推理,并且以服务的形式在本地起起来,在评测过程中,OpenCompass 通过 api 将数据喂给Lightllm,并对返回的结果进行处理。OpenCompass 对 Lightllm 进行了适配,本教程将介绍如何使用 OpenCompass 来对以 Lightllm 作为推理后端的模型进行评测。

## 环境配置

### 安装 OpenCompass

请根据 OpenCompass [安装指南](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) 来安装算法库和准备数据集。

### 安装 Lightllm

请根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 来安装 Lightllm。注意对齐相关依赖库的版本,尤其是 transformers 的版本。

## 评测

我们以 llama2-7B 评测 humaneval 作为例子来介绍如何评测。

### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来

```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \
--port 8080 \
--tp 1 \
--max_total_token_num 120000
```

**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。

可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功

```python
import time
import requests
import json

url = 'http://localhost:8080/generate'
headers = {'Content-Type': 'application/json'}
data = {
'inputs': 'What is AI?',
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': 1024,
}
}
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
print(response.json())
else:
print('Error:', response.status_code, response.text)
```

### 第二步: 使用 OpenCompass 评测上述模型

```shell
python run.py configs/eval_lightllm.py
```

当模型完成推理和指标计算后,我们便可获得模型的评测结果。

**注:** `eval_lightllm.py` 中,配置的 url 要和上一步服务地址对齐。
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ OpenCompass 上手路线
advanced_guides/new_dataset.md
advanced_guides/new_model.md
advanced_guides/evaluation_turbomind.md
advanced_guides/evaluation_lightllm.md
advanced_guides/code_eval_service.md
advanced_guides/multimodal_eval.md
advanced_guides/prompt_attack.md
Expand Down
1 change: 1 addition & 0 deletions opencompass/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403
from .intern_model import InternLM # noqa: F401, F403
from .lightllm_api import LightllmAPI # noqa: F401
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
from .minimax_api import MiniMax # noqa: F401
from .openai_api import OpenAI # noqa: F401
Expand Down
87 changes: 87 additions & 0 deletions opencompass/models/lightllm_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional

import requests

from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger

from .base_api import BaseAPIModel


@MODELS.register_module()
class LightllmAPI(BaseAPIModel):

is_api: bool = True

def __init__(
self,
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Optional[Dict] = None,
):

super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
retry=retry)
self.logger = get_logger()
self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}
self.do_sample = self.generation_kwargs.get('do_sample', False)
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)

def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
return results

def _generate(self, input: str, max_out_len: int) -> str:
max_num_retries = 0
while max_num_retries < self.retry:
self.wait()
header = {'content-type': 'application/json'}
try:
data = dict(inputs=input,
parameters=dict(do_sample=self.do_sample,
ignore_eos=self.ignore_eos,
max_new_tokens=max_out_len))
raw_response = requests.post(self.url,
headers=header,
data=json.dumps(data))
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
try:
response = raw_response.json()
return response['generated_text']
except requests.JSONDecodeError:
self.logger.error('JsonDecode error, got',
str(raw_response.content))
max_num_retries += 1

raise RuntimeError('Calling LightllmAPI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')

0 comments on commit d3b0d5c

Please sign in to comment.