Skip to content

Commit

Permalink
Improving player_status by querying the memory for exact matches to t…
Browse files Browse the repository at this point in the history
…he agents name

PiperOrigin-RevId: 590926673
Change-Id: I9817e992cdffb84af5d6b4d6ed1940b65c26b066
  • Loading branch information
vezhnick authored and copybara-github committed Dec 14, 2023
1 parent c5e6260 commit 2aac400
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
21 changes: 21 additions & 0 deletions concordia/associative_memory/associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,27 @@ def retrieve_associative(

return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time)

def retrieve_by_regex(
self,
regex: str,
add_time: bool = True,
sort_by_time: bool = True,
):
"""Retrieve memories matching a regex.
Args:
regex: a regex to match
add_time: whether to add time stamp to the output
sort_by_time: whether to sort the result by time
Returns:
List of strings corresponding to memories
"""
with self._memory_bank_lock:
data = self._memory_bank[self._memory_bank['text'].str.contains(regex)]

return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time)

def retrieve_recent(
self,
k: int = 1,
Expand Down
9 changes: 1 addition & 8 deletions concordia/components/game_master/player_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,7 @@ def update(self) -> None:
self._partial_states = {name: '' for name in self._player_names}
per_player_prompt = {}
for player_name in self._player_names:
query = f'{player_name}'
mems = (
'\n'.join(
self._memory.retrieve_associative(
query, k=self._num_memories_to_retrieve, add_time=True)
)
+ '\n'
)
mems = '\n'.join(self._memory.retrieve_by_regex(player_name)) + '\n'
prompt = interactive_document.InteractiveDocument(self._model)
prompt.statement(f'Events:\n{mems}')
time_now = self._clock_now().strftime('[%d %b %Y %H:%M:%S]')
Expand Down

0 comments on commit 2aac400

Please sign in to comment.