Skip to content

Commit

Permalink
[Feature] Support chat style inferencer.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Nov 28, 2023
1 parent d4af31b commit e51f46f
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 346 deletions.
53 changes: 27 additions & 26 deletions configs/eval_openai_agent.py → configs/eval_chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from mmengine.config import read_base
from opencompass.models.openai_api import OpenAI
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.openicl import AgentInferencer

with read_base():
from .summarizers.medium import summarizer
from .datasets.gsm8k.gsm8k_gen import gsm8k_datasets as datasets

from opencompass.models.lagent import LagentAgent
from lagent.llms import GPTAPI
from lagent.agents.react import ReAct, ReActProtocol
from lagent.actions import PythonInterpreter
from lagent import PythonInterpreter, ReAct
from lagent.agents.react import ReActProtocol

FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""

Expand Down Expand Up @@ -109,28 +108,30 @@ def solution():
```'''

models = [
dict(abbr='gpt-3.5-react',
type=LagentAgent,
agent_type=ReAct,
max_turn=3,
llm=dict(
type=GPTAPI,
model_type='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=PythonInterpreter,
description=PYTHON_INTERPRETER_DESCRIPTION),
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=8),
dict(
abbr='gpt-3.5-react',
type=LagentAgent,
agent_type=ReAct,
max_turn=3,
llm=dict(
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=PythonInterpreter,
description=PYTHON_INTERPRETER_DESCRIPTION),
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=1,
),
]

for dataset in datasets:
Expand Down
82 changes: 82 additions & 0 deletions configs/eval_chat_cibench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from lagent.agents.react import ReActProtocol
from mmengine.config import read_base

from opencompass.lagent.actions.ipython_interpreter import IPythonInterpreter
from opencompass.lagent.agents.react import CIReAct
from opencompass.models.lagent import CodeAgent
from opencompass.models.openai_api import OpenAI
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask

with read_base():
from .datasets.CIBench.CIBench_gen_eb42f9 import \
cibench_datasets as datasets

FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""

FEWSHOT_INSTRUCTION = """\
You are an assistant who can utilize external tools.
{tool_description}
To use a tool, please response with the following format:
```
{thought} Think what you need to solve, do you need to use tools?
{action} The tool name, should be one of [{action_names}].
{action_input} The input to the tool that you want to use.
```
The tool will give you response after your response using the following format:
```
{response} the results after call the tool.
```
Therefore DO NOT generate tool response by yourself.
Also please follow the guidelines:
1. Always use code interpreter to solve the problem.
2. The generated codes should always in a markdown code block format.
3. The generated codes will be executed in an ipython manner and the results will be cached.
4. Your responded code should always be simple and only solves the problem in current step.
Begin!
"""

IPYTHON_INTERPRETER_DESCRIPTION = '''\
It can run Python code in a manner as jupyter notebook. The code must be a valid code that contains only python method.'''

models = [
dict(
abbr='gpt-3.5-code',
type=CodeAgent,
agent_type=CIReAct,
max_turn=3,
llm=dict(
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=IPythonInterpreter,
description=IPYTHON_INTERPRETER_DESCRIPTION)
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=1,
),
]

for dataset in datasets:
# Evaluate on every assistant response
dataset['infer_cfg']['inferencer']['infer_mode'] = 'every'

infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)
35 changes: 35 additions & 0 deletions configs/eval_chat_last.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from mmengine.config import read_base

from opencompass.models.openai_api import OpenAI
from opencompass.openicl import ChatInferencer
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask

with read_base():
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets as datasets

models = [
dict(
abbr='gpt-3.5',
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]

for dataset in datasets:
# Use ChatInferencer instead of GenInferencer
dataset['infer_cfg']['inferencer'] = dict(type=ChatInferencer)

infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)
5 changes: 4 additions & 1 deletion opencompass/datasets/cibench.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def load_experiment(file: str) -> dict:
outputs.append(None)
return dict(
experiment=file,
questions=questions,
questions=sum(([
dict(role='user', content=question),
dict(role='assistant', content=output)
] for question, output in zip(questions, outputs)), []),
references=dict(outputs=outputs, tags=tags, experiment=file),
)

Expand Down
2 changes: 2 additions & 0 deletions opencompass/lagent/actions/ipython_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self,
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
self.user_data_dir = user_data_dir
self._initialized = False
if not os.path.exists(WORK_DIR):
os.mkdir(WORK_DIR)

@staticmethod
def start_kernel():
Expand Down
134 changes: 3 additions & 131 deletions opencompass/lagent/agents/react.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,7 @@
import re
from typing import Union

from lagent.actions import ActionExecutor
from lagent.agents.base_agent import BaseAgent
from lagent.agents.react import ReActProtocol
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.agents.react import ReAct
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn


class ReAct(BaseAgent):
"""An implementation of ReAct (https://arxiv.org/abs/2210.03629)
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat
and act as backend.
action_executor (ActionExecutor): an action executor to manage
all actions and their response.
protocol (ReActProtocol): a wrapper to generate prompt and
parse the response from LLM / actions.
max_turn (int): the maximum number of trails for LLM to generate
plans that can be successfully parsed by ReWOO protocol.
"""

def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
action_executor: ActionExecutor,
protocol: ReActProtocol = ReActProtocol(),
max_turn: int = 2) -> None:
self.max_turn = max_turn
super().__init__(llm=llm,
action_executor=action_executor,
protocol=protocol)

def reset(self):
"""Reset history."""
self._session_history = []

def opencompass_adapter(self, prompt):
# adapter for prompt parsing
if isinstance(prompt, list):
system_prompt = []
merged_prompt = []
for p in prompt:
tmp_p = p.copy()
if 'content' in tmp_p:
tmp_p['prompt'] = tmp_p.pop('content')
if 'role' in tmp_p:
if tmp_p['role'] == 'system':
# skip system prompt
system_prompt.append(tmp_p['prompt'])
continue
# no system for meta template temperaily
if tmp_p['role'] == 'assistant':
tmp_p['role'] = 'BOT'
if tmp_p['role'] == 'user':
# merge previous system prompt to user
system_str = ''.join(system_prompt)
tmp_p['prompt'] = system_str + tmp_p['prompt']
tmp_p['role'] = 'HUMAN'
system_prompt = []
merged_prompt.append(tmp_p)

# merge if system still exists
if system_prompt:
if 'role' in merged_prompt[-1]:
if merged_prompt[-1]['role'] == 'HUMAN':
# append to the final human prompt
merged_prompt[-1]['prompt'] += ''.join(system_prompt)
else:
# create a human prompt behind
merged_prompt.append(
dict(role='HUMAN', prompt=''.join(system_prompt)))

from opencompass.utils.prompt import PromptList
new_prompt = PromptList()
# adapter for meta template
new_prompt.append(dict(section='round', pos='begin'))
new_prompt.extend(merged_prompt)
new_prompt.append(dict(section='round', pos='end'))

return new_prompt

def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
force_stop = False
default_response = '对不起,我无法回答你的问题'
for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
prompt = self.opencompass_adapter(prompt)
# allow single generation
response = self._llm.generate_from_template([prompt], 512)[0]
self._inner_history.append(dict(role='assistant',
content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)

# TODO: hard code here
action_input = re.sub('<eoa>', '', action_input)

if 'tensorflow' in action_input:
# skip tensorflow currently
break
action_return: ActionReturn = self._action_executor(
action, action_input)
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
return agent_return
self._inner_history.append(
dict(role='system',
content=self._protocol.format_response(action_return)))
if turn == self.max_turn - 1:
force_stop = True
agent_return.response = default_response
# only append the user and final response
self._session_history.append(dict(role='user', content=message))
self._session_history.append(
dict(role='assistant', content=agent_return.response))
return agent_return


class CIReAct(ReAct):
"""Code Interpreter version of ReAct. The success state is different from
ReAct.
Expand Down Expand Up @@ -165,9 +39,7 @@ def chat(self, message: str) -> AgentReturn:
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
prompt = self.opencompass_adapter(prompt)
# allow single generation
response = self._llm.generate_from_template([prompt], 512)[0]
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(dict(role='assistant',
content=response))
thought, action, action_input = self._protocol.parse(
Expand All @@ -179,7 +51,7 @@ def chat(self, message: str) -> AgentReturn:
if action_return.state == ActionStatusCode.SUCCESS:
# if success, stash model response and system response
self._session_history.append(
dict(role='assistant', content=action_return.args['text']))
dict(role='assistant', content=response))
self._session_history.append(
dict(
role='system',
Expand Down
Loading

0 comments on commit e51f46f

Please sign in to comment.