Skip to content

Commit

Permalink
Merge pull request #260 from polywrap/feat/custom-model-support
Browse files Browse the repository at this point in the history
Feat/custom model support
  • Loading branch information
cbrzn authored Jun 18, 2024
2 parents 42f5bc5 + 2f295d8 commit f7b0893
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 27 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/cd.pypi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Publish to Pypi

on:
push:
tags:
- "v*.*.*"

jobs:
Publish:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"

- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -

- name: Install dependencies
run: poetry install

- name: Check types
run: poetry run build-check

- name: Release to pypi
run: poetry publish --build --username __token__ --password ${{ secrets.PYPI_ACCESS_TOKEN }}
43 changes: 33 additions & 10 deletions autotx/AutoTx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import os
from textwrap import dedent
from typing import Any, Dict, Optional, Callable
from dataclasses import dataclass
from autogen import Agent as AutogenAgent
from dataclasses import dataclass, field
from autogen import Agent as AutogenAgent, ModelClient
from termcolor import cprint
from typing import Optional

Expand All @@ -21,20 +21,27 @@
from autotx.utils.constants import OPENAI_BASE_URL, OPENAI_MODEL_NAME
from autotx.wallets.smart_wallet import SmartWallet

@dataclass(kw_only=True)
class CustomModel:
client: ModelClient
arguments: Optional[Dict[str, Any]] = None

@dataclass(kw_only=True)
class Config:
verbose: bool
logs_dir: Optional[str] = None
log_costs: bool
max_rounds: int
get_llm_config: Callable[[], Optional[Dict[str, Any]]]
custom_model: Optional[CustomModel] = None

def __init__(self, verbose: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], logs_dir: Optional[str], max_rounds: Optional[int] = None, log_costs: Optional[bool] = None):
def __init__(self, verbose: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], logs_dir: Optional[str], max_rounds: Optional[int] = None, log_costs: Optional[bool] = None, custom_model: Optional[CustomModel] = None):
self.verbose = verbose
self.get_llm_config = get_llm_config
self.logs_dir = logs_dir
self.log_costs = log_costs if log_costs is not None else False
self.max_rounds = max_rounds if max_rounds is not None else 100
self.custom_model = custom_model

@dataclass
class PastRun:
Expand Down Expand Up @@ -62,6 +69,7 @@ class AutoTx:
intents: list[Intent]
network: NetworkInfo
get_llm_config: Callable[[], Optional[Dict[str, Any]]]
custom_model: Optional[CustomModel]
agents: list[AutoTxAgent]
log_costs: bool
max_rounds: int
Expand All @@ -80,6 +88,9 @@ def __init__(
config: Config,
on_notify_user: Callable[[str], None] | None = None
):
if len(agents) == 0:
raise Exception("Agents attribute can not be an empty list")

self.web3 = web3
self.wallet = wallet
self.network = network
Expand All @@ -97,6 +108,7 @@ def __init__(
self.current_run_cost_with_cache = 0
self.info_messages = []
self.on_notify_user = on_notify_user
self.custom_model = config.custom_model

def run(self, prompt: str, non_interactive: bool, summary_method: str = "last_msg") -> RunResult:
return asyncio.run(self.a_run(prompt, non_interactive, summary_method))
Expand All @@ -107,8 +119,15 @@ async def a_run(self, prompt: str, non_interactive: bool, summary_method: str =
info_messages = []

if self.verbose:
print(f"LLM model: {OPENAI_MODEL_NAME}")
print(f"LLM API URL: {OPENAI_BASE_URL}")
available_config = self.get_llm_config()
if available_config and "config_list" in available_config:
print("Available LLM configurations:")
for config in available_config["config_list"]:
if "model" in config:
print(f"LLM model: {config['model']}")
if "base_url" in config:
print(f"LLM API URL: {config['base_url']}")
print("==" * 10)

while True:
result = await self.try_run(prompt, non_interactive, summary_method)
Expand Down Expand Up @@ -175,22 +194,26 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str

agents_information = self.get_agents_information(self.agents)

user_proxy_agent = user_proxy.build(prompt, agents_information, self.get_llm_config)
clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config, self.notify_user)
user_proxy_agent = user_proxy.build(prompt, agents_information, self.get_llm_config, self.custom_model)

helper_agents: list[AutogenAgent] = [
user_proxy_agent,
]

if not non_interactive:
clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config, self.notify_user, self.custom_model)
helper_agents.append(clarifier_agent)

autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config()) for agent in self.agents]
autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config(), self.custom_model) for agent in self.agents]

manager_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config)
recipient_agent = None
if len(autogen_agents) > 1:
recipient_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config, self.custom_model)
else:
recipient_agent = autogen_agents[0]

chat = await user_proxy_agent.a_initiate_chat(
manager_agent,
recipient_agent,
message=dedent(
f"""
I am currently connected with the following wallet: {self.wallet.address}, on network: {self.network.chain_id.name}
Expand Down
3 changes: 2 additions & 1 deletion autotx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from autotx.AutoTx import AutoTx
from autotx.autotx_agent import AutoTxAgent
from autotx.autotx_tool import AutoTxTool
from autotx.utils.LlamaClient import LlamaClient

__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool']
__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool', 'LlamaClient']
2 changes: 1 addition & 1 deletion autotx/agents/SendTokensAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run(

autotx.notify_user(f"Prepared transaction: {intent.summary}")

return intent.summary
return f"{intent.summary} has been prepared."

return run

Expand Down
7 changes: 5 additions & 2 deletions autotx/autotx_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from autotx.utils.color import Color
if TYPE_CHECKING:
from autotx.autotx_tool import AutoTxTool
from autotx.AutoTx import AutoTx
from autotx.AutoTx import AutoTx, CustomModel

class AutoTxAgent():
name: str
Expand All @@ -18,7 +18,7 @@ def __init__(self) -> None:
f"{tool.name}: {tool.description}" for tool in self.tools
]

def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]]) -> autogen.Agent:
def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]], custom_model: Optional['CustomModel'] = None) -> autogen.Agent:
system_message = None
if isinstance(self.system_message, str):
system_message = self.system_message
Expand Down Expand Up @@ -58,4 +58,7 @@ def send_message_hook(
for tool in self.tools:
tool.register_tool(autotx, agent, user_proxy)

if custom_model:
agent.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments)

return agent
14 changes: 10 additions & 4 deletions autotx/helper_agents/clarifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from textwrap import dedent
from typing import Annotated, Any, Callable, Dict, Optional
from autogen import UserProxyAgent, AssistantAgent
from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, Optional
from autogen import UserProxyAgent, AssistantAgent, ModelClient

from autotx.utils.color import Color

def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], notify_user: Callable[[str, Color | None], None]) -> AssistantAgent:
if TYPE_CHECKING:
from autotx.AutoTx import CustomModel

def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], notify_user: Callable[[str, Color | None], None], custom_model: Optional['CustomModel']) -> AssistantAgent:
missing_1 = dedent("""
If the goal is not clear or missing information, you MUST ask for more information by calling the request_user_input tool.
Always ensure you have all the information needed to define the goal that can be executed without prior context.
Expand Down Expand Up @@ -77,5 +80,8 @@ def goal_outside_scope(

clarifier_agent.register_for_llm(name="goal_outside_scope", description="Notify the user about their goal not being in the scope of the agents")(goal_outside_scope)
user_proxy.register_for_execution(name="goal_outside_scope")(goal_outside_scope)


if custom_model:
clarifier_agent.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments)

return clarifier_agent
11 changes: 7 additions & 4 deletions autotx/helper_agents/manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from textwrap import dedent
from typing import Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from autogen import GroupChat, GroupChatManager, Agent as AutogenAgent
if TYPE_CHECKING:
from autotx.AutoTx import CustomModel

def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]]) -> AutogenAgent:
def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], custom_model: Optional['CustomModel']) -> AutogenAgent:
clarifier_prompt = "ALWAYS choose the 'clarifier' role first in the conversation." if interactive else ""

