From 548b9ac29adbbc75c06d67ea70ebcd48486974b8 Mon Sep 17 00:00:00 2001 From: Nicolas Frank Date: Tue, 3 Sep 2024 17:51:55 +0200 Subject: [PATCH] Small fixes --- CHANGELOG.md | 3 +++ src/scholarag/app/middleware.py | 1 - src/scholarag/app/streaming.py | 18 +++++++++++------- src/scholarag/generative_question_answering.py | 18 +++++++++++------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb0f5f..eb13668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Use OpenAI response_format instead of separators in the prompt. + ## [v0.0.5] ### Changed diff --git a/src/scholarag/app/middleware.py b/src/scholarag/app/middleware.py index 6570225..58473fe 100644 --- a/src/scholarag/app/middleware.py +++ b/src/scholarag/app/middleware.py @@ -265,7 +265,6 @@ async def get_and_set_cache( "path": request.scope["path"], }, } - cached["content"] = cached["content"].split("")[-1] cached["content"] = cached["content"].split("")[-1] diff --git a/src/scholarag/app/streaming.py b/src/scholarag/app/streaming.py index da6c800..896ba9b 100644 --- a/src/scholarag/app/streaming.py +++ b/src/scholarag/app/streaming.py @@ -89,8 +89,8 @@ async def stream_response( contexts=contexts_text, system_prompt=settings.generative.system_prompt.get_secret_value(), ) - # Get the first chunk and resulting parsing - _, parsed = await anext(generator) + + parsed: dict[str, str | bool | list[int]] | GenerativeQAOutput = {} # While the model didn't say if it has answer or not, keep consuming while parsed.get("has_answer") is None: # type: ignore _, parsed = await anext(generator) @@ -111,17 +111,20 @@ async def stream_response( parsed_output = GenerativeQAOutput( has_answer=False, answer="", paragraphs=[] ) - # Else, we stream the tokens (without the 'answer:') + # Else, we stream the tokens. + # First ensure not streaming '"answer":' else: accumulated_text = "" while '"answer":' not in accumulated_text: chunk, _ = await anext(generator) accumulated_text += chunk - # Finally we stream the answer + # Then we stream the answer async for chunk, parsed in generator: # While the answer has not finished streaming we yield the tokens. if parsed.get("answer") is None: # type: ignore yield chunk + # Stop streaming as soon as the answer is complete + # (i.e. don't stream the paragraph ids) else: break @@ -142,7 +145,7 @@ async def stream_response( ) try: - # Extract the final pydantic class + # Extract the final pydantic class (last item in generator) parsed_output = await anext( ( parsed @@ -162,13 +165,13 @@ async def stream_response( "status_code": 404, "code": ErrorCode.NO_ANSWER_FOUND.value, "detail": ( - "The LLM did not manage to answer the question based on the provided contexts." + "The LLM encountered an error when answering the question." ), } } ) try: - # We "raise" the finish_reason + # Finally we "raise" the finish_reason await anext(generator) except RuntimeError as err: finish_reason: str = err.args[0] @@ -189,6 +192,7 @@ async def stream_response( " retriever_k value by 1 or 2 depending of whether you are" " using the reranker or not." ), + "raw_answer": parsed_output.answer, } } ) diff --git a/src/scholarag/generative_question_answering.py b/src/scholarag/generative_question_answering.py index 01778de..de7c096 100644 --- a/src/scholarag/generative_question_answering.py +++ b/src/scholarag/generative_question_answering.py @@ -15,16 +15,14 @@ class GenerativeQAOutput(BaseModel): """Base class for the expected LLM output.""" - has_answer: bool # Here to prevent streaming errors. + has_answer: bool # Here to prevent streaming errors answer: str paragraphs: list[int] -# SOURCES_SEPARATOR = "" -# ERROR_SEPARATOR = "" MESSAGES = [ { - "role": "system", + "role": "system", # This one can be overriden through env var "content": """Given the following extracted parts of a long document and a question, create a final answer with references to the relevant paragraphs. If you don't know the answer, just say that you don't know, don't try to make up an answer, leave the paragraphs as an empty list and set `has_answer` to False. @@ -54,7 +52,7 @@ class GenerativeQAOutput(BaseModel): """, }, { - "role": "user", + "role": "user", # This one cannot be overriden "content": """QUESTION: {question} ========= {summaries} @@ -101,7 +99,7 @@ def run( contexts Contexts to use to answer the question. system_prompt - System prompt for the LLM. Leave None for default + System prompt for the LLM. Leave None for default. Returns ------- @@ -227,7 +225,7 @@ def stream( Yields ------ chunks, parsed - Chunks of the answer, partially parsed json. + Chunks of the answer, (partially) parsed json. Returns ------- @@ -260,12 +258,15 @@ def stream( response_format=GenerativeQAOutput, ) as stream: for event in stream: + # Inbetween chunks we have accumulated text -> skip if isinstance(event, ContentDeltaEvent): continue + # At the end we get the parsed pydantic class if isinstance(event, ContentDoneEvent): if event.parsed is not None: # mypy yield "", event.parsed continue + if isinstance(event, ChunkEvent): # mypy # Only the last chunk contains the usage. if not event.chunk.usage: @@ -352,12 +353,15 @@ async def astream( response_format=GenerativeQAOutput, ) as stream: async for event in stream: + # Inbetween chunks we have accumulated text -> skip if isinstance(event, ContentDeltaEvent): continue + # At the end we get the parsed pydantic class if isinstance(event, ContentDoneEvent): if event.parsed is not None: # mypy yield "", event.parsed continue + if isinstance(event, ChunkEvent): # mypy # Only the last chunk contains the usage. if not event.chunk.usage: