Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Frank committed Sep 3, 2024
1 parent 63bacab commit eefc78d
Show file tree
Hide file tree
Showing 7 changed files with 514 additions and 503 deletions.
1 change: 1 addition & 0 deletions src/scholarag/app/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ async def get_and_set_cache(
"path": request.scope["path"],
},
}

cached["content"] = cached["content"].split("<bbs_json_error>")[-1]
cached["content"] = cached["content"].split("<bbs_json_data>")[-1]

Expand Down
10 changes: 8 additions & 2 deletions src/scholarag/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async def generative_qa(
**{
"query": request.query,
"contexts": contexts_text,
"messages": settings.generative.system_prompt.get_secret_value(),
"system_prompt": settings.generative.system_prompt.get_secret_value(),
},
)
try:
Expand Down Expand Up @@ -325,6 +325,7 @@ async def generative_qa(
" its answer. Please decrease the reranker_k or retriever_k value"
" by 1 or 2 depending of whether you are using the reranker or not."
),
"answer": answer.answer,
},
)

Expand All @@ -334,6 +335,7 @@ async def generative_qa(
detail={
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": "The LLM did not provide any source to answer the question.",
"answer": answer.answer,
},
)
else:
Expand Down Expand Up @@ -373,7 +375,11 @@ async def generative_qa(
}
)
)
output = {"answer": answer.answer, "paragraphs": answer.paragraphs, "metadata": metadata}
output = {
"answer": answer.answer,
"paragraphs": answer.paragraphs,
"metadata": metadata,
}
logger.info(f"Total time to generate a complete answer: {time.time() - start}")
return GenerativeQAResponse(**output)

Expand Down
70 changes: 41 additions & 29 deletions src/scholarag/app/streaming.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
"""Utilities to stream openai response."""

import json
from collections import deque
from typing import Any, AsyncIterable

from httpx import AsyncClient
from openai import AsyncOpenAI, BadRequestError

from scholarag.app.config import Settings
from scholarag.app.dependencies import ErrorCode
from scholarag.app.schemas import GenerativeQAResponse
from scholarag.document_stores import AsyncBaseSearch
from scholarag.generative_question_answering import (
GenerativeQAWithSources,
GenerativeQAOutput,
GenerativeQAWithSources,
)
from scholarag.app.schemas import GenerativeQAResponse
from scholarag.retrieve_metadata import MetaDataRetriever


Expand Down Expand Up @@ -85,27 +84,33 @@ async def stream_response(
qas.client = openai_client

try:
generator = qas.astream(query=query, contexts=contexts_text, system_prompt=settings.generative.system_prompt.get_secret_value())
generator = qas.astream(
query=query,
contexts=contexts_text,
system_prompt=settings.generative.system_prompt.get_secret_value(),
)
# Get the first chunk and resulting parsing
_, parsed = await anext(generator)
# While the model didn't say if it has answer or not, keep consuming
while parsed.get("has_answer") is None:
while parsed.get("has_answer") is None: # type: ignore
_, parsed = await anext(generator)
# If the LLM doesn't know answer, no need to further iterate
if not parsed["has_answer"]:
yield "<bbs_json_error>"
yield json.dumps(
{
"Error": {
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM did not manage to answer the question based on the provided contexts."
),
}
if not parsed["has_answer"]: # type: ignore
yield "<bbs_json_error>"
yield json.dumps(
{
"Error": {
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM did not manage to answer the question based on the provided contexts."
),
}
)
parsed_output = GenerativeQAOutput(has_answer=False, answer="", paragraphs=[])
}
)
parsed_output = GenerativeQAOutput(
has_answer=False, answer="", paragraphs=[]
)
# Else, we stream the tokens (without the 'answer:')
else:
accumulated_text = ""
Expand All @@ -115,12 +120,11 @@ async def stream_response(
# Finally we stream the answer
async for chunk, parsed in generator:
# While the answer has not finished streaming we yield the tokens.
if parsed.get("answer") is None:
if parsed.get("answer") is None: # type: ignore
yield chunk
else:
break
# Yield the final token
# yield chunk

# Errors cannot be raised due to the streaming nature of the endpoint. They're yielded as a dict instead,
# leaving the charge to the user to catch them. The stream is simply interrupted in case of error.
except BadRequestError as err:
Expand All @@ -137,10 +141,15 @@ async def stream_response(
}
)


try:
# Extract the final pydantic class
parsed_output = await anext((parsed async for (_, parsed) in generator if isinstance(parsed, GenerativeQAOutput)))
parsed_output = await anext(
(
parsed
async for (_, parsed) in generator
if isinstance(parsed, GenerativeQAOutput)
)
)
except StopAsyncIteration:
# If it is not present, we had an issue.
# By default if there was an issue we nullify the potential partial answer.
Expand All @@ -159,10 +168,10 @@ async def stream_response(
}
)
try:
# Last item in generator should be finish reason.
finish_reason = await anext(generator)
except StopAsyncIteration:
finish_reason = None
# We "raise" the finish_reason
await anext(generator)
except RuntimeError as err:
finish_reason: str = err.args[0]

if not interrupted and parsed_output.has_answer:
# Ensure that the model finished yielding.
Expand Down Expand Up @@ -250,7 +259,6 @@ async def retrieve_metadata(

metadata = []
for context, context_id in zip(contexts, answer.paragraphs):

metadata.append(
{
"impact_factor": impact_factors[context.get("journal")],
Expand All @@ -273,5 +281,9 @@ async def retrieve_metadata(
}
)

output = {"answer": answer.answer, "paragraphs": answer.paragraphs, "metadata": metadata}
output = {
"answer": answer.answer,
"paragraphs": answer.paragraphs,
"metadata": metadata,
}
return GenerativeQAResponse(**output)
Loading

0 comments on commit eefc78d

Please sign in to comment.