From b6b9e7ef2b8ca00fa01deb9189510dfd2d241bc3 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Mon, 29 Jul 2024 04:53:57 -0700 Subject: [PATCH] Improve importance model. The importance model now can be contextualized by retrieving other memories to provide scale for the decision of how to assign importance to each new memory. This change also removes a parameter on the formative memories generator which overrode the importance value assigned by memory.add. If that behavior is desirable, a more consistent way to achieve it would be to use a Constant importance model. PiperOrigin-RevId: 657156830 Change-Id: Icf5425fed4c3519d80c50ba6f4b10ede7df84865 --- .../associative_memory/associative_memory.py | 104 ++++++++++++++++-- .../associative_memory/formative_memories.py | 10 +- .../associative_memory/importance_function.py | 94 +++++++++++----- .../agent/dialectical_reflection.py | 12 +- concordia/components/agent/plan.py | 8 +- concordia/components/agent/reflection.py | 5 +- concordia/environment/game_master.py | 2 +- 7 files changed, 173 insertions(+), 62 deletions(-) diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py index 93bce0aa..108324df 100644 --- a/concordia/associative_memory/associative_memory.py +++ b/concordia/associative_memory/associative_memory.py @@ -20,14 +20,17 @@ preprint arXiv:2304.03442. """ -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence import datetime +import random import threading from concordia.associative_memory import importance_function import numpy as np import pandas as pd +_NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE = 25 + def _check_date_in_range(timestamp: datetime.datetime) -> None: if timestamp < pd.Timestamp.min: @@ -44,9 +47,11 @@ class AssociativeMemory: def __init__( self, sentence_embedder: Callable[[str], np.ndarray], - importance: Callable[[str], float] | None = None, + importance: Callable[[str, Sequence[tuple[str, float]]], + float] | None = None, clock: Callable[[], datetime.datetime] = datetime.datetime.now, clock_step_size: datetime.timedelta | None = None, + seed: int | None = None, ): """Constructor. @@ -57,9 +62,17 @@ def __init__( clock: a callable to get time when adding memories clock_step_size: sets the step size of the clock. If None, assumes precise time + seed: the seed to use for the random number generator if None then use the + current time """ self._memory_bank_lock = threading.Lock() + if seed is None: + self._seed = random.seed(int(datetime.datetime.now().timestamp())) + else: + self._seed = seed self._embedder = sentence_embedder + self._num_to_retrieve_to_contextualize_importance = ( + _NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE) self._importance = ( importance or importance_function.ConstantImportanceModel().importance) @@ -77,7 +90,7 @@ def add( timestamp: datetime.datetime | None = None, tags: Iterable[str] = (), importance: float | None = None, - ): + ) -> None: """Adds nonduplicated entries (time, text, tags, importance) to the memory. Args: @@ -87,7 +100,13 @@ def add( importance: optionally set the importance of the memory. """ if importance is None: - importance = self._importance(text) + with self._memory_bank_lock: + memory_size = len(self._memory_bank) + num_to_retrieve = self._num_to_retrieve_to_contextualize_importance + if memory_size < num_to_retrieve: + num_to_retrieve = memory_size + context = self.retrieve_random_with_importance(k=num_to_retrieve) + importance = self._importance(text, context) if timestamp is None: timestamp = self._clock_now() @@ -119,7 +138,7 @@ def extend( self, texts: Iterable[str], **kwargs, - ): + ) -> None: """Adds the texts to the memory. Args: @@ -129,7 +148,7 @@ def extend( for text in texts: self.add(text, **kwargs) - def get_data_frame(self): + def get_data_frame(self) -> pd.DataFrame: with self._memory_bank_lock: return self._memory_bank.copy() @@ -202,7 +221,7 @@ def _pd_to_text( data: pd.DataFrame, add_time: bool = False, sort_by_time: bool = True, - ): + ) -> Sequence[str]: """Formats a dataframe into list of strings. Args: @@ -240,7 +259,7 @@ def retrieve_associative( use_importance: bool = True, add_time: bool = True, sort_by_time: bool = True, - ): + ) -> Sequence[str]: """Retrieve memories associatively. Args: @@ -270,7 +289,7 @@ def retrieve_by_regex( regex: str, add_time: bool = True, sort_by_time: bool = True, - ): + ) -> Sequence[str]: """Retrieve memories matching a regex. Args: @@ -291,7 +310,7 @@ def retrieve_time_interval( time_from: datetime.datetime, time_until: datetime.datetime, add_time: bool = False, - ): + ) -> Sequence[str]: """Retrieve memories within a time interval. Args: @@ -315,7 +334,7 @@ def retrieve_recent( self, k: int = 1, add_time: bool = False, - ): + ) -> Sequence[str]: """Retrieve memories by recency. Args: @@ -333,7 +352,7 @@ def retrieve_recent_with_importance( self, k: int = 1, add_time: bool = False, - ): + ) -> tuple[Sequence[str], Sequence[float]]: """Retrieve memories by recency and return importance alongside. Args: @@ -350,6 +369,40 @@ def retrieve_recent_with_importance( list(data['importance']), ) + def retrieve_random( + self, + k: int = 1, + add_time: bool = False, + ) -> Sequence[str]: + """Retrieve random memories. + + Args: + k: number of entries to retrieve + add_time: whether to add time stamp to the output + + Returns: + List of strings corresponding to memories + """ + with self._memory_bank_lock: + data = self._memory_bank.sample(k, random_state=self._seed) + return self._pd_to_text(data, add_time=add_time, sort_by_time=True) + + def retrieve_random_with_importance( + self, + k: int = 1, + ) -> Sequence[tuple[str, float]]: + """Retrieve random memories and return importance alongside. + + Args: + k: number of entries to retrieve + + Returns: + List of tuples of (memory, importance) + """ + with self._memory_bank_lock: + data = self._memory_bank.sample(k, random_state=self._seed) + return tuple(zip(list(data['text']), list(data['importance']))) + def __len__(self): """Returns the number of entries in the memory bank. @@ -358,3 +411,30 @@ def __len__(self): """ with self._memory_bank_lock: return len(self._memory_bank) + + def get_mean_importance(self) -> float: + """Returns the mean importance of the memories in the memory bank.""" + with self._memory_bank_lock: + return self._memory_bank['importance'].mean() + + def get_max_importance(self) -> float: + """Returns the max importance of the memories in the memory bank.""" + with self._memory_bank_lock: + return self._memory_bank['importance'].max() + + def get_min_importance(self) -> float: + """Returns the min importance of the memories in the memory bank.""" + with self._memory_bank_lock: + return self._memory_bank['importance'].min() + + def set_num_to_retrieve_to_contextualize_importance( + self, num_to_retrieve: int) -> None: + """Sets the number of memories to retrieve for contextualizing importance. + + Set this to 0 if you want to disable contextualization of importance. + + Args: + num_to_retrieve: the number of memories to retrieve for contextualizing + importance. + """ + self._num_to_retrieve_to_contextualize_importance = num_to_retrieve diff --git a/concordia/associative_memory/formative_memories.py b/concordia/associative_memory/formative_memories.py index 0362f3ed..f98d7a0f 100644 --- a/concordia/associative_memory/formative_memories.py +++ b/concordia/associative_memory/formative_memories.py @@ -22,7 +22,6 @@ import re from typing import Any from concordia.associative_memory import associative_memory -from concordia.associative_memory import importance_function from concordia.document import interactive_document from concordia.language_model import language_model from dateutil.relativedelta import relativedelta # pylint: disable=g-importing-member @@ -31,7 +30,6 @@ DEFAULT_DOB = datetime.datetime(year=1984, month=7, day=3, hour=0, minute=0) DEFAULT_FORMATIVE_AGES = (6, 9, 13, 16, 19, 21, 23) -DEFAULT_IMPORTANT_MODEL = importance_function.ConstantImportanceModel() @dataclasses.dataclass(frozen=True, kw_only=True) @@ -49,7 +47,6 @@ class AgentConfig: goal: defines agents goal. Can be left blank if not used. date_of_birth: the date of birth for the agent. formative_ages: ages at which the formative episodes will be created - formative_memory_importance: the importance value of formative memories. extras: a field for the user to keep any experiment specific data they need to define an agent """ @@ -62,7 +59,6 @@ class AgentConfig: goal: str = '' date_of_birth: datetime.datetime = DEFAULT_DOB formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES - formative_memory_importance: float = 1.0 extras: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -229,7 +225,6 @@ def add_memories( tags=['episode'], timestamp=( agent_config.date_of_birth + relativedelta(years=episode_age)), - importance=agent_config.formative_memory_importance, ) if self._current_date: @@ -238,7 +233,6 @@ def add_memories( f'{agent_config.name} is {age} years old.', tags=['info'], timestamp=self._current_date, - importance=agent_config.formative_memory_importance, ) def make_memories( @@ -262,12 +256,12 @@ def make_memories( context_items = context.split('\n') for item in context_items: if item: - mem.add(item, importance=agent_config.formative_memory_importance) + mem.add(item) if agent_config.specific_memories: specific_memories = agent_config.specific_memories.split('\n') for item in specific_memories: if item: - mem.add(item, importance=agent_config.formative_memory_importance) + mem.add(item) return mem diff --git a/concordia/associative_memory/importance_function.py b/concordia/associative_memory/importance_function.py index 8b8e3f0f..8b28f6ec 100644 --- a/concordia/associative_memory/importance_function.py +++ b/concordia/associative_memory/importance_function.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Memory importance function.""" import abc @@ -29,11 +28,16 @@ class ImportanceModel(metaclass=abc.ABCMeta): """Memory importance module for generative agents.""" @abc.abstractmethod - def importance(self, memory: str) -> float: + def importance(self, + memory: str, + context: Sequence[tuple[str, float]] = ()) -> float: """Computes importance of a memory. Args: memory: a memory (text) to compute importance of + context: a sequence of tuples of (old memory (str), importance + (float between 0 and 1)) used to provide context and relative scale for + the decision of the importance of the new memory. Returns: Value of importance in the [0,1] interval @@ -62,29 +66,43 @@ def __init__( self._model = model self._importance_scale = [str(i) for i in sorted(importance_scale)] - def importance(self, memory: str) -> float: + def importance( + self, + memory: str, + context: Sequence[tuple[str, float]] = ()) -> float: """Computes importance of a memory by quering LLM. Args: memory: memory to compute importance of + context: a sequence of tuples of (old memory (str), importance + (float between 0 and 1)) used to provide context and relative scale for + the decision of the importance of the new memory. Returns: Value of importance in the [0,1] interval """ zero, *_, one = self._importance_scale prompt = interactive_document.InteractiveDocument(self._model) - action = prompt.multiple_choice_question( - f"On the scale of {zero} to" - f" {one}, where {zero} is" - " purely mundane (e.g., brushing teeth, making bed) and" - f" {one} is extremely poignant (e.g., a break" - " up, college acceptance), rate the likely poignancy of the following" - " piece of memory.\nMemory:" + if context: + context_string = '\n'.join( + f'{context[0]} -- how memorable: {context[1]}' + for context in context) + prompt.statement(context_string) + question = ( + f'on a scale from {zero} to' + f' {one}, where {zero} is' + ' entirely mundane (e.g., brushing teeth, making bed) and' + f' {one} is extremely poignant (e.g., a breakup of a romantic ' + 'relationship, college acceptance, a wedding), rate the likely ' + 'memorableness of the following new memory.\nMemory:' + memory - + "\nRating: ", - answers=self._importance_scale, - ) - return action / (len(self._importance_scale) - 1) + + '\nRating: ') + if context is not None: + question = ( + f'{context}\nRelative to the life memories above, {question}') + action = prompt.multiple_choice_question( + question=question, answers=self._importance_scale) + return action / len(self._importance_scale) class GMImportanceModel(ImportanceModel): @@ -107,29 +125,43 @@ def __init__( self._model = model self._importance_scale = [str(i) for i in sorted(importance_scale)] - def importance(self, memory: str) -> float: + def importance(self, + memory: str, + context: Sequence[tuple[str, float]] = ()) -> float: """Computes importance of a memory by quering LLM. Args: memory: memory to compute importance of + context: a sequence of tuples of (old memory (str), importance + (float between 0 and 1)) used to provide context and relative scale for + the decision of the importance of the new memory. Returns: Value of importance """ zero, *_, one = self._importance_scale chain_of_thought = interactive_document.InteractiveDocument(self._model) - action = chain_of_thought.multiple_choice_question( - f"On the scale of {zero} to " - f"{one}, where {zero} is purely mundane " - f"(e.g., wind blowing, bus arriving) and {one} is " - "extremely poignant (e.g., an earthquake, end of war, " - "revolution), rate the likely poignancy of the " - "following event.\nEvent:" + if context: + context_string = '\n'.join( + f'{context[0]} -- likely importance to the plot: {context[1]}' + for context in context) + chain_of_thought.statement(context_string) + question = ( + f'You are the game master of a tabletop role-playing game. On a ' + f'scale from {zero} to ' + f'{one}, where {zero} is purely mundane ' + f'(e.g., wind blowing, bus arriving) and {one} is ' + 'extremely important (e.g., an earthquake, ' + 'the end of a war, a revolution), rate the likely importance of the ' + 'following event for advancing the overall plot.\nEvent:' + memory - + "\nRating: ", - answers=self._importance_scale, - ) - return action / (len(self._importance_scale) - 1) + + '\nRating: ') + if context is not None: + question = ( + f'{context}\nRelative to the life memories above, {question}') + action = chain_of_thought.multiple_choice_question( + question=question, answers=self._importance_scale) + return action / len(self._importance_scale) class ConstantImportanceModel(ImportanceModel): @@ -149,15 +181,19 @@ def __init__( """ self._fixed_importance = fixed_importance - def importance(self, memory: str) -> float: - """Computes importance of a memory by quering LLM. + def importance(self, + memory: str, + context: Sequence[tuple[str, float]] = ()) -> float: + """Computes importance of a memory by querying the LLM. Args: memory: memory to compute importance of + context: unused Returns: Value of importance """ - del memory + del memory, context return self._fixed_importance + diff --git a/concordia/components/agent/dialectical_reflection.py b/concordia/components/agent/dialectical_reflection.py index f6833a3b..6f568e12 100644 --- a/concordia/components/agent/dialectical_reflection.py +++ b/concordia/components/agent/dialectical_reflection.py @@ -105,27 +105,27 @@ def update(self) -> None: old_state = self._state # The following query looks for conversations using the fact that their # observations are preceded by ' -- "'. - prethoughts = self._memory.retrieve_associative( + prethoughts = list(self._memory.retrieve_associative( ' -- "', self._num_memories_to_retrieve, use_recency=True, add_time=False - ) + )) # The following query looks for memories of reading and learning. - prethoughts += self._memory.retrieve_associative( + prethoughts += list(self._memory.retrieve_associative( 'book, article, read, idea, concept, study, learn, research, theory', k=self._num_memories_to_retrieve, use_recency=False, add_time=False, - ) + )) if self._topic_component: - prethoughts += self._memory.retrieve_associative( + prethoughts += list(self._memory.retrieve_associative( self._topic_component.state(), k=self._num_memories_to_retrieve, use_recency=False, add_time=False, - ) + )) prethoughts = '-' + '\n-'.join(prethoughts) + '\n' diff --git a/concordia/components/agent/plan.py b/concordia/components/agent/plan.py index 6c95472f..963b4002 100644 --- a/concordia/components/agent/plan.py +++ b/concordia/components/agent/plan.py @@ -102,19 +102,19 @@ def update(self): observation = '\n'.join(self._last_observation) self._last_observation = [] - memories = self._memory.retrieve_associative( + memories = list(self._memory.retrieve_associative( observation, k=self._num_memories_to_retrieve, use_recency=True, add_time=True, - ) + )) if self._goal_component: - memories = memories + self._memory.retrieve_associative( + memories = memories + list(self._memory.retrieve_associative( self._goal_component.state(), k=self._num_memories_to_retrieve, use_recency=True, add_time=True, - ) + )) memories = '\n'.join(memories) components = '\n'.join([ diff --git a/concordia/components/agent/reflection.py b/concordia/components/agent/reflection.py index 24f7153a..1b675212 100644 --- a/concordia/components/agent/reflection.py +++ b/concordia/components/agent/reflection.py @@ -75,7 +75,7 @@ def update(self) -> None: return - mems = '\n'.join(mems) + mems = '\n'.join(list(mems)) prompt_questions = interactive_document.InteractiveDocument(self._model) @@ -95,7 +95,8 @@ def update(self) -> None: mems = [] # make sure that the answer comes out of LLM in the right format for question in questions.splitlines(): - mems += self._memory.retrieve_associative(question, 10, add_time=True) + mems += list( + self._memory.retrieve_associative(question, 10, add_time=True)) mems = '\n'.join(mems) diff --git a/concordia/environment/game_master.py b/concordia/environment/game_master.py index d4ec8abd..e2e0e80e 100644 --- a/concordia/environment/game_master.py +++ b/concordia/environment/game_master.py @@ -383,7 +383,7 @@ def step( if self._players_act_simultaneously: self._clock.advance() - def run_episode(self, max_steps: int = 20) -> list[str]: + def run_episode(self, max_steps: int = 20) -> Sequence[str]: for _ in range(max_steps): self.step() for comp in self._components.values():