From 2aac400dd3ce90678cbf17ad274a97c8a9465ad7 Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Thu, 14 Dec 2023 06:59:15 -0800 Subject: [PATCH] Improving player_status by querying the memory for exact matches to the agents name PiperOrigin-RevId: 590926673 Change-Id: I9817e992cdffb84af5d6b4d6ed1940b65c26b066 --- .../associative_memory/associative_memory.py | 21 +++++++++++++++++++ .../components/game_master/player_status.py | 9 +------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py index 35c9e172..28361648 100644 --- a/concordia/associative_memory/associative_memory.py +++ b/concordia/associative_memory/associative_memory.py @@ -247,6 +247,27 @@ def retrieve_associative( return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time) + def retrieve_by_regex( + self, + regex: str, + add_time: bool = True, + sort_by_time: bool = True, + ): + """Retrieve memories matching a regex. + + Args: + regex: a regex to match + add_time: whether to add time stamp to the output + sort_by_time: whether to sort the result by time + + Returns: + List of strings corresponding to memories + """ + with self._memory_bank_lock: + data = self._memory_bank[self._memory_bank['text'].str.contains(regex)] + + return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time) + def retrieve_recent( self, k: int = 1, diff --git a/concordia/components/game_master/player_status.py b/concordia/components/game_master/player_status.py index 9c365022..f2248784 100644 --- a/concordia/components/game_master/player_status.py +++ b/concordia/components/game_master/player_status.py @@ -71,14 +71,7 @@ def update(self) -> None: self._partial_states = {name: '' for name in self._player_names} per_player_prompt = {} for player_name in self._player_names: - query = f'{player_name}' - mems = ( - '\n'.join( - self._memory.retrieve_associative( - query, k=self._num_memories_to_retrieve, add_time=True) - ) - + '\n' - ) + mems = '\n'.join(self._memory.retrieve_by_regex(player_name)) + '\n' prompt = interactive_document.InteractiveDocument(self._model) prompt.statement(f'Events:\n{mems}') time_now = self._clock_now().strftime('[%d %b %Y %H:%M:%S]')