diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py index 43f4c554..772cc5d4 100644 --- a/concordia/agents/basic_agent.py +++ b/concordia/agents/basic_agent.py @@ -88,6 +88,7 @@ def __init__( self._user_controlled = user_controlled self._update_interval = update_interval + self._conversation_prefix = '' self._under_interrogation = False self._state_lock = threading.Lock() @@ -287,7 +288,22 @@ def get_externality(externality): return output + def observe_latest(self, conversation: str): + # If the prefix is not found then `find` returns -1. + prefix_start_index = conversation.find(self._conversation_prefix) + if prefix_start_index >= 0: + # Get the part of the conversation the agent heard since their last turn. + start_index = prefix_start_index + len(self._conversation_prefix) + conversation_suffix = conversation[start_index:] + # Replace newline characters with commas. + conversation_suffix = conversation_suffix.replace('\n', ', ') + # Observe the new part of the conversation. + self.observe(conversation_suffix) + # Store the whole conversation thus far as the new prefix. + self._conversation_prefix = conversation + def say(self, conversation: str) -> str: + self.observe_latest(conversation) convo_context = ( f'{self._agent_name} is in the following' f' conversation:\n{conversation}\n'