Skip to content

Commit

Permalink
Structured output (#8)
Browse files Browse the repository at this point in the history
* Remove separators and use OpenAI's response_format instead

* Add unit tests

* Small fixes

* Change default model in app

---------

Co-authored-by: Nicolas Frank <[email protected]>
  • Loading branch information
WonderPG and Nicolas Frank authored Sep 17, 2024
1 parent 5fa7dd3 commit 8ae8eb2
Show file tree
Hide file tree
Showing 10 changed files with 681 additions and 703 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ SCHOLARAG__RETRIEVAL__MAX_LENGTH=
SCHOLARAG__GENERATIVE__OPENAI__MODEL=
SCHOLARAG__GENERATIVE__OPENAI__TEMPERATURE=
SCHOLARAG__GENERATIVE__OPENAI__MAX_TOKENS=
SCHOLARAG__GENERATIVE__PROMPT_TEMPLATE=
SCHOLARAG__GENERATIVE__SYSTEM_PROMPT=

SCHOLARAG__METADATA__EXTERNAL_APIS=
SCHOLARAG__METADATA__TIMEOUT=
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Changed
- Use OpenAI response_format instead of separators in the prompt.
- Switch to cohere reranker v3 and `retriever_k = 500`.

## [v0.0.5]
Expand Down
6 changes: 3 additions & 3 deletions src/scholarag/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import Self

from scholarag.generative_question_answering import PROMPT_TEMPLATE
from scholarag.generative_question_answering import MESSAGES


class SettingsKeycloak(BaseModel):
Expand Down Expand Up @@ -84,7 +84,7 @@ class SettingsOpenAI(BaseModel):
"""OpenAI settings."""

token: SecretStr | None = None
model: str = "gpt-3.5-turbo"
model: str = "gpt-4o-mini"
temperature: float = 0
max_tokens: int | None = None

Expand All @@ -95,7 +95,7 @@ class SettingsGenerative(BaseModel):
"""Generative QA settings."""

openai: SettingsOpenAI = SettingsOpenAI()
prompt_template: SecretStr = SecretStr(PROMPT_TEMPLATE)
system_prompt: SecretStr = SecretStr(MESSAGES[0]["content"])

model_config = ConfigDict(frozen=True)

Expand Down
20 changes: 11 additions & 9 deletions src/scholarag/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async def generative_qa(
**{
"query": request.query,
"contexts": contexts_text,
"prompt_template": settings.generative.prompt_template.get_secret_value(),
"system_prompt": settings.generative.system_prompt.get_secret_value(),
},
)
try:
Expand Down Expand Up @@ -325,17 +325,17 @@ async def generative_qa(
" its answer. Please decrease the reranker_k or retriever_k value"
" by 1 or 2 depending of whether you are using the reranker or not."
),
"raw_answer": answer["raw_answer"],
"answer": answer.answer,
},
)

if answer["paragraphs"] is None or len(answer["paragraphs"]) == 0:
if not answer.has_answer:
raise HTTPException(
status_code=500,
detail={
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": "The LLM did not provide any source to answer the question.",
"raw_answer": answer["raw_answer"],
"answer": answer.answer,
},
)
else:
Expand All @@ -344,7 +344,7 @@ async def generative_qa(
impact_factors = fetched_metadata["get_impact_factors"]
abstracts = fetched_metadata["recreate_abstract"]

context_ids: list[int] = answer["paragraphs"]
context_ids: list[int] = answer.paragraphs
logger.info("Adding article metadata to the answers")
metadata = []
for context_id in context_ids:
Expand Down Expand Up @@ -375,11 +375,13 @@ async def generative_qa(
}
)
)
answer["metadata"] = metadata

del answer["paragraphs"]
output = {
"answer": answer.answer,
"paragraphs": answer.paragraphs,
"metadata": metadata,
}
logger.info(f"Total time to generate a complete answer: {time.time() - start}")
return GenerativeQAResponse(**answer)
return GenerativeQAResponse(**output)


