diff --git a/.github/workflows/cd.pypi.yaml b/.github/workflows/cd.pypi.yaml new file mode 100644 index 00000000..1699058e --- /dev/null +++ b/.github/workflows/cd.pypi.yaml @@ -0,0 +1,30 @@ +name: Publish to Pypi + +on: + push: + tags: + - "v*.*.*" + +jobs: + Publish: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install Poetry + run: curl -sSL https://install.python-poetry.org | python3 - + + - name: Install dependencies + run: poetry install + + - name: Check types + run: poetry run build-check + + - name: Release to pypi + run: poetry publish --build --username __token__ --password ${{ secrets.PYPI_ACCESS_TOKEN }} \ No newline at end of file diff --git a/autotx/AutoTx.py b/autotx/AutoTx.py index b5b83826..94a11ce1 100644 --- a/autotx/AutoTx.py +++ b/autotx/AutoTx.py @@ -5,8 +5,8 @@ import os from textwrap import dedent from typing import Any, Dict, Optional, Callable -from dataclasses import dataclass -from autogen import Agent as AutogenAgent +from dataclasses import dataclass, field +from autogen import Agent as AutogenAgent, ModelClient from termcolor import cprint from typing import Optional @@ -21,6 +21,11 @@ from autotx.utils.constants import OPENAI_BASE_URL, OPENAI_MODEL_NAME from autotx.wallets.smart_wallet import SmartWallet +@dataclass(kw_only=True) +class CustomModel: + client: ModelClient + arguments: Optional[Dict[str, Any]] = None + @dataclass(kw_only=True) class Config: verbose: bool @@ -28,13 +33,15 @@ class Config: log_costs: bool max_rounds: int get_llm_config: Callable[[], Optional[Dict[str, Any]]] + custom_model: Optional[CustomModel] = None - def __init__(self, verbose: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], logs_dir: Optional[str], max_rounds: Optional[int] = None, log_costs: Optional[bool] = None): + def __init__(self, verbose: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], logs_dir: Optional[str], max_rounds: Optional[int] = None, log_costs: Optional[bool] = None, custom_model: Optional[CustomModel] = None): self.verbose = verbose self.get_llm_config = get_llm_config self.logs_dir = logs_dir self.log_costs = log_costs if log_costs is not None else False self.max_rounds = max_rounds if max_rounds is not None else 100 + self.custom_model = custom_model @dataclass class PastRun: @@ -62,6 +69,7 @@ class AutoTx: intents: list[Intent] network: NetworkInfo get_llm_config: Callable[[], Optional[Dict[str, Any]]] + custom_model: Optional[CustomModel] agents: list[AutoTxAgent] log_costs: bool max_rounds: int @@ -80,6 +88,9 @@ def __init__( config: Config, on_notify_user: Callable[[str], None] | None = None ): + if len(agents) == 0: + raise Exception("Agents attribute can not be an empty list") + self.web3 = web3 self.wallet = wallet self.network = network @@ -97,6 +108,7 @@ def __init__( self.current_run_cost_with_cache = 0 self.info_messages = [] self.on_notify_user = on_notify_user + self.custom_model = config.custom_model def run(self, prompt: str, non_interactive: bool, summary_method: str = "last_msg") -> RunResult: return asyncio.run(self.a_run(prompt, non_interactive, summary_method)) @@ -107,8 +119,15 @@ async def a_run(self, prompt: str, non_interactive: bool, summary_method: str = info_messages = [] if self.verbose: - print(f"LLM model: {OPENAI_MODEL_NAME}") - print(f"LLM API URL: {OPENAI_BASE_URL}") + available_config = self.get_llm_config() + if available_config and "config_list" in available_config: + print("Available LLM configurations:") + for config in available_config["config_list"]: + if "model" in config: + print(f"LLM model: {config['model']}") + if "base_url" in config: + print(f"LLM API URL: {config['base_url']}") + print("==" * 10) while True: result = await self.try_run(prompt, non_interactive, summary_method) @@ -175,22 +194,26 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str agents_information = self.get_agents_information(self.agents) - user_proxy_agent = user_proxy.build(prompt, agents_information, self.get_llm_config) - clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config, self.notify_user) + user_proxy_agent = user_proxy.build(prompt, agents_information, self.get_llm_config, self.custom_model) helper_agents: list[AutogenAgent] = [ user_proxy_agent, ] if not non_interactive: + clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config, self.notify_user, self.custom_model) helper_agents.append(clarifier_agent) - autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config()) for agent in self.agents] + autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config(), self.custom_model) for agent in self.agents] - manager_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config) + recipient_agent = None + if len(autogen_agents) > 1: + recipient_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config, self.custom_model) + else: + recipient_agent = autogen_agents[0] chat = await user_proxy_agent.a_initiate_chat( - manager_agent, + recipient_agent, message=dedent( f""" I am currently connected with the following wallet: {self.wallet.address}, on network: {self.network.chain_id.name} diff --git a/autotx/__init__.py b/autotx/__init__.py index a5d73445..de3105cd 100644 --- a/autotx/__init__.py +++ b/autotx/__init__.py @@ -1,5 +1,6 @@ from autotx.AutoTx import AutoTx from autotx.autotx_agent import AutoTxAgent from autotx.autotx_tool import AutoTxTool +from autotx.utils.LlamaClient import LlamaClient -__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool'] +__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool', 'LlamaClient'] diff --git a/autotx/agents/SendTokensAgent.py b/autotx/agents/SendTokensAgent.py index b915aa27..88ac55ae 100644 --- a/autotx/agents/SendTokensAgent.py +++ b/autotx/agents/SendTokensAgent.py @@ -97,7 +97,7 @@ def run( autotx.notify_user(f"Prepared transaction: {intent.summary}") - return intent.summary + return f"{intent.summary} has been prepared." return run diff --git a/autotx/autotx_agent.py b/autotx/autotx_agent.py index d93250a6..a3784015 100644 --- a/autotx/autotx_agent.py +++ b/autotx/autotx_agent.py @@ -4,7 +4,7 @@ from autotx.utils.color import Color if TYPE_CHECKING: from autotx.autotx_tool import AutoTxTool - from autotx.AutoTx import AutoTx + from autotx.AutoTx import AutoTx, CustomModel class AutoTxAgent(): name: str @@ -18,7 +18,7 @@ def __init__(self) -> None: f"{tool.name}: {tool.description}" for tool in self.tools ] - def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]]) -> autogen.Agent: + def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]], custom_model: Optional['CustomModel'] = None) -> autogen.Agent: system_message = None if isinstance(self.system_message, str): system_message = self.system_message @@ -58,4 +58,7 @@ def send_message_hook( for tool in self.tools: tool.register_tool(autotx, agent, user_proxy) + if custom_model: + agent.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments) + return agent \ No newline at end of file diff --git a/autotx/helper_agents/clarifier.py b/autotx/helper_agents/clarifier.py index 97c117c0..0aad7f05 100644 --- a/autotx/helper_agents/clarifier.py +++ b/autotx/helper_agents/clarifier.py @@ -1,10 +1,13 @@ from textwrap import dedent -from typing import Annotated, Any, Callable, Dict, Optional -from autogen import UserProxyAgent, AssistantAgent +from typing import TYPE_CHECKING, Annotated, Any, Callable, Dict, Optional +from autogen import UserProxyAgent, AssistantAgent, ModelClient from autotx.utils.color import Color -def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], notify_user: Callable[[str, Color | None], None]) -> AssistantAgent: +if TYPE_CHECKING: + from autotx.AutoTx import CustomModel + +def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], notify_user: Callable[[str, Color | None], None], custom_model: Optional['CustomModel']) -> AssistantAgent: missing_1 = dedent(""" If the goal is not clear or missing information, you MUST ask for more information by calling the request_user_input tool. Always ensure you have all the information needed to define the goal that can be executed without prior context. @@ -77,5 +80,8 @@ def goal_outside_scope( clarifier_agent.register_for_llm(name="goal_outside_scope", description="Notify the user about their goal not being in the scope of the agents")(goal_outside_scope) user_proxy.register_for_execution(name="goal_outside_scope")(goal_outside_scope) - + + if custom_model: + clarifier_agent.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments) + return clarifier_agent \ No newline at end of file diff --git a/autotx/helper_agents/manager.py b/autotx/helper_agents/manager.py index 7fa153f2..2c1cf6d5 100644 --- a/autotx/helper_agents/manager.py +++ b/autotx/helper_agents/manager.py @@ -1,10 +1,12 @@ from textwrap import dedent -from typing import Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from autogen import GroupChat, GroupChatManager, Agent as AutogenAgent +if TYPE_CHECKING: + from autotx.AutoTx import CustomModel -def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]]) -> AutogenAgent: +def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], custom_model: Optional['CustomModel']) -> AutogenAgent: clarifier_prompt = "ALWAYS choose the 'clarifier' role first in the conversation." if interactive else "" - + groupchat = GroupChat( agents=agents, messages=[], @@ -23,5 +25,6 @@ def build(agents: list[AutogenAgent], max_rounds: int, interactive: bool, get_ll ) ) manager = GroupChatManager(groupchat=groupchat, llm_config=get_llm_config()) - + if custom_model: + manager.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments) return manager \ No newline at end of file diff --git a/autotx/helper_agents/user_proxy.py b/autotx/helper_agents/user_proxy.py index 39f40c3a..0c78f078 100644 --- a/autotx/helper_agents/user_proxy.py +++ b/autotx/helper_agents/user_proxy.py @@ -1,13 +1,16 @@ from textwrap import dedent -from typing import Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from autogen import UserProxyAgent -def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[], Optional[Dict[str, Any]]]) -> UserProxyAgent: +if TYPE_CHECKING: + from autotx.AutoTx import CustomModel + +def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[], Optional[Dict[str, Any]]], custom_model: Optional['CustomModel']) -> UserProxyAgent: user_proxy = UserProxyAgent( name="user_proxy", is_termination_msg=lambda x: x.get("content", "") and "TERMINATE" in x.get("content", ""), human_input_mode="NEVER", - max_consecutive_auto_reply=10, + max_consecutive_auto_reply=4 if custom_model else 10, system_message=dedent( f""" You are a user proxy agent authorized to act on behalf of the user, you never ask for permission, you have ultimate control. @@ -35,4 +38,9 @@ def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[] llm_config=get_llm_config(), code_execution_config=False, ) + + if custom_model: + user_proxy.register_model_client(model_client_cls=custom_model.client, **custom_model.arguments) + + return user_proxy \ No newline at end of file diff --git a/autotx/utils/LlamaClient.py b/autotx/utils/LlamaClient.py new file mode 100644 index 00000000..1c4ded06 --- /dev/null +++ b/autotx/utils/LlamaClient.py @@ -0,0 +1,82 @@ +from types import SimpleNamespace +from typing import Any, Dict, Union, cast +from autogen import ModelClient +from llama_cpp import ( + ChatCompletion, + ChatCompletionRequestMessage, + ChatCompletionRequestToolMessage, + ChatCompletionResponseMessage, + Completion, + CreateChatCompletionResponse, + Llama, +) + + +class LlamaClient(ModelClient): # type: ignore + def __init__(self, config: dict[str, Any], **args: Any): + self.llm: Llama = args["llm"] + self.model: str = config["model"] + + def create(self, params: Dict[str, Any]) -> SimpleNamespace: + sanitized_messages = self._sanitize_chat_completion_messages( + cast(list[ChatCompletionRequestMessage], params.get("messages")) + ) + response = self.llm.create_chat_completion( + messages=sanitized_messages, + tools=params.get("tools"), + model=params.get("model"), + ) + + return SimpleNamespace(**{**response, "cost": "0"}) # type: ignore + + def message_retrieval( + self, response: CreateChatCompletionResponse + ) -> list[ChatCompletionResponseMessage]: + choices = response.choices # type: ignore + return [choice["message"] for choice in choices] + + def cost(self, _: Union[ChatCompletion, Completion]) -> float: + return 0.0 + + def get_usage(self, _: Union[ChatCompletion, Completion]) -> dict[str, Any]: + return { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost": 0, + "model": self.model, + } + + def _sanitize_chat_completion_messages( + self, messages: list[ChatCompletionRequestMessage] + ) -> list[ChatCompletionRequestMessage]: + sanitized_messages: list[ChatCompletionRequestMessage] = [] + for message in messages: + if "tool_call_id" in message: + id: str = cast(ChatCompletionRequestToolMessage, message)[ + "tool_call_id" + ] + + def get_tool_name(messages, id: str) -> Union[str, None]: # type: ignore + return next( + ( + message["tool_calls"][0]["function"]["name"] + for message in messages + if "tool_calls" in message + and message["tool_calls"][0]["id"] == id + ), + None, + ) + + function_name = get_tool_name(messages, id) + if function_name is None: + raise Exception(f"No tool response for this tool call with id {id}") + + sanitized_messages.append( + ChatCompletionRequestToolMessage(**message, name=function_name) # type: ignore + ) + + else: + sanitized_messages.append(message) + + return sanitized_messages diff --git a/poetry.lock b/poetry.lock index 9cd75a46..fff9099f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1355,6 +1355,28 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "llama-cpp-python" +version = "0.2.78" +description = "Python bindings for the llama.cpp library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "llama_cpp_python-0.2.78.tar.gz", hash = "sha256:3df7cfde84287faaf29675fba8939060c3ab3f0ce8db875dabf7df5d83bd8751"}, +] + +[package.dependencies] +diskcache = ">=5.6.1" +jinja2 = ">=2.11.3" +numpy = ">=1.20.0" +typing-extensions = ">=4.5.0" + +[package.extras] +all = ["llama_cpp_python[dev,server,test]"] +dev = ["black (>=23.3.0)", "httpx (>=0.24.1)", "mkdocs (>=1.4.3)", "mkdocs-material (>=9.1.18)", "mkdocstrings[python] (>=0.22.0)", "pytest (>=7.4.0)", "twine (>=4.0.2)"] +server = ["PyYAML (>=5.1)", "fastapi (>=0.100.0)", "pydantic-settings (>=2.0.1)", "sse-starlette (>=1.6.1)", "starlette-context (>=0.3.6,<0.4)", "uvicorn (>=0.22.0)"] +test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)", "scipy (>=1.10)"] + [[package]] name = "lru-dict" version = "1.2.0" @@ -2336,6 +2358,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3549,4 +3572,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "57c468f2b88c0a5b9b152634cafef530545725cdeaa9e499c7e7b6c13cd5ce41" +content-hash = "5dfc1bc10f28a24e09c829fbdf2469dfcd0ab7dcd8c51312e7f15e8ffd02601f" diff --git a/pyproject.toml b/pyproject.toml index ccab2597..9b270787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "autotx" -version = "0.1.0" +version = "0.1.1" description = "" authors = ["Nestor Amesty "] readme = "README.md" @@ -19,6 +19,7 @@ web3 = "^6.19.0" safe-eth-py = "^5.8.0" uvicorn = "^0.29.0" supabase = "^2.5.0" +llama-cpp-python = "^0.2.78" [tool.poetry.group.dev.dependencies] mypy = "^1.8.0"