Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Unify AI Support #1446

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/examples/unify.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
llm:
api_type: "unify"
model: "llama-3-8b-chat@together-ai" # or Get a list of models here: https://docs.unify.ai/python/utils#list-models
base_url: "https://api.unify.ai/v0"
api_key: "Enter your Unify API key here" # or Get your API key from https://console.unify.ai
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep some value with YOUR_API_KEY

1 change: 1 addition & 0 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LLMType(Enum):
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
UNIFY = "unify"

def __missing__(self, key):
return self.OPENAI
Expand Down
1 change: 1 addition & 0 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
LLMType.MISTRAL,
LLMType.YI,
LLMType.OPENROUTER,
LLMType.UNIFY,
]
)
class OpenAILLM(BaseLLM):
Expand Down
122 changes: 122 additions & 0 deletions metagpt/provider/unify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Optional, Dict, List, Union
from openai.types import Completion, CompletionUsage
from openai.types.chat import ChatCompletion

from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import count_message_tokens, OPENAI_TOKEN_COSTS
from unify.clients import Unify, AsyncUnify

@register_provider([LLMType.UNIFY])
class UnifyLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_client()
self.cost_manager = CostManager(token_costs=OPENAI_TOKEN_COSTS) # Using OpenAI costs as Unify is compatible
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about non-openai models?


def _init_client(self):
self.model = self.config.model
self.client = Unify(
api_key=self.config.api_key,
endpoint=f"{self.config.model}@{self.config.provider}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no provider field in LLMConfig

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to only add async client

)
self.async_client = AsyncUnify(
api_key=self.config.api_key,
endpoint=f"{self.config.model}@{self.config.provider}",
)

def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
return {
"messages": messages,
"max_tokens": self.config.max_token,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"temperature": self.config.temperature,
"stream": stream,
}

def get_choice_text(self, resp: Union[ChatCompletion, str]) -> str:
if isinstance(resp, str):
return resp
return resp.choices[0].message.content if resp.choices else ""

def _update_costs(self, usage: dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add due to implemented under BaseLLM

prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)

async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
try:
response = await self.async_client.generate(
messages=messages,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_const_kwargs not used ?

max_tokens=self.config.max_token,
temperature=self.config.temperature,
stream=False,
)
# Construct a ChatCompletion object to match OpenAI's format
chat_completion = ChatCompletion(
id="unify_chat_completion",
object="chat.completion",
created=0, # Unify doesn't provide this, so we use 0
model=self.model,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"finish_reason": "stop",
}],
usage=CompletionUsage(
prompt_tokens=count_message_tokens(messages, self.model),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently main branch has no count_message_tokens func. and suggest to support unify common models not only from openai.

completion_tokens=count_message_tokens([{"role": "assistant", "content": response}], self.model),
total_tokens=0, # Will be calculated below
),
)
chat_completion.usage.total_tokens = chat_completion.usage.prompt_tokens + chat_completion.usage.completion_tokens
self._update_costs(chat_completion.usage.model_dump())
return chat_completion
except Exception as e:
logger.error(f"Error in Unify chat completion: {str(e)}")
raise

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
try:
stream = self.client.generate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use async_client

messages=messages,
max_tokens=self.config.max_token,
temperature=self.config.temperature,
stream=True,
)
collected_content = []
for chunk in stream:
log_llm_stream(chunk)
collected_content.append(chunk)

full_content = "".join(collected_content)
usage = {
"prompt_tokens": count_message_tokens(messages, self.model),
"completion_tokens": count_message_tokens([{"role": "assistant", "content": full_content}], self.model),
}
self._update_costs(usage)
return full_content
except Exception as e:
logger.error(f"Error in Unify chat completion stream: {str(e)}")
raise

async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)

async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
response = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(response)

def get_model_name(self):
return self.model

def get_usage(self) -> Optional[Dict[str, int]]:
return self.cost_manager.get_latest_usage()
Loading