From 8ae8eb2e4b6fa2b74e386cd422c2e1e7d0f05407 Mon Sep 17 00:00:00 2001 From: Nicolas Frank <58003267+WonderPG@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:48:14 +0200 Subject: [PATCH] Structured output (#8) * Remove separators and use OpenAI's response_format instead * Add unit tests * Small fixes * Change default model in app --------- Co-authored-by: Nicolas Frank --- .env.example | 2 +- CHANGELOG.md | 1 + src/scholarag/app/config.py | 6 +- src/scholarag/app/routers/qa.py | 20 +- src/scholarag/app/schemas.py | 2 +- src/scholarag/app/streaming.py | 146 +++-- .../generative_question_answering.py | 443 ++++++------- tests/app/dependencies_overrides.py | 30 +- tests/app/test_qa.py | 136 ++-- tests/test_generative_question_answering.py | 598 +++++++++--------- 10 files changed, 681 insertions(+), 703 deletions(-) diff --git a/.env.example b/.env.example index 9676c3b..2acb915 100644 --- a/.env.example +++ b/.env.example @@ -20,7 +20,7 @@ SCHOLARAG__RETRIEVAL__MAX_LENGTH= SCHOLARAG__GENERATIVE__OPENAI__MODEL= SCHOLARAG__GENERATIVE__OPENAI__TEMPERATURE= SCHOLARAG__GENERATIVE__OPENAI__MAX_TOKENS= -SCHOLARAG__GENERATIVE__PROMPT_TEMPLATE= +SCHOLARAG__GENERATIVE__SYSTEM_PROMPT= SCHOLARAG__METADATA__EXTERNAL_APIS= SCHOLARAG__METADATA__TIMEOUT= diff --git a/CHANGELOG.md b/CHANGELOG.md index 55db25b..8c3d432 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed +- Use OpenAI response_format instead of separators in the prompt. - Switch to cohere reranker v3 and `retriever_k = 500`. ## [v0.0.5] diff --git a/src/scholarag/app/config.py b/src/scholarag/app/config.py index 8107536..813bc6f 100644 --- a/src/scholarag/app/config.py +++ b/src/scholarag/app/config.py @@ -9,7 +9,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self -from scholarag.generative_question_answering import PROMPT_TEMPLATE +from scholarag.generative_question_answering import MESSAGES class SettingsKeycloak(BaseModel): @@ -84,7 +84,7 @@ class SettingsOpenAI(BaseModel): """OpenAI settings.""" token: SecretStr | None = None - model: str = "gpt-3.5-turbo" + model: str = "gpt-4o-mini" temperature: float = 0 max_tokens: int | None = None @@ -95,7 +95,7 @@ class SettingsGenerative(BaseModel): """Generative QA settings.""" openai: SettingsOpenAI = SettingsOpenAI() - prompt_template: SecretStr = SecretStr(PROMPT_TEMPLATE) + system_prompt: SecretStr = SecretStr(MESSAGES[0]["content"]) model_config = ConfigDict(frozen=True) diff --git a/src/scholarag/app/routers/qa.py b/src/scholarag/app/routers/qa.py index ab9e5ec..7316f78 100644 --- a/src/scholarag/app/routers/qa.py +++ b/src/scholarag/app/routers/qa.py @@ -293,7 +293,7 @@ async def generative_qa( **{ "query": request.query, "contexts": contexts_text, - "prompt_template": settings.generative.prompt_template.get_secret_value(), + "system_prompt": settings.generative.system_prompt.get_secret_value(), }, ) try: @@ -325,17 +325,17 @@ async def generative_qa( " its answer. Please decrease the reranker_k or retriever_k value" " by 1 or 2 depending of whether you are using the reranker or not." ), - "raw_answer": answer["raw_answer"], + "answer": answer.answer, }, ) - if answer["paragraphs"] is None or len(answer["paragraphs"]) == 0: + if not answer.has_answer: raise HTTPException( status_code=500, detail={ "code": ErrorCode.NO_ANSWER_FOUND.value, "detail": "The LLM did not provide any source to answer the question.", - "raw_answer": answer["raw_answer"], + "answer": answer.answer, }, ) else: @@ -344,7 +344,7 @@ async def generative_qa( impact_factors = fetched_metadata["get_impact_factors"] abstracts = fetched_metadata["recreate_abstract"] - context_ids: list[int] = answer["paragraphs"] + context_ids: list[int] = answer.paragraphs logger.info("Adding article metadata to the answers") metadata = [] for context_id in context_ids: @@ -375,11 +375,13 @@ async def generative_qa( } ) ) - answer["metadata"] = metadata - - del answer["paragraphs"] + output = { + "answer": answer.answer, + "paragraphs": answer.paragraphs, + "metadata": metadata, + } logger.info(f"Total time to generate a complete answer: {time.time() - start}") - return GenerativeQAResponse(**answer) + return GenerativeQAResponse(**output) @router.post( diff --git a/src/scholarag/app/schemas.py b/src/scholarag/app/schemas.py index 8af920a..5b13087 100644 --- a/src/scholarag/app/schemas.py +++ b/src/scholarag/app/schemas.py @@ -115,5 +115,5 @@ class GenerativeQAResponse(BaseModel): """Response for the generative QA endpoint.""" answer: str | None - raw_answer: str + paragraphs: list[int] metadata: list[ParagraphMetadata] diff --git a/src/scholarag/app/streaming.py b/src/scholarag/app/streaming.py index 742ac21..896ba9b 100644 --- a/src/scholarag/app/streaming.py +++ b/src/scholarag/app/streaming.py @@ -1,7 +1,6 @@ """Utilities to stream openai response.""" import json -from collections import deque from typing import Any, AsyncIterable from httpx import AsyncClient @@ -9,10 +8,10 @@ from scholarag.app.config import Settings from scholarag.app.dependencies import ErrorCode +from scholarag.app.schemas import GenerativeQAResponse from scholarag.document_stores import AsyncBaseSearch from scholarag.generative_question_answering import ( - ERROR_SEPARATOR, - SOURCES_SEPARATOR, + GenerativeQAOutput, GenerativeQAWithSources, ) from scholarag.retrieve_metadata import MetaDataRetriever @@ -84,30 +83,50 @@ async def stream_response( openai_client = AsyncOpenAI(api_key=api_key) qas.client = openai_client - token_queue: deque[str] = deque( - maxlen=4 - ) # Needs to be EXACTLY the number of tokens of the separator. See https://platform.openai.com/tokenizer - # Some black magic. - raw_string = "" try: - async for chunk in qas.astream( + generator = qas.astream( query=query, contexts=contexts_text, - prompt_template=settings.generative.prompt_template.get_secret_value(), - ): - raw_string += chunk - if len(token_queue) == token_queue.maxlen: - queued_text = "".join(list(token_queue)) - if ( - f"{SOURCES_SEPARATOR}:" in queued_text - or ERROR_SEPARATOR in queued_text - ): # Might change if we change the separator. # This condition is subject to change based on the separator we use. - continue - yield token_queue.popleft() - token_queue.append(chunk) - except RuntimeError as e: - # Since finish reason is raised, it has to be recovered as such. - finish_reason = e.args[0] + system_prompt=settings.generative.system_prompt.get_secret_value(), + ) + + parsed: dict[str, str | bool | list[int]] | GenerativeQAOutput = {} + # While the model didn't say if it has answer or not, keep consuming + while parsed.get("has_answer") is None: # type: ignore + _, parsed = await anext(generator) + # If the LLM doesn't know answer, no need to further iterate + if not parsed["has_answer"]: # type: ignore + yield "" + yield json.dumps( + { + "Error": { + "status_code": 404, + "code": ErrorCode.NO_ANSWER_FOUND.value, + "detail": ( + "The LLM did not manage to answer the question based on the provided contexts." + ), + } + } + ) + parsed_output = GenerativeQAOutput( + has_answer=False, answer="", paragraphs=[] + ) + # Else, we stream the tokens. + # First ensure not streaming '"answer":' + else: + accumulated_text = "" + while '"answer":' not in accumulated_text: + chunk, _ = await anext(generator) + accumulated_text += chunk + # Then we stream the answer + async for chunk, parsed in generator: + # While the answer has not finished streaming we yield the tokens. + if parsed.get("answer") is None: # type: ignore + yield chunk + # Stop streaming as soon as the answer is complete + # (i.e. don't stream the paragraph ids) + else: + break # Errors cannot be raised due to the streaming nature of the endpoint. They're yielded as a dict instead, # leaving the charge to the user to catch them. The stream is simply interrupted in case of error. @@ -124,9 +143,40 @@ async def stream_response( } } ) - if not interrupted: - # Post process the output to get citations. - answer = qas._process_raw_output(raw_output=raw_string) + + try: + # Extract the final pydantic class (last item in generator) + parsed_output = await anext( + ( + parsed + async for (_, parsed) in generator + if isinstance(parsed, GenerativeQAOutput) + ) + ) + except StopAsyncIteration: + # If it is not present, we had an issue. + # By default if there was an issue we nullify the potential partial answer. + # Feel free to suggest another approach. + parsed_output = GenerativeQAOutput(has_answer=False, answer="", paragraphs=[]) + yield "" + yield json.dumps( + { + "Error": { + "status_code": 404, + "code": ErrorCode.NO_ANSWER_FOUND.value, + "detail": ( + "The LLM encountered an error when answering the question." + ), + } + } + ) + try: + # Finally we "raise" the finish_reason + await anext(generator) + except RuntimeError as err: + finish_reason: str = err.args[0] + + if not interrupted and parsed_output.has_answer: # Ensure that the model finished yielding. if finish_reason == "length": # Adding a separator before raising the error. @@ -142,22 +192,7 @@ async def stream_response( " retriever_k value by 1 or 2 depending of whether you are" " using the reranker or not." ), - "raw_answer": answer["raw_answer"], - } - } - ) - elif answer["paragraphs"] is None or len(answer["paragraphs"]) == 0: # type: ignore - # Adding a separator before raising the error. - yield "" - yield json.dumps( - { - "Error": { - "status_code": 404, - "code": ErrorCode.NO_ANSWER_FOUND.value, - "detail": ( - "The LLM did not provide any source to answer the question." - ), - "raw_answer": answer["raw_answer"], + "raw_answer": parsed_output.answer, } } ) @@ -165,7 +200,7 @@ async def stream_response( # Adding a separator between the streamed answer and the processed response. yield "" complete_answer = await retrieve_metadata( - answer=answer, + answer=parsed_output, ds_client=ds_client, index_journals=index_journals, index_paragraphs=index_paragraphs, @@ -175,11 +210,11 @@ async def stream_response( indices=indices, scores=scores, ) - yield json.dumps(complete_answer) + yield complete_answer.model_dump_json() async def retrieve_metadata( - answer: dict[str, Any], + answer: GenerativeQAOutput, ds_client: AsyncBaseSearch, index_journals: str | None, index_paragraphs: str, @@ -188,13 +223,13 @@ async def retrieve_metadata( contexts: list[dict[str, Any]], indices: tuple[int, ...], scores: tuple[float, ...] | None, -) -> dict[str, Any]: +) -> GenerativeQAResponse: """Retrieve the metadata and display them nicely. Parameters ---------- answer - Answer generated by the model. + Parsed answer returned by the model. ds_client Document store client. index_journals @@ -216,6 +251,7 @@ async def retrieve_metadata( ------- Nicely formatted answer with metadata. """ + contexts = [contexts[i] for i in answer.paragraphs] fetched_metadata = await metadata_retriever.retrieve_metadata( contexts, ds_client, index_journals, index_paragraphs, httpx_client ) @@ -225,11 +261,8 @@ async def retrieve_metadata( impact_factors = fetched_metadata["get_impact_factors"] abstracts = fetched_metadata["recreate_abstract"] - context_ids: list[int] = answer["paragraphs"] metadata = [] - for context_id in context_ids: - context = contexts[context_id] - + for context, context_id in zip(contexts, answer.paragraphs): metadata.append( { "impact_factor": impact_factors[context.get("journal")], @@ -251,7 +284,10 @@ async def retrieve_metadata( "pubmed_id": context["pubmed_id"], } ) - answer["metadata"] = metadata - del answer["paragraphs"] - return answer + output = { + "answer": answer.answer, + "paragraphs": answer.paragraphs, + "metadata": metadata, + } + return GenerativeQAResponse(**output) diff --git a/src/scholarag/generative_question_answering.py b/src/scholarag/generative_question_answering.py index f044071..448d519 100644 --- a/src/scholarag/generative_question_answering.py +++ b/src/scholarag/generative_question_answering.py @@ -1,53 +1,65 @@ """Generative question answering with sources.""" +import copy import logging -import re from typing import AsyncGenerator, Generator from openai import AsyncOpenAI, OpenAI +from openai.lib.streaming.chat import ChunkEvent, ContentDeltaEvent, ContentDoneEvent from openai.types.completion_usage import CompletionUsage -from pydantic.v1 import BaseModel, Extra, validate_arguments +from pydantic import BaseModel, ConfigDict logger = logging.getLogger(__name__) -SOURCES_SEPARATOR = "" -ERROR_SEPARATOR = "" - -PROMPT_TEMPLATE = """Given the following extracted parts of a long document and a question, create a final answer with references {SOURCES_SEPARATOR}. -If you don't know the answer, just say that you don't know, don't try to make up an answer and start your response with {ERROR_SEPARATOR}. -ALWAYS return a {SOURCES_SEPARATOR} part at the end of your answer, just leave it empty if you don't know the answer. - -QUESTION: Which state/country's law governs the interpretation of the contract? -========= -Content: This Agreement is governed by English law and the parties submit to the exclusive jurisdiction of the English courts in relation to any dispute (contractual or non-contractual) concerning this Agreement save that either party may apply to any court for an injunction or other relief to protect its Intellectual Property Rights. -Source: 28 -Content: No Waiver. Failure or delay in exercising any right or remedy under this Agreement shall not constitute a waiver of such (or any other) right or remedy.\n\n11.7 Severability. The invalidity, illegality or unenforceability of any term (or part of a term) of this Agreement shall not affect the continuation in force of the remainder of the term (if any) and this Agreement.\n\n11.8 No Agency. Except as expressly stated otherwise, nothing in this Agreement shall create an agency, partnership or joint venture of any kind between the parties.\n\n11.9 No Third-Party Beneficiaries. -Source: 30 -Content: (b) if Google believes, in good faith, that the Distributor has violated or caused Google to violate any Anti-Bribery Laws (as defined in Clause 8.5) or that such a violation is reasonably likely to occur, -Source: 4 -========= -FINAL ANSWER: This Agreement is governed by English law. -{SOURCES_SEPARATOR}: 28 - -QUESTION: What did the president say about Michael Jackson? -========= -Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. -Source: 0 -Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. -Source: 24 -Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. -Source: 5 -Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. -Source: 34 -========= -FINAL ANSWER: {ERROR_SEPARATOR} The president did not mention Michael Jackson. -{SOURCES_SEPARATOR}: - -QUESTION: {question} -========= -{summaries} -========= -FINAL ANSWER:""" + +class GenerativeQAOutput(BaseModel): + """Base class for the expected LLM output.""" + + has_answer: bool # Here to prevent streaming errors + answer: str + paragraphs: list[int] + + +MESSAGES = [ + { + "role": "system", # This one can be overriden through env var + "content": """Given the following extracted parts of a long document and a question, create a final answer with references to the relevant paragraphs. + If you don't know the answer, just say that you don't know, don't try to make up an answer, leave the paragraphs as an empty list and set `has_answer` to False. + + QUESTION: Which state/country's law governs the interpretation of the contract? + ========= + Content: This Agreement is governed by English law and the parties submit to the exclusive jurisdiction of the English courts in relation to any dispute (contractual or non-contractual) concerning this Agreement save that either party may apply to any court for an injunction or other relief to protect its Intellectual Property Rights. + Source: 28 + Content: No Waiver. Failure or delay in exercising any right or remedy under this Agreement shall not constitute a waiver of such (or any other) right or remedy.\n\n11.7 Severability. The invalidity, illegality or unenforceability of any term (or part of a term) of this Agreement shall not affect the continuation in force of the remainder of the term (if any) and this Agreement.\n\n11.8 No Agency. Except as expressly stated otherwise, nothing in this Agreement shall create an agency, partnership or joint venture of any kind between the parties.\n\n11.9 No Third-Party Beneficiaries. + Source: 30 + Content: (b) if Google believes, in good faith, that the Distributor has violated or caused Google to violate any Anti-Bribery Laws (as defined in Clause 8.5) or that such a violation is reasonably likely to occur, + Source: 4 + ========= + FINAL ANSWER: {'has_answer': True, 'answer': 'This Agreement is governed by English law.', 'paragraphs': [28]} + + QUESTION: What did the president say about Michael Jackson? + ========= + Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. + Source: 0 + Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. + Source: 24 + Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. + Source: 5 + Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. + Source: 34 + ========= + FINAL ANSWER: {'has_answer': False, 'answer': The president did not mention Michael Jackson., 'paragraphs': []} + """, + }, + { + "role": "user", # This one cannot be overriden + "content": """QUESTION: {question} + ========= + {summaries} + ========= + FINAL ANSWER:""", + }, +] class GenerativeQAWithSources(BaseModel): @@ -66,23 +78,18 @@ class GenerativeQAWithSources(BaseModel): """ client: OpenAI | AsyncOpenAI - model: str = "gpt-3.5-turbo" + model: str = "gpt-4o-mini" temperature: float = 0.0 max_tokens: int | None = None - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - @validate_arguments def run( self, query: str, contexts: list[str], - prompt_template: str = PROMPT_TEMPLATE, - ) -> tuple[dict[str, str | list[str] | int | list[int] | None], str]: + system_prompt: str | None = None, + ) -> tuple[GenerativeQAOutput, str]: """Answer the question given the contexts. Parameters @@ -91,37 +98,33 @@ def run( Question to answer. contexts Contexts to use to answer the question. - prompt_template - Custom prompt used by the LLM + system_prompt + System prompt for the LLM. Leave None for default. Returns ------- - generated_text - Answers to the question (in theory). + generated_text, finish_reason + Answers to the question (in theory), reason for the LLM to stop. """ - # Put the documents in the prompt with the correct formats. + # Put the documents in the prompt with the correct formats docs = self._process_retrieved_contexts(contexts) - prompt = prompt_template.format( - SOURCES_SEPARATOR=SOURCES_SEPARATOR, - ERROR_SEPARATOR=ERROR_SEPARATOR, - question=query, - summaries=docs, + # Deep copying to avoid replacing completely the placeholders + messages = copy.deepcopy(MESSAGES) + if system_prompt: + messages[0]["content"] = system_prompt + messages[1]["content"] = messages[1]["content"].format( + question=query, summaries=docs ) # Run the chain. logger.info("Sending generative reader request.") if isinstance(self.client, OpenAI): - response = self.client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], + response = self.client.beta.chat.completions.parse( + messages=messages, # type: ignore model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, - stream=False, + response_format=GenerativeQAOutput, ) else: raise RuntimeError( @@ -137,16 +140,15 @@ def run( f" {response.usage.completion_tokens}\nTotal tokens:" f" {response.usage.total_tokens}\nFinish reason: {finish_reason}" ) - output = response.choices[0].message.content - return self._process_raw_output(output), finish_reason # type: ignore + output = response.choices[0].message.parsed + return output, finish_reason # type: ignore - @validate_arguments async def arun( self, query: str, contexts: list[str], - prompt_template: str = PROMPT_TEMPLATE, - ) -> tuple[dict[str, str | list[str] | int | list[int] | None], str]: + system_prompt: str | None = None, + ) -> tuple[GenerativeQAOutput, str]: """Answer the question given the contexts. Parameters @@ -155,36 +157,32 @@ async def arun( Question to answer. contexts Contexts to use to answer the question. - prompt_template - Custom prompt used by the LLM + system_prompt + System prompt for the LLM. Leave None for default. Returns ------- - generated_text - Answers to the question (in theory). + generated_text, finish_reason + Answers to the question (in theory), reason for the LLM to stop. """ # Put the documents in the prompt with the correct formats. docs = self._process_retrieved_contexts(contexts) - prompt = prompt_template.format( - SOURCES_SEPARATOR=SOURCES_SEPARATOR, - ERROR_SEPARATOR=ERROR_SEPARATOR, - question=query, - summaries=docs, + messages = copy.deepcopy(MESSAGES) + if system_prompt: + messages[0]["content"] = system_prompt + messages[1]["content"] = messages[1]["content"].format( + question=query, summaries=docs ) + # Run the chain. logger.info("Sending generative reader request.") if isinstance(self.client, AsyncOpenAI): - response = await self.client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], + response = await self.client.beta.chat.completions.parse( + messages=messages, # type: ignore model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, - stream=False, + response_format=GenerativeQAOutput, ) else: raise RuntimeError( @@ -200,16 +198,19 @@ async def arun( f" {response.usage.completion_tokens}\nTotal tokens:" f" {response.usage.total_tokens}\nFinish reason: {finish_reason}" ) - output = response.choices[0].message.content - return self._process_raw_output(output), finish_reason # type: ignore + output = response.choices[0].message.parsed + return output, finish_reason # type: ignore - @validate_arguments def stream( self, query: str, contexts: list[str], - prompt_template: str = PROMPT_TEMPLATE, - ) -> Generator[str, None, str | None]: + system_prompt: str | None = None, + ) -> Generator[ + tuple[str, dict[str, bool | str | list[int]] | GenerativeQAOutput], + None, + str | None, + ]: """Answer the question given the contexts. Parameters @@ -218,61 +219,76 @@ def stream( Question to answer. contexts Contexts to use to answer the question. - prompt_template - Custom prompt used by the LLM + system_prompt + System prompt for the LLM. Leave None for default + + Yields + ------ + chunks, parsed + Chunks of the answer, (partially) parsed json. Returns ------- - generated_text - Answers to the question (in theory). + finish_reason + The reason for the LLM to stop generating. """ # Put the documents in the prompt with the correct formats. docs = self._process_retrieved_contexts(contexts) - prompt = prompt_template.format( - SOURCES_SEPARATOR=SOURCES_SEPARATOR, - ERROR_SEPARATOR=ERROR_SEPARATOR, - question=query, - summaries=docs, + messages = copy.deepcopy(MESSAGES) + if system_prompt: + messages[0]["content"] = system_prompt + messages[1]["content"] = messages[1]["content"].format( + question=query, summaries=docs ) # Run the chain. logger.info("Sending generative reader request.") - if isinstance(self.client, OpenAI): - response = self.client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], - model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - stream=True, - stream_options={"include_usage": True}, - ) - else: + if not isinstance(self.client, OpenAI): raise RuntimeError( "The OpenAI client might be an async one. Ensure that you are using a" - " sync OpenAI client to call run." + " sync OpenAI client to call stream." ) finish_reason = None - for chunk in response: - # Only the last chunk contains the usage. - if not chunk.usage: - # case where a token is streamed - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content - else: - # No usage and no token -> finish reason is there. - finish_reason = chunk.choices[0].finish_reason - else: - logger.info( - "Information about our OpenAI request:\n Input tokens:" - f" {chunk.usage.prompt_tokens}\nOutput tokens:" - f" {chunk.usage.completion_tokens}\nTotal tokens:" - f" {chunk.usage.total_tokens}\nFinish reason: {finish_reason}" - ) + with self.client.beta.chat.completions.stream( + messages=messages, # type: ignore + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + stream_options={"include_usage": True}, + response_format=GenerativeQAOutput, + ) as stream: + for event in stream: + # Inbetween chunks we have accumulated text -> skip + if isinstance(event, ContentDeltaEvent): + continue + # At the end we get the parsed pydantic class + if isinstance(event, ContentDoneEvent): + if event.parsed is not None: # mypy + yield "", event.parsed + continue + + if isinstance(event, ChunkEvent): # mypy + # Only the last chunk contains the usage. + if not event.chunk.usage: + # case where a token is streamed + if event.chunk.choices[0].delta.content: + yield ( # type: ignore + event.chunk.choices[0].delta.content, + event.snapshot.choices[0].message.parsed, + ) + else: + # No usage and no token -> finish reason is there. + # The first chunk might be empty and have no finish reason, + # it will be overriden later anyway. + finish_reason = event.chunk.choices[0].finish_reason + else: + logger.info( + "Information about our OpenAI request:\n Input tokens:" + f" {event.chunk.usage.prompt_tokens}\nOutput tokens:" + f" {event.chunk.usage.completion_tokens}\nTotal tokens:" + f" {event.chunk.usage.total_tokens}\nFinish reason: {finish_reason}" + ) + # In sync generators you can return a value, which will raise StopIteration and the returned # value can be retrieved as such: # ```python @@ -287,13 +303,14 @@ def stream( return finish_reason return None # for mypy - @validate_arguments async def astream( self, query: str, contexts: list[str], - prompt_template: str = PROMPT_TEMPLATE, - ) -> AsyncGenerator[str, None]: + system_prompt: str | None = None, + ) -> AsyncGenerator[ + tuple[str, dict[str, bool | str | list[int]] | GenerativeQAOutput], None + ]: """Answer the question given the contexts. Parameters @@ -302,64 +319,73 @@ async def astream( Question to answer. contexts Contexts to use to answer the question. - prompt_template - Custom prompt used by the LLM + system_prompt + System prompt for the LLM. Leave None for default - Returns - ------- - generated_text - Answers to the question (in theory). + Yields + ------ + chunks, parsed, finish_reason + Answers to the question (in theory), partially parsed json. Final token is the reason for the LLM to stop. """ # Put the documents in the prompt with the correct formats. docs = self._process_retrieved_contexts(contexts) - prompt = prompt_template.format( - SOURCES_SEPARATOR=SOURCES_SEPARATOR, - ERROR_SEPARATOR=ERROR_SEPARATOR, - question=query, - summaries=docs, + messages = copy.deepcopy(MESSAGES) + if system_prompt: + messages[0]["content"] = system_prompt + messages[1]["content"] = messages[1]["content"].format( + question=query, summaries=docs ) # Run the chain. logger.info("Sending generative reader request.") - if isinstance(self.client, AsyncOpenAI): - response = await self.client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], - model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - stream=True, - stream_options={"include_usage": True}, - ) - else: + if not isinstance(self.client, AsyncOpenAI): raise RuntimeError( "The OpenAI client might be a sync one. Ensure that you are using a" - " async AsyncOpenAI client to call arun." + " async AsyncOpenAI client to call astream." ) finish_reason = None - async for chunk in response: - # Only the last chunk contains the usage. - if not chunk.usage: - # case where a token is streamed - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content - else: - # No usage and no token -> finish reason is there. - finish_reason = chunk.choices[0].finish_reason - else: - logger.info( - "Information about our OpenAI request:\n Input tokens:" - f" {chunk.usage.prompt_tokens}\nOutput tokens:" - f" {chunk.usage.completion_tokens}\nTotal tokens:" - f" {chunk.usage.total_tokens}\nFinish reason: {finish_reason}" - ) - # It is considered a syntax error to return in an async iterator. This is a hack to do it anyway. - if finish_reason: - raise RuntimeError(finish_reason) + async with self.client.beta.chat.completions.stream( + messages=messages, # type: ignore + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + stream_options={"include_usage": True}, + response_format=GenerativeQAOutput, + ) as stream: + async for event in stream: + # Inbetween chunks we have accumulated text -> skip + if isinstance(event, ContentDeltaEvent): + continue + # At the end we get the parsed pydantic class + if isinstance(event, ContentDoneEvent): + if event.parsed is not None: # mypy + yield "", event.parsed + continue + + if isinstance(event, ChunkEvent): # mypy + # Only the last chunk contains the usage. + if not event.chunk.usage: + # case where a token is streamed + if event.chunk.choices[0].delta.content: + yield ( # type: ignore + event.chunk.choices[0].delta.content, + event.snapshot.choices[0].message.parsed, + ) + else: + # No usage and no token -> finish reason is there. + # The first chunk might be empty and have no finish reason, + # it will be overriden later anyway. + finish_reason = event.chunk.choices[0].finish_reason + else: + logger.info( + "Information about our OpenAI request:\n Input tokens:" + f" {event.chunk.usage.prompt_tokens}\nOutput tokens:" + f" {event.chunk.usage.completion_tokens}\nTotal tokens:" + f" {event.chunk.usage.total_tokens}\nFinish reason: {finish_reason}" + ) + # It is considered a syntax error to return in an async iterator. This is a hack to do it anyway. + if finish_reason: + raise RuntimeError(finish_reason) @staticmethod def _process_retrieved_contexts(contexts: list[str]) -> str: @@ -382,52 +408,3 @@ def _process_retrieved_contexts(contexts: list[str]) -> str: documents = "\n".join(contexts) return documents - - @staticmethod - def _process_raw_output( - raw_output: str, - ) -> dict[str, str | list[str] | int | list[int] | None]: - """Process raw output. - - Parameters - ---------- - raw_output - Raw output. - - Returns - ------- - dict[str, str | list[str] | int | list[int] | None] - Processed output. - """ - separator_pattern = rf"(?i)\n?{SOURCES_SEPARATOR}:" - answer_source_pair = re.split(separator_pattern, raw_output, maxsplit=1) - - # Case where the output prompt is somewhat well formatted. - if len(answer_source_pair) == 2: - answer, sources = answer_source_pair - source_list = sources.split(",") - # Safety measure, but in theory shouldn't happen - answer = answer.replace(ERROR_SEPARATOR, "") - try: - # Perfect format - formatted_output = { - "answer": answer, - "paragraphs": [ - int(re.search(r"\d+", source).group(0)) # type: ignore - for source in source_list - ], - "raw_answer": raw_output, - } - except AttributeError: - return { - "answer": None, - "paragraphs": None, - "raw_answer": raw_output, - } - else: - return { - "answer": None, - "paragraphs": None, - "raw_answer": raw_output, - } - return formatted_output diff --git a/tests/app/dependencies_overrides.py b/tests/app/dependencies_overrides.py index b603729..7ee3361 100644 --- a/tests/app/dependencies_overrides.py +++ b/tests/app/dependencies_overrides.py @@ -19,7 +19,7 @@ ) from scholarag.app.main import app from scholarag.document_stores import AsyncBaseSearch -from scholarag.generative_question_answering import ERROR_SEPARATOR, SOURCES_SEPARATOR +from scholarag.generative_question_answering import GenerativeQAOutput def override_rts(has_context=True): @@ -105,34 +105,24 @@ def override_generative_qas(has_answer=True, complete_answer=True): FakeQAS.arun.__name__ = "arun" if has_answer: FakeQAS.arun.side_effect = lambda **params: ( - { - "answer": "This is a perfect answer.", - "paragraphs": [0, 1, 2], - "raw_answer": ( - f"This is a perfect answer \n{SOURCES_SEPARATOR}: 0, 1, 2" - ), - }, + GenerativeQAOutput( + has_answer=True, + answer="This is a perfect answer.", + paragraphs=[0, 1, 2], + ), "stop", ) else: if complete_answer: FakeQAS.arun.side_effect = lambda **params: ( - { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"{ERROR_SEPARATOR}I don't know \n{SOURCES_SEPARATOR}:" - ), - }, + GenerativeQAOutput( + has_answer=False, answer="I don't know.", paragraphs=[] + ), "stop", ) else: FakeQAS.arun.side_effect = lambda **params: ( - { - "answer": None, - "paragraphs": None, - "raw_answer": f"{ERROR_SEPARATOR}I don't", - }, + GenerativeQAOutput(has_answer=False, answer="I don't", paragraphs=[]), "length", ) app.dependency_overrides[get_generative_qas] = lambda: FakeQAS diff --git a/tests/app/test_qa.py b/tests/app/test_qa.py index 4678c63..f640033 100644 --- a/tests/app/test_qa.py +++ b/tests/app/test_qa.py @@ -13,9 +13,7 @@ from scholarag.app.routers.qa import GenerativeQAResponse, ParagraphMetadata from scholarag.document_stores import AsyncOpenSearch from scholarag.generative_question_answering import ( - ERROR_SEPARATOR, - SOURCES_SEPARATOR, - GenerativeQAWithSources, + GenerativeQAOutput, ) from app.dependencies_overrides import ( @@ -83,14 +81,11 @@ def test_generative_qa_reranker(app_client, reranker_k, mock_http_calls): _, list_document_ids = override_rts(has_context=True) fakeqas.arun.side_effect = lambda **params: ( - { - "answer": "This is a perfect answer.", - "paragraphs": list(range(reranker_k)), - "raw_answer": ( - f"This is a perfect answer /n{SOURCES_SEPARATOR}:" - f" {', '.join([str(i) for i in range(reranker_k)])}" - ), - }, + GenerativeQAOutput( + has_answer=True, + answer="This is a perfect answer.", + paragraphs=list(range(reranker_k)), + ), None, ) @@ -162,13 +157,10 @@ def test_generative_qa_no_answer_code_2(app_client, mock_http_calls): assert response_body["detail"].keys() == { "code", "detail", - "raw_answer", + "answer", } assert response_body["detail"]["code"] == 2 - assert ( - response_body["detail"]["raw_answer"] - == f"{ERROR_SEPARATOR}I don't know \n{SOURCES_SEPARATOR}:" - ) + assert response_body["detail"]["answer"] == "I don't know." def test_generative_qa_code_6(app_client, mock_http_calls): @@ -188,10 +180,10 @@ def test_generative_qa_code_6(app_client, mock_http_calls): assert response_body["detail"].keys() == { "code", "detail", - "raw_answer", + "answer", } assert response_body["detail"]["code"] == 6 - assert response_body["detail"]["raw_answer"] == f"{ERROR_SEPARATOR}I don't" + assert response_body["detail"]["answer"] == "I don't" def test_generative_qa_context_too_long(app_client, mock_http_calls): @@ -308,60 +300,77 @@ async def get_citation_count(doi, httpx_client): async def streamed_response(**kwargs): response = [ - "This", - " is", - " an", - " amazingly", - " well", - " streamed", - " response", - ".", - " I", - " can", - "'t", - " believe", - " how", - " good", - " it", - " is", + '{"has_answer": ', + "true, ", + '"answer": ', + "This ", + "is ", + "an ", + "amazingly ", + "well ", + "streamed ", + "response", + ". ", + "I ", + "can", + "'t ", + "believe ", + "how ", + "good ", + "it ", + "is", "!", - "\n", - ":", - " ", - "0", - ",", - " ", - "1", - ",", - " ", - "2", + ", ", + '"paragraphs": ', + "[0,1,2]}", ] + parsed = {} for word in response: - yield word + if word == "true, ": + parsed["has_answer"] = True + if word == ", ": + parsed["answer"] = ( + "This is an amazingly well streamed response . I can't believe how good it is !" + ) + if word == "[0,1,2]}": + parsed["paragraphs"] = [0, 1, 2] + yield word, parsed + yield ( + "", + GenerativeQAOutput( + has_answer=True, + answer="This is an amazingly well streamed response . I can't believe how good it is !", + paragraphs=[0, 1, 2], + ), + ) raise RuntimeError("stop") async def streamed_response_no_answer(**kwargs): response = [ - "I", - " don", - r"'t", - " know", - r" \"", - "n", - ":", + '{"has_answer": ', + "false, ", + '"answer": ', + "I ", + "don", + "'t ", + "know.", + ", " '"paragraphs": ', + "[]}", ] + parsed = {} for word in response: - yield word + if word == "false, ": + parsed["has_answer"] = False + if word == "'t ": + parsed["answer"] = "I don't know." + if word == "[]}": + parsed["paragraphs"] = [] + yield word, parsed + yield ( + "", + GenerativeQAOutput(has_answer=False, answer="I don't know.", paragraphs=[]), + ) raise RuntimeError("stop") @@ -373,7 +382,6 @@ async def test_streamed_generative_qa(app_client, redis_fixture, mock_http_calls override_rts(has_context=True) FakeQAS = override_generative_qas() FakeQAS.astream = streamed_response - FakeQAS._process_raw_output = GenerativeQAWithSources._process_raw_output params = {} expected_tokens = ( @@ -407,7 +415,6 @@ async def test_streamed_generative_qa(app_client, redis_fixture, mock_http_calls resp = "".join(resp) assert "" in resp - assert resp.split("")[0].endswith("\n") index_response = resp.split("") response_str = index_response[1] @@ -445,7 +452,6 @@ async def test_streamed_generative_qa_error(app_client, redis_fixture): override_rts(has_context=True) FakeQAS = override_generative_qas() FakeQAS.astream = streamed_response_no_answer - FakeQAS._process_raw_output = GenerativeQAWithSources._process_raw_output params = {} expected_tokens = "" # We expect an empty answer. diff --git a/tests/test_generative_question_answering.py b/tests/test_generative_question_answering.py index a3e3f0f..25ee67d 100644 --- a/tests/test_generative_question_answering.py +++ b/tests/test_generative_question_answering.py @@ -1,13 +1,23 @@ -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest from openai import AsyncOpenAI, OpenAI -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice as RunChoice -from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta + +# from openai.types.chat.chat_completion import +from openai.lib.streaming.chat import ChunkEvent, ContentDoneEvent +from openai.types.chat import ( + ParsedChatCompletion, + ParsedChatCompletionMessage, + ParsedChoice, +) +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, +) from openai.types.completion_usage import CompletionUsage from scholarag.generative_question_answering import ( - SOURCES_SEPARATOR, + GenerativeQAOutput, GenerativeQAWithSources, ) @@ -21,122 +31,70 @@ def test_run(): "I really enjoyed this context.", "That's really an amazing context.", ] - fake_response = ChatCompletion( - id="chatcmpl-9k86ZK0uY65mPo1PvdnueBmD8hFLK", + fake_response = ParsedChatCompletion[GenerativeQAOutput]( + id="chatcmpl-A3L1rdVJUsgqGkDku7H0Lv2cjpEUS", choices=[ - RunChoice( + ParsedChoice[GenerativeQAOutput]( finish_reason="stop", index=0, logprobs=None, - message=ChatCompletionMessage( - content=f"Very nice.\n{SOURCES_SEPARATOR}: 0", + message=ParsedChatCompletionMessage[GenerativeQAOutput]( + content='{"has_answer":true,"answer":"Very nice.","paragraphs":[0]}', + refusal=None, role="assistant", function_call=None, - tool_calls=None, + tool_calls=[], + parsed=GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0] + ), ), ) ], - created=1720781271, - model="bbp-3.5", + created=1725359183, + model="gpt-4o-mini-2024-07-18", object="chat.completion", service_tier=None, - system_fingerprint=None, + system_fingerprint="fp_f33667828e", usage=CompletionUsage( - completion_tokens=87, prompt_tokens=12236, total_tokens=12323 + completion_tokens=107, prompt_tokens=2706, total_tokens=2813 ), ) + # Test with well formated output. - fake_llm.chat.completions.create.return_value = fake_response + fake_llm.beta.chat.completions.parse.return_value = fake_response result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": "Very nice.", - "paragraphs": [0], - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: 0", - } + assert result == GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0] + ) assert finish_reason == "stop" - # Test with a badly returned context that still makes sense. + + # Test multiple sources. Correct format + fake_response.choices[0].finish_reason = None fake_response.choices[ 0 - ].message.content = f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice." - result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice.", - } - assert finish_reason == "stop" - # Test with a badly formatted output but the context is in the source. - fake_response.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice.\nQUESTION: How" - " nice is this context ?" + ].message.content = '{"has_answer":true,"answer":"Very nice.","paragraphs":[0,1]}' + fake_response.choices[0].message.parsed = GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0, 1] ) - fake_response.choices[0].finish_reason = "length" result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very" - " nice.\nQUESTION: How nice is this context ?" - ), - } - assert finish_reason == "length" - # Test with a completely messed up output format - fake_response.choices[ - 0 - ].message.content = ( - "Very nice. This context is very nice.\nQUESTION: How nice is this context ?" + assert result == GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0, 1] ) - result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - "Very nice. This context is very nice.\nQUESTION: How nice is this" - " context ?" - ), - } - assert finish_reason == "length" - # Test multiple sources. Correct format - fake_response.choices[0].finish_reason = None - fake_response.choices[0].message.content = f"Very nice.\n{SOURCES_SEPARATOR}: 0, 1" - result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": "Very nice.", - "paragraphs": [0, 1], - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: 0, 1", - } assert finish_reason is None - # Test multiple sources. Sources returned not number - fake_response.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice., I really enjoyed" - " this context." + # No answer. + fake_response.choices[0].finish_reason = "stop" + fake_response.choices[ + 0 + ].message.content = '{"has_answer":false,"answer":"I dont know","paragraphs":[]}' + fake_response.choices[0].message.parsed = GenerativeQAOutput( + has_answer=False, answer="I dont know.", paragraphs=[] ) result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice., I really" - " enjoyed this context." - ), - } - - # Test multiple sources. Wrong layout but sources still present. - fake_response.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice. I really enjoyed" - " this context.\nQUESTION: How nice is this context ?" + assert result == GenerativeQAOutput( + has_answer=False, answer="I dont know.", paragraphs=[] ) - result, finish_reason = gaq.run(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice. I really" - " enjoyed this context.\nQUESTION: How nice is this context ?" - ), - } + assert finish_reason == "stop" @pytest.mark.asyncio @@ -150,298 +108,306 @@ async def test_arun(): "I really enjoyed this context.", "That's really an amazing context.", ] - create_output.return_value = ChatCompletion( - id="chatcmpl-9k86ZK0uY65mPo1PvdnueBmD8hFLK", + create_output.return_value = ParsedChatCompletion[GenerativeQAOutput]( + id="chatcmpl-A3L1rdVJUsgqGkDku7H0Lv2cjpEUS", choices=[ - RunChoice( + ParsedChoice[GenerativeQAOutput]( finish_reason="stop", index=0, logprobs=None, - message=ChatCompletionMessage( - content=f"Very nice.\n{SOURCES_SEPARATOR}: 0", + message=ParsedChatCompletionMessage[GenerativeQAOutput]( + content='{"has_answer":true,"answer":"Very nice.","paragraphs":[0]}', + refusal=None, role="assistant", function_call=None, - tool_calls=None, + tool_calls=[], + parsed=GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0] + ), ), ) ], - created=1720781271, - model="bbp-3.5", + created=1725359183, + model="gpt-4o-mini-2024-07-18", object="chat.completion", service_tier=None, - system_fingerprint=None, + system_fingerprint="fp_f33667828e", usage=CompletionUsage( - completion_tokens=87, prompt_tokens=12236, total_tokens=12323 + completion_tokens=107, prompt_tokens=2706, total_tokens=2813 ), ) - # Test with well formated output. - fake_llm.chat.completions.create = create_output - result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": "Very nice.", - "paragraphs": [0], - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: 0", - } - assert finish_reason == "stop" - # Test with a badly returned context that still makes sense. - create_output.return_value.choices[ - 0 - ].message.content = f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice." + + # Single source. + fake_llm.beta.chat.completions.parse = create_output result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice.", - } - assert finish_reason == "stop" - # Test with a badly formatted output but the context is in the source. - create_output.return_value.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice.\nQUESTION: How" - " nice is this context ?" + assert result == GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0] ) + assert finish_reason == "stop" + + # Multiple sources. create_output.return_value.choices[0].finish_reason = "length" - result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very" - " nice.\nQUESTION: How nice is this context ?" - ), - } - assert finish_reason == "length" - # Test with a completely messed up output format create_output.return_value.choices[ 0 - ].message.content = ( - "Very nice. This context is very nice.\nQUESTION: How nice is this context ?" + ].message.content = '{"has_answer":true,"answer":"Very nice.","paragraphs":[0,1]}' + create_output.return_value.choices[0].message.parsed = GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0, 1] ) result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - "Very nice. This context is very nice.\nQUESTION: How nice is this" - " context ?" - ), - } + assert result == GenerativeQAOutput( + has_answer=True, answer="Very nice.", paragraphs=[0, 1] + ) assert finish_reason == "length" - # Test multiple sources. Correct format - create_output.return_value.choices[0].finish_reason = None + + # No answer. + create_output.return_value.choices[0].finish_reason = "stop" create_output.return_value.choices[ 0 - ].message.content = f"Very nice.\n{SOURCES_SEPARATOR}: 0, 1" - result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": "Very nice.", - "paragraphs": [0, 1], - "raw_answer": f"Very nice.\n{SOURCES_SEPARATOR}: 0, 1", - } - assert finish_reason is None - - # Test multiple sources. Sources returned not number - create_output.return_value.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice., I really enjoyed" - " this context." + ].message.content = '{"has_answer":false,"answer":"I dont know","paragraphs":[]}' + create_output.return_value.choices[0].message.parsed = GenerativeQAOutput( + has_answer=False, answer="I dont know.", paragraphs=[] ) result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice., I really" - " enjoyed this context." - ), - } - - # Test multiple sources. Wrong layout but sources still present. - create_output.return_value.choices[0].message.content = ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice. I really enjoyed" - " this context.\nQUESTION: How nice is this context ?" + assert result == GenerativeQAOutput( + has_answer=False, answer="I dont know.", paragraphs=[] ) - result, finish_reason = await gaq.arun(query=query, contexts=context) - assert result == { - "answer": None, - "paragraphs": None, - "raw_answer": ( - f"Very nice.\n{SOURCES_SEPARATOR}: This context is very nice. I really" - " enjoyed this context.\nQUESTION: How nice is this context ?" - ), - } + assert finish_reason == "stop" def stream(**kwargs): - base_response = ChatCompletionChunk( - id="chatcmpl-9kBJNLoYybxyGe9pxsZkY3XgZMg8t", - choices=[ - Choice( - delta=ChoiceDelta( - content="", function_call=None, role="assistant", tool_calls=None - ), - finish_reason=None, - index=0, - logprobs=None, - ) - ], - created=1720793597, - model="gpt-ni.colas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=None, + base_response = ChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="chatcmpl-A3NV8ibLOiAzjjYSJtj7qV4fqx1Tc", + choices=[ + Choice( + delta=ChoiceDelta( + content="", + function_call=None, + refusal=None, + role="assistant", + tool_calls=None, + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1725368686, + model="gpt-schola.rag-maxi", + object="chat.completion.chunk", + service_tier=None, + system_fingerprint="fp_5bd87c427a", + usage=None, + ), + snapshot=ParsedChatCompletion[object]( + id="chatcmpl-A3NV8ibLOiAzjjYSJtj7qV4fqx1Tc", + choices=[ + ParsedChoice[object]( + finish_reason="stop", + index=0, + logprobs=None, + message=ParsedChatCompletionMessage[object]( + content="", + refusal=None, + role="assistant", + function_call=None, + tool_calls=None, + parsed=None, + ), + ) + ], + created=1725368686, + model="gpt-4o-mini-2024-07-18", + object="chat.completion", + service_tier=None, + system_fingerprint="fp_5bd87c427a", + usage=None, + ), + ) + to_stream = ( + '{"has_answer": true, "answer": "I am a great answer.", "paragraphs": [0,1]}' ) - to_stream = f"I am a great answer. {SOURCES_SEPARATOR}\n 0, 1" yield base_response for word in to_stream.split(" "): - base_response.choices[0].delta.content = word + " " if word != "1" else word + base_response.chunk.choices[0].delta.content = ( + word + " " if word != "[0,1]}" else word + ) yield base_response - yield ChatCompletionChunk( - id="chatcmpl-9kBOdgQkrwK9yRbx7VXIiBXwTnKea", - choices=[ - Choice( - delta=ChoiceDelta( - content=None, function_call=None, role=None, tool_calls=None - ), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], - created=1720793923, - model="gpt-ni.colas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=None, - ) - yield ChatCompletionChunk( - id="chatcmpl-9kBOdgQkrwK9yRbx7VXIiBXwTnKea", - choices=[], - created=1720793923, - model="gpt-ni.ciolas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=15, prompt_tokens=3935, total_tokens=3950 + + # Chunk containing the finish reasom + yield ChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="chatcmpl-A3OP6fNcncOnNJUcMVVxWaPenNBWl", + choices=[ + Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=None, + ), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + created=1725372156, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier=None, + system_fingerprint="fp_f33667828e", + usage=None, + ), + snapshot=ParsedChatCompletion[object]( + id="chatcmpl-A3OP6fNcncOnNJUcMVVxWaPenNBWl", + choices=[ + ParsedChoice[object]( + finish_reason="stop", + index=0, + logprobs=None, + message=ParsedChatCompletionMessage[object]( + content='{"has_answer": true, "answer": "I am a great answer.", "paragraphs": [0,1]}', + refusal=None, + role="assistant", + function_call=None, + tool_calls=None, + parsed=GenerativeQAOutput( + has_answer=True, + answer="I am a great answer.", + paragraphs=[0, 1], + ), + ), + ) + ], + created=1725372156, + model="gpt-schola.rag-maxi", + object="chat.completion", + service_tier=None, + system_fingerprint="fp_f33667828e", + usage=None, ), ) - -async def astream(**kwargs): - base_response = ChatCompletionChunk( - id="chatcmpl-9kBJNLoYybxyGe9pxsZkY3XgZMg8t", - choices=[ - Choice( - delta=ChoiceDelta( - content="", function_call=None, role="assistant", tool_calls=None - ), - finish_reason=None, - index=0, - logprobs=None, - ) - ], - created=1720793597, - model="gpt-ni.colas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=None, - ) - to_stream = f"I am a great answer. {SOURCES_SEPARATOR}\n 0, 1" - yield base_response - for word in to_stream.split(" "): - base_response.choices[0].delta.content = word + " " if word != "1" else word - yield base_response - yield ChatCompletionChunk( - id="chatcmpl-9kBOdgQkrwK9yRbx7VXIiBXwTnKea", - choices=[ - Choice( - delta=ChoiceDelta( - content=None, function_call=None, role=None, tool_calls=None - ), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], - created=1720793923, - model="gpt-ni.colas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=None, + # Chunk containing the parsed output + yield ContentDoneEvent[GenerativeQAOutput]( + type="content.done", + content='{"has_answer": true, "answer": "I am a great answer.","paragraphs": [0,1]}', + parsed=GenerativeQAOutput( + has_answer=True, answer="I am a great answer.", paragraphs=[0, 1] + ), ) - yield ChatCompletionChunk( - id="chatcmpl-9kBOdgQkrwK9yRbx7VXIiBXwTnKea", - choices=[], - created=1720793923, - model="gpt-ni.ciolas-turbo", - object="chat.completion.chunk", - service_tier=None, - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=15, prompt_tokens=3935, total_tokens=3950 + + # Chunk containing the usage + yield ChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="chatcmpl-A3NZvqwHHblDdW19Vs1RyFd06xRc3", + choices=[], + created=1725368983, + model="gpt-schola.rag-maxi", + object="chat.completion.chunk", + service_tier=None, + system_fingerprint="fp_f33667828e", + usage=CompletionUsage( + completion_tokens=85, prompt_tokens=2677, total_tokens=2762 + ), + ), + snapshot=ParsedChatCompletion[object]( + id="chatcmpl-A3NZvqwHHblDdW19Vs1RyFd06xRc3", + choices=[ + ParsedChoice[object]( + finish_reason="stop", + index=0, + logprobs=None, + message=ParsedChatCompletionMessage[object]( + content='{"has_answer": true, "answer": "I am a great answer.", "paragraphs": [0,1]}', + refusal=None, + role="assistant", + function_call=None, + tool_calls=None, + parsed=GenerativeQAOutput( + has_answer=True, + answer="I am a great answer.", + paragraphs=[0, 1], + ), + ), + ) + ], + created=1725368983, + model="gpt-schola.rag-maxi", + object="chat.completion", + service_tier=None, + system_fingerprint="fp_f33667828e", + usage=CompletionUsage( + completion_tokens=85, prompt_tokens=2677, total_tokens=2762 + ), ), ) +async def astream(**kwargs): + for elem in stream(**kwargs): + yield elem + + def test_stream(): fake_llm = Mock(spec=OpenAI(api_key="assdas")) - gaq = GenerativeQAWithSources(client=fake_llm, model="gpt-ni.colas-turbo") + gaq = GenerativeQAWithSources(client=fake_llm, model="gpt-schola.rag-maxi") query = "How nice is this context ?" context = [ "This context is very nice.", "I really enjoyed this context.", "That's really an amazing context.", ] - fake_llm.chat.completions.create = stream + stream_mock = MagicMock() + stream_mock.__enter__.return_value = stream() + stream_mock.__exit__.return_value = None + + fake_llm.beta.chat.completions.stream.return_value = stream_mock + streamed_gen = gaq.stream(query, context) try: partial_text = "" while True: - partial_text += next(streamed_gen) + chunk, _ = next(streamed_gen) + partial_text += chunk except StopIteration as err: finish_reason = err.value - assert partial_text == f"I am a great answer. {SOURCES_SEPARATOR}\n 0, 1" + assert ( + partial_text + == '{"has_answer": true, "answer": "I am a great answer.", "paragraphs": [0,1]}' + ) assert finish_reason == "stop" @pytest.mark.asyncio async def test_astream(): fake_llm = AsyncMock(spec=AsyncOpenAI(api_key="assdas")) - gaq = GenerativeQAWithSources(client=fake_llm, model="gpt-ni.colas-turbo") + gaq = GenerativeQAWithSources(client=fake_llm, model="gpt-schola.rag-maxi") query = "How nice is this context ?" context = [ "This context is very nice.", "I really enjoyed this context.", "That's really an amazing context.", ] - fake_create = AsyncMock() - fake_create.return_value = astream() - fake_llm.chat.completions.create = fake_create + stream_mock = MagicMock() + stream_mock.__aenter__.return_value = astream() + stream_mock.__aexit__.return_value = None + + # fake_create = AsyncMock() + # fake_create.return_value = astream() + fake_llm.beta.chat.completions.stream.return_value = stream_mock try: partial_text = "" - async for word in gaq.astream(query, context): - partial_text += word + async for chunk, _ in gaq.astream(query, context): + partial_text += chunk except RuntimeError as err: finish_reason = err.args[0] - assert partial_text == f"I am a great answer. {SOURCES_SEPARATOR}\n 0, 1" + assert ( + partial_text + == '{"has_answer": true, "answer": "I am a great answer.", "paragraphs": [0,1]}' + ) assert finish_reason == "stop" - - -@pytest.mark.parametrize( - "raw_input, expected_answer, expected_paragraphs", - [ - (f"Very nice.\n{SOURCES_SEPARATOR.upper()}: 0,1", "Very nice.", [0, 1]), - (f"Very nice.{SOURCES_SEPARATOR.upper()}: 0,1", "Very nice.", [0, 1]), - (f"Very nice.{SOURCES_SEPARATOR.lower()}: 0,1", "Very nice.", [0, 1]), - (f"Very nice.\n{SOURCES_SEPARATOR.lower()}: 0,1", "Very nice.", [0, 1]), - ("Wrongly formatted input.", None, None), - ], -) -def test_process_raw_output(raw_input, expected_answer, expected_paragraphs): - gaq = GenerativeQAWithSources(client=OpenAI(api_key="asaas")) - res = gaq._process_raw_output(raw_input) - assert res["answer"] == expected_answer - assert res["paragraphs"] == expected_paragraphs