From 09f3a478a1754ebe4b373e948493fb236b43f1eb Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Wed, 22 May 2024 15:57:30 -0700 Subject: [PATCH] Fix bug in the case where the user produces a conversation by repeatedly calling `agent.say`. Previously the agent would not observe the conversation in that case since the usual way of observing the conversation happens in the conversation scene, which this approach bypasses. The method that uses the conversation scene bypasses 'say', so this won't observe twice. This change just makes the behavior of 'say' closer to the behavior of the conversation scene. PiperOrigin-RevId: 636322219 Change-Id: I90bc997e29b795e40604483caf9070a8df034393 --- concordia/agents/basic_agent.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py index 43f4c554..772cc5d4 100644 --- a/concordia/agents/basic_agent.py +++ b/concordia/agents/basic_agent.py @@ -88,6 +88,7 @@ def __init__( self._user_controlled = user_controlled self._update_interval = update_interval + self._conversation_prefix = '' self._under_interrogation = False self._state_lock = threading.Lock() @@ -287,7 +288,22 @@ def get_externality(externality): return output + def observe_latest(self, conversation: str): + # If the prefix is not found then `find` returns -1. + prefix_start_index = conversation.find(self._conversation_prefix) + if prefix_start_index >= 0: + # Get the part of the conversation the agent heard since their last turn. + start_index = prefix_start_index + len(self._conversation_prefix) + conversation_suffix = conversation[start_index:] + # Replace newline characters with commas. + conversation_suffix = conversation_suffix.replace('\n', ', ') + # Observe the new part of the conversation. + self.observe(conversation_suffix) + # Store the whole conversation thus far as the new prefix. + self._conversation_prefix = conversation + def say(self, conversation: str) -> str: + self.observe_latest(conversation) convo_context = ( f'{self._agent_name} is in the following' f' conversation:\n{conversation}\n'