Skip to content

Commit

Permalink
A new method to the InteractiveDocument class open_question_diversifi…
Browse files Browse the repository at this point in the history
…ed, which takes a question as input and returns a random answer from a set of 10 possible answers. This method can be used to increase the diversity of the answers that the agent provides.

PiperOrigin-RevId: 658034201
Change-Id: Ie765ae56cffb3582d9413cd29d0ba7ddee413d8f
  • Loading branch information
vezhnick authored and copybara-github committed Jul 31, 2024
1 parent dc37406 commit 92231b6
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions concordia/document/interactive_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from collections.abc import Collection, Iterable, Iterator, Sequence
import contextlib
import random
import re

from concordia.document import document
from concordia.language_model import language_model
Expand Down Expand Up @@ -183,6 +185,102 @@ def open_question(
self._response(f'{answer_suffix}\n')
return response

def open_question_diversified(
self,
question: str,
*,
forced_response: str | None = None,
num_samples: int = 10,
max_tokens: int = DEFAULT_MAX_TOKENS,
terminators: Collection[str] = (),
question_label: str = 'Question',
answer_label: str = 'Answer',
) -> str:
"""Asks the agent an open question and appends it to the document.
The agent is asked to provide multiple answers, from which one is selected
randomly. This increases the diversity of the answers.
Args:
question: the question to ask.
forced_response: forces the document to provide this response. The LLM
will not be consulted. If answer_prefix is in the forced response then
remove it.
num_samples: how many samples to generate.
max_tokens: the maximum number of tokens to sample from the model.
terminators: strings that must not be present in the model's response. If
emitted by the model the response will be truncated before them.
Importantly, the truncation is done on the final sample only and does
not affect the intermediate samples.
question_label: the label to use for the question, typically "Question".
answer_label: the label to use for the answer, typically "Answer".
Returns:
The agents truncated response (or `forced_response` is provided).
Raises:
Warning: if the LLM does not generate the expected number of answers.
"""

def truncate_string(s, tr):
"""Truncates a string to the first occurrence of any of the terminators.
Args:
s: The string to truncate.
tr: A set of strings representing the terminators.
Returns:
The truncated string, or the original string if no terminator is
found.
"""

# Find the earliest index where any terminator appears
earliest_index = len(s) # Initialize to the end of the string
for terminator in tr:
index = s.find(terminator)
if index != -1 and index < earliest_index:
earliest_index = index

# Truncate the string if a terminator was found
if earliest_index < len(s):
return s[:earliest_index]
else:
return s

self._question(
f'Task: generate {num_samples} {answer_label}s to the following'
f' {question_label}:\nQuestion: {question}\n'
)
if forced_response is None:
self._response(f'{answer_label}s:\n1. ')
candidates = self._model.sample_text(
prompt=self._model_view.text(),
max_tokens=max_tokens * num_samples,
terminators=[],
)
self.statement(candidates)

candidates = candidates.splitlines()

if len(candidates) != num_samples:
self.debug(
f'LLM generated {len(candidates)} answers instead of {num_samples}'
)
if len(candidates) < 2:
raise Warning(
f'LLM generated only {len(candidates)} initial answers.'
)
candidates = [re.sub(r'^\d+\.\s*', '', line) for line in candidates]
response = random.choice(candidates)
response = truncate_string(response, terminators)

else:
response = forced_response

self._response(f'Final {answer_label}: ')
self._model_response(f'{response}\n')
return response

def multiple_choice_question(
self, question: str, answers: Sequence[str]
) -> int:
Expand Down

0 comments on commit 92231b6

Please sign in to comment.