diff --git a/templates/components/multiagent/python/app/api/routers/vercel_response.py b/templates/components/multiagent/python/app/api/routers/vercel_response.py index fce55230..12082496 100644 --- a/templates/components/multiagent/python/app/api/routers/vercel_response.py +++ b/templates/components/multiagent/python/app/api/routers/vercel_response.py @@ -1,64 +1,50 @@ import json import logging +from abc import ABC from typing import AsyncGenerator, List from aiostream import stream from app.agents.single import AgentRunEvent, AgentRunResult -from app.api.routers.events import EventCallbackHandler from app.api.routers.models import ChatData, Message from app.api.services.suggestion import NextQuestionSuggestion from fastapi import Request from fastapi.responses import StreamingResponse -from llama_index.core.chat_engine.types import StreamingAgentChatResponse logger = logging.getLogger("uvicorn") -class VercelStreamResponse(StreamingResponse): +class VercelStreamResponse(StreamingResponse, ABC): """ - Class to convert the response from the chat engine to the streaming format expected by Vercel + Base class to convert the response from the chat engine to the streaming format expected by Vercel """ TEXT_PREFIX = "0:" DATA_PREFIX = "8:" - @classmethod - def convert_text(cls, token: str): - # Escape newlines and double quotes to avoid breaking the stream - token = json.dumps(token) - return f"{cls.TEXT_PREFIX}{token}\n" + def __init__(self, request: Request, chat_data: ChatData, *args, **kwargs): + self.request = request - @classmethod - def convert_data(cls, data: dict): - data_str = json.dumps(data) - return f"{cls.DATA_PREFIX}[{data_str}]\n" + stream = self._create_stream(request, chat_data, *args, **kwargs) + content = self.content_generator(stream) - def __init__( - self, - request: Request, - event_handler: EventCallbackHandler, - response: StreamingAgentChatResponse, - chat_data: ChatData, - ): - content = VercelStreamResponse.content_generator( - request, event_handler, response, chat_data - ) super().__init__(content=content) - @staticmethod - async def _generate_next_questions(chat_history: List[Message], response: str): - questions = await NextQuestionSuggestion.suggest_next_questions( - chat_history, response - ) - if questions: - return { - "type": "suggested_questions", - "data": questions, - } - return None + async def content_generator(self, stream): + is_stream_started = False - @classmethod - def content_generator( + async with stream.stream() as streamer: + async for output in streamer: + if not is_stream_started: + is_stream_started = True + # Stream a blank message to start the stream + yield self.convert_text("") + + yield output + + if await self.request.is_disconnected(): + break + + def _create_stream( self, request: Request, chat_data: ChatData, @@ -108,3 +94,26 @@ def _event_to_response(event: AgentRunEvent) -> dict: "type": "agent", "data": {"agent": event.name, "text": event.msg}, } + + @classmethod + def convert_text(cls, token: str): + # Escape newlines and double quotes to avoid breaking the stream + token = json.dumps(token) + return f"{cls.TEXT_PREFIX}{token}\n" + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}[{data_str}]\n" + + @staticmethod + async def _generate_next_questions(chat_history: List[Message], response: str): + questions = await NextQuestionSuggestion.suggest_next_questions( + chat_history, response + ) + if questions: + return { + "type": "suggested_questions", + "data": questions, + } + return None