Skip to content

Commit

Permalink
Make the delimiter symbol used in the formative memories generator co…
Browse files Browse the repository at this point in the history
…nfigurable.

PiperOrigin-RevId: 635413943
Change-Id: I178dcb9c1e8cdc40e78f0d132b77fbac8118cdf5
  • Loading branch information
jzleibo authored and copybara-github committed May 20, 2024
1 parent 0e411ed commit 94c3154
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions concordia/associative_memory/formative_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,22 @@ def __init__(
*,
model: language_model.LanguageModel,
shared_memories: Sequence[str] = (),
delimiter_symbol: str = '\n\n\n',
blank_memory_factory_call: Callable[
[], associative_memory.AssociativeMemory
],
):
"""Initializes the formative memory factory.
Args:
model: the language model to use for generating memories
shared_memories: memories to be added to all agents
delimiter_symbol: the delimiter to use when splitting the generated
episodes
blank_memory_factory_call: a function that returns a new blank memory
"""
self._model = model
self._delimiter_symbol = delimiter_symbol
self._blank_memory_factory_call = blank_memory_factory_call
self._shared_memories = shared_memories

Expand Down Expand Up @@ -192,8 +203,9 @@ def add_memories(
'mention their age at the time the event occurred using language such '
f'as "When {agent_config.name} was 5 years old, they experienced..." . '
'Use past tense. Write no more than three sentences per episode. '
'Separate episodes from one another by the delimiter "\n\n\n". Do not '
'apply any other special formatting besides these delimiters.'
'Separate episodes from one another by the delimiter '
f'"{self._delimiter_symbol}". Do not apply any other '
'special formatting besides these delimiters.'
)
if agent_config.traits:
question += (
Expand All @@ -213,7 +225,7 @@ def add_memories(
terminators=[],
)

episodes = aggregated_result.split('\n\n\n')
episodes = aggregated_result.split(self._delimiter_symbol)

if len(episodes) != len(list(agent_config.formative_ages)):
logger.warning(
Expand Down

0 comments on commit 94c3154

Please sign in to comment.