Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update FastAPI endpoint to support nodeSources #30

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ export const askQuestions = async (

if (program.framework === "nextjs" || program.frontend) {
if (!program.ui) {
program.ui = getPrefOrDefault("ui");
program.ui = defaults.ui;
}
}

Expand Down
93 changes: 82 additions & 11 deletions templates/types/streaming/fastapi/app/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import List
import json
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from typing import List, Tuple

chat_router = r = APIRouter()

Expand All @@ -19,8 +23,78 @@ class _ChatData(BaseModel):
messages: List[_Message]


class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]

@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
)

@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]


class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]


class VercelStreamResponse(StreamingResponse):
"""
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
Class to convert the response from the chat engine to the streaming format expected by Vercel
"""

TEXT_PREFIX = "0:"
DATA_PREFIX = "2:"
VERCEL_HEADERS = {
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
}

@classmethod
def convert_text(cls, token: str):
return f'{cls.TEXT_PREFIX}"{token}"\n'

@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"

@classmethod
async def event_generator(
cls, request: Request, response: StreamingAgentChatResponse
):
# Yield the text response
async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield cls.convert_text(token)

# Yield the source nodes
yield cls.convert_data(
{
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
}
)

def __init__(self, content: Any, **kwargs):
super().__init__(
content=content,
headers=self.VERCEL_HEADERS,
**kwargs,
)


async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
Expand Down Expand Up @@ -58,13 +132,9 @@ async def chat(

response = await chat_engine.astream_chat(last_message_content, messages)

async def event_generator():
async for token in response.async_response_gen():
if await request.is_disconnected():
break
yield token

return StreamingResponse(event_generator(), media_type="text/plain")
return VercelStreamResponse(
content=VercelStreamResponse.event_generator(request, response)
)


# non-streaming endpoint - delete if not needed
Expand All @@ -77,5 +147,6 @@ async def chat_request(

response = await chat_engine.achat(last_message_content, messages)
return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response)
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
)
Loading