Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Frank committed Sep 3, 2024
1 parent eefc78d commit 548b9ac
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Use OpenAI response_format instead of separators in the prompt.

## [v0.0.5]

### Changed
Expand Down
1 change: 0 additions & 1 deletion src/scholarag/app/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ async def get_and_set_cache(
"path": request.scope["path"],
},
}

cached["content"] = cached["content"].split("<bbs_json_error>")[-1]
cached["content"] = cached["content"].split("<bbs_json_data>")[-1]

Expand Down
18 changes: 11 additions & 7 deletions src/scholarag/app/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ async def stream_response(
contexts=contexts_text,
system_prompt=settings.generative.system_prompt.get_secret_value(),
)
# Get the first chunk and resulting parsing
_, parsed = await anext(generator)

parsed: dict[str, str | bool | list[int]] | GenerativeQAOutput = {}
# While the model didn't say if it has answer or not, keep consuming
while parsed.get("has_answer") is None: # type: ignore
_, parsed = await anext(generator)
Expand All @@ -111,17 +111,20 @@ async def stream_response(
parsed_output = GenerativeQAOutput(
has_answer=False, answer="", paragraphs=[]
)
# Else, we stream the tokens (without the 'answer:')
# Else, we stream the tokens.
# First ensure not streaming '"answer":'
else:
accumulated_text = ""
while '"answer":' not in accumulated_text:
chunk, _ = await anext(generator)
accumulated_text += chunk
# Finally we stream the answer
# Then we stream the answer
async for chunk, parsed in generator:
# While the answer has not finished streaming we yield the tokens.
if parsed.get("answer") is None: # type: ignore
yield chunk
# Stop streaming as soon as the answer is complete
# (i.e. don't stream the paragraph ids)
else:
break

Expand All @@ -142,7 +145,7 @@ async def stream_response(
)

try:
# Extract the final pydantic class
# Extract the final pydantic class (last item in generator)
parsed_output = await anext(
(
parsed
Expand All @@ -162,13 +165,13 @@ async def stream_response(
"status_code": 404,
"code": ErrorCode.NO_ANSWER_FOUND.value,
"detail": (
"The LLM did not manage to answer the question based on the provided contexts."
"The LLM encountered an error when answering the question."
),
}
}
)
try:
# We "raise" the finish_reason
# Finally we "raise" the finish_reason
await anext(generator)
except RuntimeError as err:
finish_reason: str = err.args[0]
Expand All @@ -189,6 +192,7 @@ async def stream_response(
" retriever_k value by 1 or 2 depending of whether you are"
" using the reranker or not."
),
"raw_answer": parsed_output.answer,
}
}
)
Expand Down
18 changes: 11 additions & 7 deletions src/scholarag/generative_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@
class GenerativeQAOutput(BaseModel):
"""Base class for the expected LLM output."""

has_answer: bool # Here to prevent streaming errors.
has_answer: bool # Here to prevent streaming errors
answer: str
paragraphs: list[int]


# SOURCES_SEPARATOR = "<bbs_sources>"
# ERROR_SEPARATOR = "<bbs_error>"
MESSAGES = [
{
"role": "system",
"role": "system", # This one can be overriden through env var
"content": """Given the following extracted parts of a long document and a question, create a final answer with references to the relevant paragraphs.
If you don't know the answer, just say that you don't know, don't try to make up an answer, leave the paragraphs as an empty list and set `has_answer` to False.
Expand Down Expand Up @@ -54,7 +52,7 @@ class GenerativeQAOutput(BaseModel):
""",
},
{
"role": "user",
"role": "user", # This one cannot be overriden
"content": """QUESTION: {question}
=========
{summaries}
Expand Down Expand Up @@ -101,7 +99,7 @@ def run(
contexts
Contexts to use to answer the question.
system_prompt
System prompt for the LLM. Leave None for default
System prompt for the LLM. Leave None for default.
Returns
-------
Expand Down Expand Up @@ -227,7 +225,7 @@ def stream(
Yields
------
chunks, parsed
Chunks of the answer, partially parsed json.
Chunks of the answer, (partially) parsed json.
Returns
-------
Expand Down Expand Up @@ -260,12 +258,15 @@ def stream(
response_format=GenerativeQAOutput,
) as stream:
for event in stream:
# Inbetween chunks we have accumulated text -> skip
if isinstance(event, ContentDeltaEvent):
continue
# At the end we get the parsed pydantic class
if isinstance(event, ContentDoneEvent):
if event.parsed is not None: # mypy
yield "", event.parsed
continue

if isinstance(event, ChunkEvent): # mypy
# Only the last chunk contains the usage.
if not event.chunk.usage:
Expand Down Expand Up @@ -352,12 +353,15 @@ async def astream(
response_format=GenerativeQAOutput,
) as stream:
async for event in stream:
# Inbetween chunks we have accumulated text -> skip
if isinstance(event, ContentDeltaEvent):
continue
# At the end we get the parsed pydantic class
if isinstance(event, ContentDoneEvent):
if event.parsed is not None: # mypy
yield "", event.parsed
continue

if isinstance(event, ChunkEvent): # mypy
# Only the last chunk contains the usage.
if not event.chunk.usage:
Expand Down

0 comments on commit 548b9ac

Please sign in to comment.