diff --git a/.changeset/eleven-lemons-look.md b/.changeset/eleven-lemons-look.md new file mode 100644 index 00000000..84d3879a --- /dev/null +++ b/.changeset/eleven-lemons-look.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add nodes to the response and support Vercel streaming format diff --git a/questions.ts b/questions.ts index 231be855..d81c9726 100644 --- a/questions.ts +++ b/questions.ts @@ -505,7 +505,7 @@ export const askQuestions = async ( if (program.framework === "nextjs" || program.frontend) { if (!program.ui) { - program.ui = getPrefOrDefault("ui"); + program.ui = defaults.ui; } } diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 2ef7ff1f..80fa7070 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,11 +1,14 @@ -from typing import List from pydantic import BaseModel -from fastapi.responses import StreamingResponse +from typing import List, Any, Optional, Dict, Tuple from fastapi import APIRouter, Depends, HTTPException, Request, status -from llama_index.core.chat_engine.types import BaseChatEngine +from llama_index.core.chat_engine.types import ( + BaseChatEngine, + StreamingAgentChatResponse, +) +from llama_index.core.schema import NodeWithScore from llama_index.core.llms import ChatMessage, MessageRole from app.engine import get_chat_engine -from typing import List, Tuple +from app.api.routers.vercel_response import VercelStreamResponse chat_router = r = APIRouter() @@ -19,8 +22,27 @@ class _ChatData(BaseModel): messages: List[_Message] +class _SourceNodes(BaseModel): + id: str + metadata: Dict[str, Any] + score: Optional[float] + + @classmethod + def from_source_node(cls, source_node: NodeWithScore): + return cls( + id=source_node.node.node_id, + metadata=source_node.node.metadata, + score=source_node.score, + ) + + @classmethod + def from_source_nodes(cls, source_nodes: List[NodeWithScore]): + return [cls.from_source_node(node) for node in source_nodes] + + class _Result(BaseModel): result: _Message + nodes: List[_SourceNodes] async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]: @@ -58,13 +80,25 @@ async def chat( response = await chat_engine.astream_chat(last_message_content, messages) - async def event_generator(): + async def event_generator(request: Request, response: StreamingAgentChatResponse): + # Yield the text response async for token in response.async_response_gen(): + # If client closes connection, stop sending events if await request.is_disconnected(): break - yield token + yield VercelStreamResponse.convert_text(token) + + # Yield the source nodes + yield VercelStreamResponse.convert_data( + { + "nodes": [ + _SourceNodes.from_source_node(node).dict() + for node in response.source_nodes + ] + } + ) - return StreamingResponse(event_generator(), media_type="text/plain") + return VercelStreamResponse(content=event_generator(request, response)) # non-streaming endpoint - delete if not needed @@ -77,5 +111,6 @@ async def chat_request( response = await chat_engine.achat(last_message_content, messages) return _Result( - result=_Message(role=MessageRole.ASSISTANT, content=response.response) + result=_Message(role=MessageRole.ASSISTANT, content=response.response), + nodes=_SourceNodes.from_source_nodes(response.source_nodes), ) diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py new file mode 100644 index 00000000..37392cc9 --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py @@ -0,0 +1,33 @@ +import json +from typing import Any +from fastapi.responses import StreamingResponse + + +class VercelStreamResponse(StreamingResponse): + """ + Class to convert the response from the chat engine to the streaming format expected by Vercel/AI + """ + + TEXT_PREFIX = "0:" + DATA_PREFIX = "2:" + VERCEL_HEADERS = { + "X-Experimental-Stream-Data": "true", + "Content-Type": "text/plain; charset=utf-8", + "Access-Control-Expose-Headers": "X-Experimental-Stream-Data", + } + + @classmethod + def convert_text(cls, token: str): + return f'{cls.TEXT_PREFIX}"{token}"\n' + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}[{data_str}]\n" + + def __init__(self, content: Any, **kwargs): + super().__init__( + content=content, + headers=self.VERCEL_HEADERS, + **kwargs, + )