@router.post(
Expand Down
2 changes: 1 addition & 1 deletion src/scholarag/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ class GenerativeQAResponse(BaseModel):
"""Response for the generative QA endpoint."""

answer: str | None
raw_answer: str
paragraphs: list[int]
metadata: list[ParagraphMetadata]
146 changes: 91 additions & 55 deletions src/scholarag/app/streaming.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Utilities to stream openai response."""

import json
from collections import deque
from typing import Any, AsyncIterable

from httpx import AsyncClient
from openai import AsyncOpenAI, BadRequestError

from scholarag.app.config import Settings
from scholarag.app.dependencies import ErrorCode
from scholarag.app.schemas import GenerativeQAResponse
from scholarag.document_stores import AsyncBaseSearch
from scholarag.generative_question_answering import (
ERROR_SEPARATOR,
SOURCES_SEPARATOR,
GenerativeQAOutput,
GenerativeQAWithSources,
)
from scholarag.retrieve_metadata import MetaDataRetriever
Expand Down Expand Up @@ -84,30 +83,50 @@ async def stream_response(
openai_client = AsyncOpenAI(api_key=api_key)
qas.client = openai_client

token_queue: deque[str] = deque(
maxlen=4
) # Needs to be EXACTLY the number of tokens of the separator. See https://platform.openai.com/tokenizer
# Some black magic.
raw_string = ""
try:
async for chunk in qas.astream(
generator = qas.astream(
query=query,
contexts=contexts_text,
prompt_template=settings.generative.prompt_template.get_secret_value(),
):
raw_string += chunk
if len(token_queue) == token_queue.maxlen:
queued_text = "".join(list(token_queue))
if (
f"{SOURCES_SEPARATOR}:" in queued_text
or ERROR_SEPARATOR in queued_text
): # Might change if we change the separator. # This condition is subject to change based on the separator we use.
continue
yield token_queue.popleft()
token_queue.append(chunk)
except RuntimeError as e:
# Since finish reason is raised, it has to be recovered as such.
finish_reason = e.args[0]
system_prompt=settings.generative.system_prompt.get_secret_value(),
)

parsed: dict[str, str | bool | list[int]] | GenerativeQAOutput = {}
# While the model didn't say if it has answer or not, keep consuming
while parsed.get("has_answer") is None: # type: ignore
_, parsed = await anext(generator)
# If the LLM doesn't know answer, no need to further iterate
if not parsed["has_answer"]: # type: ignore
yield "<bbs_json_error>"
yield json.dumps(
{
"Error": {
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM did not manage to answer the question based on the provided contexts."
),
}
}
)
parsed_output = GenerativeQAOutput(
has_answer=False, answer="", paragraphs=[]
)
# Else, we stream the tokens.
# First ensure not streaming '"answer":'
else:
accumulated_text = ""
while '"answer":' not in accumulated_text:
chunk, _ = await anext(generator)
accumulated_text += chunk
# Then we stream the answer
async for chunk, parsed in generator:
# While the answer has not finished streaming we yield the tokens.
if parsed.get("answer") is None: # type: ignore
yield chunk
# Stop streaming as soon as the answer is complete
# (i.e. don't stream the paragraph ids)
else:
break

# Errors cannot be raised due to the streaming nature of the endpoint. They're yielded as a dict instead,
# leaving the charge to the user to catch them. The stream is simply interrupted in case of error.
Expand All @@ -124,9 +143,40 @@ async def stream_response(
}
}
)
if not interrupted:
# Post process the output to get citations.
answer = qas._process_raw_output(raw_output=raw_string)

try:
# Extract the final pydantic class (last item in generator)
parsed_output = await anext(
(
parsed
async for (_, parsed) in generator
if isinstance(parsed, GenerativeQAOutput)
)
)
except StopAsyncIteration:
# If it is not present, we had an issue.
# By default if there was an issue we nullify the potential partial answer.
# Feel free to suggest another approach.
parsed_output = GenerativeQAOutput(has_answer=False, answer="", paragraphs=[])
yield "<bbs_json_error>"
yield json.dumps(
{
"Error": {
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM encountered an error when answering the question."
),
}
}
)
try:
# Finally we "raise" the finish_reason
await anext(generator)
except RuntimeError as err:
finish_reason: str = err.args[0]

if not interrupted and parsed_output.has_answer:
# Ensure that the model finished yielding.
if finish_reason == "length":
# Adding a separator before raising the error.
Expand All @@ -142,30 +192,15 @@ async def stream_response(
" retriever_k value by 1 or 2 depending of whether you are"
" using the reranker or not."
),
"raw_answer": answer["raw_answer"],
}
}
)
elif answer["paragraphs"] is None or len(answer["paragraphs"]) == 0: # type: ignore
# Adding a separator before raising the error.
yield "<bbs_json_error>"
yield json.dumps(
{
"Error": {
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM did not provide any source to answer the question."
),
"raw_answer": answer["raw_answer"],
"raw_answer": parsed_output.answer,
}
}
)
else:
# Adding a separator between the streamed answer and the processed response.
yield "<bbs_json_data>"
complete_answer = await retrieve_metadata(
answer=answer,
answer=parsed_output,
ds_client=ds_client,
index_journals=index_journals,
index_paragraphs=index_paragraphs,
Expand All @@ -175,11 +210,11 @@ async def stream_response(
indices=indices,
scores=scores,
)
yield json.dumps(complete_answer)
yield complete_answer.model_dump_json()


async def retrieve_metadata(
answer: dict[str, Any],
answer: GenerativeQAOutput,
ds_client: AsyncBaseSearch,
index_journals: str | None,
index_paragraphs: str,
Expand All @@ -188,13 +223,13 @@ async def retrieve_metadata(
contexts: list[dict[str, Any]],
indices: tuple[int, ...],
scores: tuple[float, ...] | None,
) -> dict[str, Any]:
) -> GenerativeQAResponse:
"""Retrieve the metadata and display them nicely.
Parameters
----------
answer
Answer generated by the model.
Parsed answer returned by the model.
ds_client
Document store client.
index_journals
Expand All @@ -216,6 +251,7 @@ async def retrieve_metadata(
-------
Nicely formatted answer with metadata.
"""
contexts = [contexts[i] for i in answer.paragraphs]
fetched_metadata = await metadata_retriever.retrieve_metadata(
contexts, ds_client, index_journals, index_paragraphs, httpx_client
)
Expand All @@ -225,11 +261,8 @@ async def retrieve_metadata(
impact_factors = fetched_metadata["get_impact_factors"]
abstracts = fetched_metadata["recreate_abstract"]

context_ids: list[int] = answer["paragraphs"]
metadata = []
for context_id in context_ids:
context = contexts[context_id]

for context, context_id in zip(contexts, answer.paragraphs):
metadata.append(
{
"impact_factor": impact_factors[context.get("journal")],
Expand All @@ -251,7 +284,10 @@ async def retrieve_metadata(
"pubmed_id": context["pubmed_id"],
}
)
answer["metadata"] = metadata

del answer["paragraphs"]
return answer
output = {
"answer": answer.answer,
"paragraphs": answer.paragraphs,
"metadata": metadata,
}
return GenerativeQAResponse(**output)
Loading

0 comments on commit 8ae8eb2

Please sign in to comment.