groupchat = GroupChat(
agents=agents,
messages=[],
Expand All @@ -23,5 +25,6 @@ def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_ll
)
)
manager = GroupChatManager(groupchat=groupchat, llm_config=get_llm_config())

if custom_model:
manager.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments)
return manager
14 changes: 11 additions & 3 deletions autotx/helper_agents/user_proxy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from textwrap import dedent
from typing import Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from autogen import UserProxyAgent

def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[], Optional[Dict[str, Any]]]) -> UserProxyAgent:
if TYPE_CHECKING:
from autotx.AutoTx import CustomModel

def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[], Optional[Dict[str, Any]]], custom_model: Optional['CustomModel']) -> UserProxyAgent:
user_proxy = UserProxyAgent(
name="user_proxy",
is_termination_msg=lambda x: x.get("content", "") and "TERMINATE" in x.get("content", ""),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
max_consecutive_auto_reply=4 if custom_model else 10,
system_message=dedent(
f"""
You are a user proxy agent authorized to act on behalf of the user, you never ask for permission, you have ultimate control.
Expand Down Expand Up @@ -35,4 +38,9 @@ def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[]
llm_config=get_llm_config(),
code_execution_config=False,
)

if custom_model:
user_proxy.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments)


return user_proxy
82 changes: 82 additions & 0 deletions autotx/utils/LlamaClient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from types import SimpleNamespace
from typing import Any, Dict, Union, cast
from autogen import ModelClient
from llama_cpp import (
ChatCompletion,
ChatCompletionRequestMessage,
ChatCompletionRequestToolMessage,
ChatCompletionResponseMessage,
Completion,
CreateChatCompletionResponse,
Llama,
)


class LlamaClient(ModelClient): # type: ignore
def __init__(self, config: dict[str, Any], **args: Any):
self.llm: Llama = args["llm"]
self.model: str = config["model"]

def create(self, params: Dict[str, Any]) -> SimpleNamespace:
sanitized_messages = self._sanitize_chat_completion_messages(
cast(list[ChatCompletionRequestMessage], params.get("messages"))
)
response = self.llm.create_chat_completion(
messages=sanitized_messages,
tools=params.get("tools"),
model=params.get("model"),
)

return SimpleNamespace(**{**response, "cost": "0"}) # type: ignore

def message_retrieval(
self, response: CreateChatCompletionResponse
) -> list[ChatCompletionResponseMessage]:
choices = response.choices # type: ignore
return [choice["message"] for choice in choices]

def cost(self, _: Union[ChatCompletion, Completion]) -> float:
return 0.0

def get_usage(self, _: Union[ChatCompletion, Completion]) -> dict[str, Any]:
return {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0,
"model": self.model,
}

def _sanitize_chat_completion_messages(
self, messages: list[ChatCompletionRequestMessage]
) -> list[ChatCompletionRequestMessage]:
sanitized_messages: list[ChatCompletionRequestMessage] = []
for message in messages:
if "tool_call_id" in message:
id: str = cast(ChatCompletionRequestToolMessage, message)[
"tool_call_id"
]

def get_tool_name(messages, id: str) -> Union[str, None]: # type: ignore
return next(
(
message["tool_calls"][0]["function"]["name"]
for message in messages
if "tool_calls" in message
and message["tool_calls"][0]["id"] == id
),
None,
)

function_name = get_tool_name(messages, id)
if function_name is None:
raise Exception(f"No tool response for this tool call with id {id}")

sanitized_messages.append(
ChatCompletionRequestToolMessage(**message, name=function_name) # type: ignore
)

else:
sanitized_messages.append(message)

return sanitized_messages
Loading

0 comments on commit f7b0893

Please sign in to comment.