Skip to content

Commit

Permalink
feat: Update FastAPI endpoint to support nodeSources (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj authored Apr 1, 2024
1 parent 2739714 commit c06d4af
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 9 deletions.
5 changes: 5 additions & 0 deletions .changeset/eleven-lemons-look.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Add nodes to the response and support Vercel streaming format
2 changes: 1 addition & 1 deletion questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ export const askQuestions = async (

if (program.framework === "nextjs" || program.frontend) {
if (!program.ui) {
program.ui = getPrefOrDefault("ui");
program.ui = defaults.ui;
}
}

Expand Down
51 changes: 43 additions & 8 deletions templates/types/streaming/fastapi/app/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from typing import List, Tuple
from app.api.routers.vercel_response import VercelStreamResponse

chat_router = r = APIRouter()

Expand All @@ -19,8 +22,27 @@ class _ChatData(BaseModel):
messages: List[_Message]


class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]

@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
)

@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]


class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]


async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
Expand Down Expand Up @@ -58,13 +80,25 @@ async def chat(

response = await chat_engine.astream_chat(last_message_content, messages)

async def event_generator():
async def event_generator(request: Request, response: StreamingAgentChatResponse):
# Yield the text response
async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield token
yield VercelStreamResponse.convert_text(token)

# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
}
)

return StreamingResponse(event_generator(), media_type="text/plain")
return VercelStreamResponse(content=event_generator(request, response))


# non-streaming endpoint - delete if not needed
Expand All @@ -77,5 +111,6 @@ async def chat_request(

response = await chat_engine.achat(last_message_content, messages)
return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response)
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
from typing import Any
from fastapi.responses import StreamingResponse


class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel/AI
"""

TEXT_PREFIX = "0:"
DATA_PREFIX = "2:"
VERCEL_HEADERS = {
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
}

@classmethod
def convert_text(cls, token: str):
return f'{cls.TEXT_PREFIX}"{token}"\n'

@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"

def __init__(self, content: Any, **kwargs):
super().__init__(
content=content,
headers=self.VERCEL_HEADERS,
**kwargs,
)

0 comments on commit c06d4af

Please sign in to comment.