diff --git a/shared/open-ai-integration/quickstart.mdx b/shared/open-ai-integration/quickstart.mdx index 0ffe7a453..124e0f2d5 100644 --- a/shared/open-ai-integration/quickstart.mdx +++ b/shared/open-ai-integration/quickstart.mdx @@ -22,33 +22,26 @@ The following figure illustrates the integration topology: This guide walks you through the core elements of the [Agora Conversational AI Demo](https://github.com/AgoraIO/agora-openai-converse) integrating Agora's Python SDK with OpenAI's Realtime API: -1. Download the [Agora Conversational AI Demo code](https://github.com/AgoraIO/agora-openai-converse). +1. Create a new folder for the project: -1. The project is structured as follows: + ``` + mkdir realtime-agent + cd realtime-agent/ + ``` + +1. Create the following structure for your project: ``` /realtime_agent - ├── __init__.py - ├── agent.py - ├── agora - │   ├── __init__.py - │   ├── requirements.txt - │   ├── rtc.py - │   └── token_builder - │   ├── AccessToken2.py - │   ├── Packer.py - │   ├── RtcTokenBuilder2.py - │   └── realtimekit_token_builder.py - ├── parse_args.py - └── realtimeapi - ├── __init__.py - ├── call_tool.py - ├── client.py - ├── messages.py - ├── mic_to_websocket.py - ├── push_to_talk.py - ├── send_audio_to_websocket.py - └── util.py + ├── __init__.py + ├── .env + ├── agent.py + ├── requirements.txt + └── realtimeapi + ├── __init__.py + ├── client.py + ├── messages.py + └── util.py ``` @@ -59,17 +52,62 @@ This guide walks you through the core elements of the [Agora Conversational AI D - `agent.py`: The primary script responsible for executing the `RealtimeKitAgent`. It integrates Agora's functionality from the `rtc.py` module and OpenAI's capabilities from the `realtimeapi` package. - `rtc.py`: Contains an implementation of the server-side Agora Python Voice SDK. - - `parse_args.py`: Handles command-line argument parsing for the application. - `realtimeapi/`: Contains the classes and methods that interact with OpenAI's Realtime API. -1. Create the `.env` file by copying the `.env.example` in the root of the repo - ```bash - cp .env.example .env - ``` + The [complete code](#complete-integration-code) for `agent.py` is provided at the bottom of this page. + +1. Add the following dependencies to the `requirements.txt` file: + + ``` + aiohttp[speedups] + annotated-types==0.7.0 + anyio==4.4.0 + attrs==23.2.0 + black==24.4.2 + certifi==2024.7.4 + click==8.1.7 + distro==1.9.0 + frozenlist==1.4.1 + h11==0.14.0 + httpcore==1.0.5 + httpx==0.27.0 + idna==3.7 + iniconfig==2.0.0 + multidict==6.0.5 + mypy==1.10.1 + mypy-extensions==1.0.0 + numpy>=1.21.0 + openai==1.37.1 + packaging==24.1 + pathspec==0.12.1 + platformdirs==4.2.2 + pluggy==1.5.0 + protobuf==5.27.2 + psutil==5.9.8 + pydantic==2.8.2 + pydantic_core==2.20.1 + pyaudio>=0.2.11 + pydub==0.25.1 + pyee==12.0.0 + PyJWT==2.8.0 + pytest==8.2.2 + python-dotenv==1.0.1 + ruff==0.5.2 + sniffio==1.3.1 + sounddevice>=0.4.6 + tqdm==4.66.4 + types-protobuf==4.25.0.20240417 + typing_extensions==4.12.2 + watchfiles==0.22.0 + yarl==1.9.4 + agora-python-server-sdk>=2.0.0 + agora-realtime-ai-api==1.0.2 + ``` + +1. Create the `.env` file and fill in the values for the environment variables: -1. Fill in the values for the environment variables: ```python - # Agora RTC app ID + # Agora RTC app ID and app certificate AGORA_APP_ID= AGORA_APP_CERT= @@ -87,10 +125,12 @@ This guide walks you through the core elements of the [Agora Conversational AI D pip install -r requirements.txt ``` -1. Run the demo server: - ```bash - python -m realtime_agent.agent --channel_name= --uid= - ``` +1. Install Agora realtime API: + + ```bash + pip3 install agora-realtime-ai-api + ``` + ## Implementation @@ -172,7 +212,7 @@ async def setup_and_run_agent( ### Initialize the RealtimeKitAgent -The `RealtimeKitAgent` class constructor accepts an OpenAI `RealtimeApiClient`, an optional `ToolContext` for function registration, and an Agora channel for managing audio communication. This setup initializes the agent to process audio streams, register tools (if provided), and interacts with the AI model. +The `RealtimeKitAgent` class constructor accepts an OpenAI `RealtimeApiClient`, an optional `ToolContext` for function registration, and an Agora channel for managing audio communication. This setup initializes the agent to process audio streams, registers tools (if provided), and interacts with the AI model. ```python def __init__( @@ -393,6 +433,447 @@ until you have received the response to the tool call.\ ) ``` +### Complete integration code + +The `agent.py` imports key classes from `rtc.py`, which implements the server-side Agora Python Voice SDK, facilitating communication and managing audio streams. + +
+Complete code for `agent.py` + +{`import abc +import asyncio +import base64 +import json +import logging +import os +from builtins import anext +from typing import Any, Callable, assert_never + +from agora.rtc.rtc_connection import RTCConnection, RTCConnInfo +from attr import dataclass +from dotenv import load_dotenv +from pydantic import BaseModel + +from realtime_agent.realtimeapi import messages +from realtime_agent.realtimeapi.client import RealtimeApiClient +from realtime_agent.realtimeapi.util import SAMPLE_RATE,CHANNELS + +from .agora.rtc import Channel, ChatMessage, RtcEngine, RtcOptions +from .parse_args import parse_args_realtimekit + +logger = logging.getLogger(__name__) + +async def wait_for_remote_user(channel: Channel) -> int: + remote_users = list(channel.remote_users.keys()) + if len(remote_users) > 0: + return remote_users[0] + + future = asyncio.Future[int]() + + channel.once("user_joined", lambda conn, user_id: future.set_result(user_id)) + + try: + remote_user = await future + return remote_user + except Exception as e: + logger.error(f"Error waiting for remote user: {e}") + raise + +@dataclass(frozen=True, kw_only=True) +class InferenceConfig: + system_message: str | None = None + turn_detection: messages.TurnDetection | None = None # MARK: CHECK! + voice: messages.Voices | None = None + + +@dataclass(frozen=True, kw_only=True) +class LocalFunctionToolDeclaration: + """Declaration of a tool that can be called by the model, and runs a function locally on the tool context.""" + + name: str + description: str + parameters: dict[str, Any] + function: Callable[..., Any] + + def model_description(self) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + +@dataclass(frozen=True, kw_only=True) +class PassThroughFunctionToolDeclaration: + """Declaration of a tool that can be called by the model.""" + + name: str + description: str + parameters: dict[str, Any] + + def model_description(self) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + +ToolDeclaration = LocalFunctionToolDeclaration | PassThroughFunctionToolDeclaration + + +@dataclass(frozen=True, kw_only=True) +class LocalToolCallExecuted: + json_encoded_output: str + + +@dataclass(frozen=True, kw_only=True) +class ShouldPassThroughToolCall: + decoded_function_args: dict[str, Any] + + +ExecuteToolCallResult = LocalToolCallExecuted | ShouldPassThroughToolCall + + +class ToolContext(abc.ABC): + _tool_declarations: dict[str, ToolDeclaration] + + def __init__(self) -> None: + # TODO should be an ordered dict + self._tool_declarations = {} + + def register_function( + self, + *, + name: str, + description: str = "", + parameters: dict[str, Any], + fn: Callable[..., Any], + ) -> None: + self._tool_declarations[name] = LocalFunctionToolDeclaration( + name=name, description=description, parameters=parameters, function=fn + ) + + def register_client_function( + self, + *, + name: str, + description: str = "", + parameters: dict[str, Any], + ) -> None: + self._tool_declarations[name] = PassThroughFunctionToolDeclaration( + name=name, description=description, parameters=parameters + ) + + async def execute_tool( + self, tool_name: str, encoded_function_args: str + ) -> ExecuteToolCallResult | None: + tool = self._tool_declarations.get(tool_name) + if not tool: + return None + + args = json.loads(encoded_function_args) + assert isinstance(args, dict) + + if isinstance(tool, LocalFunctionToolDeclaration): + logger.info(f"Executing tool {tool_name} with args {args}") + result = await tool.function(**args) + logger.info(f"Tool {tool_name} executed with result {result}") + return LocalToolCallExecuted(json_encoded_output=json.dumps(result)) + + if isinstance(tool, PassThroughFunctionToolDeclaration): + return ShouldPassThroughToolCall(decoded_function_args=args) + + assert_never(tool) + + def model_description(self) -> list[dict[str, Any]]: + return [v.model_description() for v in self._tool_declarations.values()] + + +class ClientToolCallResponse(BaseModel): + tool_call_id: str + result: dict[str, Any] | str | float | int | bool | None = None + + +@dataclass(frozen=True, kw_only=True) +class RTCConfigure: + APP_ID: str + TOKEN: str = "" + +class RealtimeKitAgent: + engine: RtcEngine + channel: Channel + client: RealtimeApiClient + audio_queue: asyncio.Queue[bytes] = asyncio.Queue() + message_queue: asyncio.Queue[messages.ResponseAudioTranscriptDelta] = asyncio.Queue() + message_done_queue: asyncio.Queue[messages.ResponseAudioTranscriptDone] = asyncio.Queue() + tools: ToolContext | None = None + + _client_tool_futures: dict[str, asyncio.Future[ClientToolCallResponse]] + + @classmethod + async def setup_and_run_agent( + cls, + *, + engine: RtcEngine, + options: RtcOptions, + inference_config: InferenceConfig, + tools: ToolContext | None, + ) -> None: + channel = engine.create_channel(options) + await channel.connect() + + try: + async with RealtimeApiClient( + base_uri=os.getenv("REALTIME_API_BASE_URI", "wss://api.openai.com"), + api_key=os.getenv("OPENAI_API_KEY"), + verbose=False, + ) as client: + await client.send_message( + messages.SessionUpdate( + session=messages.SessionUpdateParams( + #MARK: check this + turn_detection=inference_config.turn_detection, + tools=tools.model_description() if tools else None, + tool_choice="auto", + instructions=inference_config.system_message, + ) + ) + ) + + [start_session_message, _] = await asyncio.gather( + *[ + anext(client.listen()), + client.send_message( + messages.UpdateConversationConfig( + system_message=inference_config.system_message, + output_audio_format=messages.AudioFormats.PCM16, + voice=inference_config.voice, + tools=tools.model_description() if tools else None, + transcribe_input=False, + ) + ), + ] + ) + # assert isinstance(start_session_message, messages.StartSession) + logger.info( + f"Session started: {start_session_message.session.id} model: {start_session_message.session.model}" + ) + + agent = cls( + client=client, + tools=tools, + channel=channel, + ) + await agent.run() + + finally: + engine.destroy() + + @classmethod + async def entry_point( + cls, + *, + engine: RtcEngine, + options: RtcOptions, + inference_config: InferenceConfig, + tools: ToolContext | None = None, + ) -> None: + await cls.setup_and_run_agent( + engine=engine, options=options, inference_config=inference_config, tools=tools + ) + + def __init__( + self, + *, + client: RealtimeApiClient, + tools: ToolContext | None, + channel: Channel, + ) -> None: + self.client = client + self.tools = tools + self._client_tool_futures = {} + self.channel = channel + self.subscribe_user = None + + async def run(self) -> None: + try: + def log_exception(t: asyncio.Task[Any]) -> None: + if not t.cancelled() and t.exception(): + logger.error( + "unhandled exception", + exc_info=t.exception(), + ) + + logger.info("Waiting for remote user to join") + self.subscribe_user = await wait_for_remote_user(self.channel) + logger.info(f"Subscribing to user {self.subscribe_user}") + await self.channel.subscribe_audio(self.subscribe_user) + + async def on_user_left(agora_rtc_conn: RTCConnection, user_id: int, reason: int): + logger.info(f"User left: {user_id}") + if self.subscribe_user == user_id: + self.subscribe_user = None + logger.info("Subscribed user left, disconnecting") + await self.channel.disconnect() + + self.channel.on("user_left", on_user_left) + + disconnected_future = asyncio.Future[None]() + + def callback(agora_rtc_conn: RTCConnection, conn_info: RTCConnInfo, reason): + logger.info(f"Connection state changed: {conn_info.state}") + if conn_info.state == 1: + if not disconnected_future.done(): + disconnected_future.set_result(None) + + self.channel.on("connection_state_changed", callback) + + asyncio.create_task(self._stream_input_audio_to_model()).add_done_callback( + log_exception + ) + asyncio.create_task( + self._stream_audio_queue_to_audio_output() + ).add_done_callback(log_exception) + + asyncio.create_task(self._process_model_messages()).add_done_callback( + log_exception + ) + + await disconnected_future + logger.info("Agent finished running") + except asyncio.CancelledError: + logger.info("Agent cancelled") + + async def _stream_input_audio_to_model(self) -> None: + while self.subscribe_user is None: + await asyncio.sleep(0.1) + audio_frames = self.channel.get_audio_frames(self.subscribe_user) + async for audio_frame in audio_frames: + try: + # send the frame to the model via the API client + await self.client.send_audio_data(audio_frame.data) + except Exception as e: + logger.error(f"Error sending audio data to model: {e}") + + async def _stream_audio_queue_to_audio_output(self) -> None: + while True: + # audio queue contains audio data from the model, send it the end-user via our local audio source + frame = await self.audio_queue.get() + await self.channel.push_audio_frame(frame) + await asyncio.sleep(0) # allow other tasks to run + + + async def _process_model_messages(self) -> None: + async for message in self.client.listen(): + # logger.info(f"Received message {message=}") + match message: + case messages.ResponseAudioDelta(): + # logger.info("Received audio message") + await self.audio_queue.put(base64.b64decode(message.delta)) + + case messages.ResponseAudioTranscriptDelta(): + logger.info(f"Received text message {message=}") + await self.channel.chat.send_message(ChatMessage(message=message.model_dump_json(), msg_id=message.item_id)) + + case messages.ResponseAudioTranscriptDone(): + logger.info(f"Text message done: {message=}") + await self.channel.chat.send_message(ChatMessage(message=message.model_dump_json(), msg_id=message.item_id)) + + # case messages.MessageAdded(): + # pass + case messages.InputAudioBufferSpeechStarted(): + pass + case messages.InputAudioBufferSpeechStopped(): + pass + # InputAudioBufferCommitted + case messages.InputAudioBufferCommitted(): + pass + # case messages.ServerAddMessage(): + # pass + # ItemCreated + case messages.ItemCreated(): + pass + # ResponseCreated + case messages.ResponseCreated(): + pass + + # ResponseOutputItemAdded + case messages.ResponseOutputItemAdded(): + pass + + # ResponseContenPartAdded + case messages.ResponseContenPartAdded(): + pass + # ResponseAudioDone + case messages.ResponseAudioDone(): + pass + # ResponseContentPartDone + case messages.ResponseContentPartDone(): + pass + # ResponseOutputItemDone + case messages.ResponseOutputItemDone(): + pass + case _: + logger.warning(f"Unhandled message {message=}") + +async def shutdown(loop, signal=None): + """Gracefully shut down the application.""" + if signal: + logger.info(f"Received exit signal {signal.name}...") + + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + + logger.info(f"Cancelling {len(tasks)} outstanding tasks") + for task in tasks: + task.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) + loop.stop() + +if __name__ == "__main__": + load_dotenv() + + options = parse_args_realtimekit() + logger.info(f"app_id: channel_id: {options['channel_name']}, uid: {options['uid']}") + + if not os.environ.get("AGORA_APP_ID") : + raise ValueError("Need to set environment variable AGORA_APP_ID") + + asyncio.run( + RealtimeKitAgent.entry_point( + engine=RtcEngine(appid=os.environ.get("AGORA_APP_ID"), appcert=os.environ.get("AGORA_APP_CERT")), + options=RtcOptions( + channel_name=options['channel_name'], + uid=options['uid'], + sample_rate=SAMPLE_RATE, + channels=CHANNELS + ), + inference_config=InferenceConfig( + system_message="""\ +You are a helpful assistant. If asked about the weather make sure to use the provided tool to get that information. \\ +If you are asked a question that requires a tool, say something like "working on that" and dont provide a concrete response \\ +until you have received the response to the tool call.\\ +""", + voice=messages.Voices.Alloy, + turn_detection=messages.ServerVAD( + threshold=0.5, + prefix_padding_ms=500, + suffix_padding_ms=200, + ), + ), + ) + ) +`} + +
+ ## Test the code 1. **Update the values for** `AGORA_APP_ID`, `AGORA_APP_CERT`, **and** `OPENAI_API_KEY` **in the project's** `.env` **file**. @@ -411,5 +892,6 @@ until you have received the response to the tool call.\ This section contains additional information or links to relevant documentation that complements the current page or explains other aspects of the product. +- Checkout the [Demo project on GitHub](https://github.com/AgoraIO/agora-openai-converse) - [API reference for `rtc.py`](https://api-reference-git-python-voice-implementation-agora-gdxe.vercel.app/voice-sdk/python/rtc-py-api.html) - [Voice calling quickstart (Python)](/voice-calling/get-started/get-started-sdk?platform=python) \ No newline at end of file