From 81dcee101202b692f80dfb868c960da5a484c9eb Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Wed, 20 Dec 2023 03:07:31 -0800 Subject: [PATCH] Refactoring observation component that uses memory retrieval by time to represent observations. PiperOrigin-RevId: 592503275 Change-Id: I9a68bb7457ae597d5701349d1b2a012e6d263aa2 --- .../associative_memory/associative_memory.py | 25 ++++ concordia/components/agent/observation.py | 137 +++++++++++------- .../components/game_master/conversation.py | 8 +- concordia/tests/concordia_integration_test.py | 107 +++++++++++++- examples/cyberball/cyberball.ipynb | 16 +- examples/magic_beans_for_sale.ipynb | 18 ++- examples/phone/calendar.ipynb | 15 +- examples/three_key_questions.ipynb | 18 ++- examples/village/day_in_riverbend.ipynb | 14 +- examples/village/riverbend_elections.ipynb | 16 +- 10 files changed, 296 insertions(+), 78 deletions(-) diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py index 28361648..6e6672e9 100644 --- a/concordia/associative_memory/associative_memory.py +++ b/concordia/associative_memory/associative_memory.py @@ -268,6 +268,31 @@ def retrieve_by_regex( return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time) + def retrieve_time_interval( + self, + time_from: datetime.datetime, + time_until: datetime.datetime, + add_time: bool = False, + ): + """Retrieve memories within a time interval. + + Args: + time_from: the start time of the interval + time_until: the end time of the interval + add_time: whether to add time stamp to the output + + Returns: + List of strings corresponding to memories + """ + + with self._memory_bank_lock: + data = self._memory_bank[ + (self._memory_bank['time'] >= time_from) + & (self._memory_bank['time'] <= time_until) + ] + + return self._pd_to_text(data, add_time=add_time, sort_by_time=True) + def retrieve_recent( self, k: int = 1, diff --git a/concordia/components/agent/observation.py b/concordia/components/agent/observation.py index 2231f015..2b34853b 100644 --- a/concordia/components/agent/observation.py +++ b/concordia/components/agent/observation.py @@ -15,6 +15,8 @@ """Agent components for representing observation stream.""" +from collections.abc import Callable +import datetime from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model @@ -23,31 +25,36 @@ class Observation(component.Component): - """Component that stacks current observations together, clears on update.""" + """Component that displays and adds observations to memory.""" def __init__( self, agent_name: str, + clock_now: Callable[[], datetime.datetime], + timeframe: datetime.timedelta, memory: associative_memory.AssociativeMemory, component_name: str = 'Current observation', verbose: bool = False, log_colour='green', ): - """Initialize the observation component. + """Initializes the component. Args: - agent_name: the name of the agent - memory: memory for writing observations into - component_name: the name of this component - verbose: whether or not to print intermediate reasoning steps - log_colour: colour for logging + agent_name: Name of the agent. + clock_now: Function that returns the current time. + timeframe: Delta from current moment to display observations from, e.g. 1h + would display all observations made in the last hour. + memory: Associative memory to add and retrieve observations. + component_name: Name of this component. + verbose: Whether to print the observations. + log_colour: Colour to print the log. """ self._agent_name = agent_name self._log_colour = log_colour self._name = component_name self._memory = memory - - self._last_observation = [] + self._timeframe = timeframe + self._clock_now = clock_now self._verbose = verbose @@ -55,57 +62,74 @@ def name(self) -> str: return self._name def state(self): + mems = self._memory.retrieve_time_interval( + self._clock_now() - self._timeframe, self._clock_now(), add_time=True + ) if self._verbose: - self._log('\n'.join(self._last_observation) + '\n') - return '\n'.join(self._last_observation) + '\n' + self._log('\n'.join(mems) + '\n') + return '\n'.join(mems) + '\n' def _log(self, entry: str): print(termcolor.colored(entry, self._log_colour), end='') def observe(self, observation: str): - self._last_observation.append(observation) self._memory.add( f'[observation] {observation}', tags=['observation'], ) - def update(self): - self._last_observation = [] - return '' - class ObservationSummary(component.Component): - """Component that summarises current observations on update.""" + """Component that summarises observations from a segment of time.""" def __init__( self, - model: language_model.LanguageModel, agent_name: str, + model: language_model.LanguageModel, + clock_now: Callable[[], datetime.datetime], + timeframe_delta_from: datetime.timedelta, + timeframe_delta_until: datetime.timedelta, + memory: associative_memory.AssociativeMemory, components: list[component.Component], + component_name: str = 'Summary of observations', + display_timeframe: bool = True, verbose: bool = False, log_colour='green', ): - """Summarize the agent's observations. + """Initializes the component. Args: - model: a language model - agent_name: the name of the agent - components: components to condition observation summarisation - verbose: whether or not to print intermediate reasoning steps - log_colour: colour for logging + agent_name: Name of the agent. + model: Language model to summarise the observations. + clock_now: Function that returns the current time. + timeframe_delta_from: delta from the current moment to the begnning of the + segment to summarise, e.g. 4h would summarise all observations that + happened from 4h ago intil clock_now minus timeframe_delta_until. + timeframe_delta_until: delta from the current moment to the end of the + segment to summarise. + memory: Associative memory retrieve observations from. + components: List of components to summarise. + component_name: Name of the component. + display_timeframe: Whether to display the time interval as text. + verbose: Whether to print the observations. + log_colour: Colour to print the log. """ self._model = model - self._state = '' self._agent_name = agent_name self._log_colour = log_colour + self._name = component_name + self._memory = memory + self._timeframe_delta_from = timeframe_delta_from + self._timeframe_delta_until = timeframe_delta_until + self._clock_now = clock_now self._components = components - - self._last_observation = [] + self._state = '' + self._display_timeframe = display_timeframe self._verbose = verbose def name(self) -> str: - return 'Summary of recent observations' + return self._name def state(self): return self._state @@ -113,36 +137,45 @@ def state(self): def _log(self, entry: str): print(termcolor.colored(entry, self._log_colour), end='') - def observe(self, observation: str): - self._last_observation.append(observation) - def update(self): - context = '\n'.join( - [ - f"{self._agent_name}'s " - + (comp.name() + ':\n' + comp.state()) - for comp in self._components - ] + context = '\n'.join([ + f"{self._agent_name}'s " + (comp.name() + ':\n' + comp.state()) + for comp in self._components + ]) + + segment_start = self._clock_now() - self._timeframe_delta_from + segment_end = self._clock_now() - self._timeframe_delta_until + + mems = self._memory.retrieve_time_interval( + segment_start, + segment_end, + add_time=True, ) - numbered_observations = [ - f'{i}. {observation}' - for i, observation in enumerate(self._last_observation) - ] - current_observations = '\n'.join(numbered_observations) - prompt = interactive_document.InteractiveDocument(self._model) prompt.statement(context + '\n') - prompt.statement( - 'Current observations, numbered in chronological order:\n' - + f'{current_observations}\n' - ) - self._state = prompt.open_question( - 'Summarize the observations into one sentence.' + prompt.statement(f'Recent memories of {self._agent_name}:\n' + f'{mems}\n') + self._state = ( + self._agent_name + + ' ' + + prompt.open_question( + 'Summarize the memories above into one sentence about' + f' {self._agent_name}.', + answer_prefix=f'{self._agent_name} ', + max_characters=500, + ) ) - self._last_observation = [] + if self._display_timeframe: + if segment_start.date() == segment_end.date(): + interval = segment_start.strftime( + '%d %b %Y [%H:%M:%S ' + ) + segment_end.strftime('- %H:%M:%S]: ') + else: + interval = segment_start.strftime( + '[%d %b %Y %H:%M:%S ' + ) + segment_end.strftime('- %d %b %Y %H:%M:%S]: ') + self._state = f'{interval} {self._state}' if self._verbose: - self._log('\nObservation summary:') - self._log('\n' + prompt.view().text() + '\n') + self._log(self._state) diff --git a/concordia/components/game_master/conversation.py b/concordia/components/game_master/conversation.py index 21a3e4e3..0253336b 100644 --- a/concordia/components/game_master/conversation.py +++ b/concordia/components/game_master/conversation.py @@ -16,6 +16,7 @@ """Externality for the Game Master, which generates conversations.""" from collections.abc import Sequence +import datetime from concordia import components as generic_components from concordia.agents import basic_agent @@ -123,7 +124,12 @@ def _make_npc( generic_components.constant.ConstantComponent( name='General knowledge:', state=context ), - sim_components.observation.Observation(agent_name=name, memory=mem), + sim_components.observation.Observation( + agent_name=name, + memory=mem, + clock_now=scene_clock.now, + timeframe=datetime.timedelta(days=1), + ), ], verbose=True, ) diff --git a/concordia/tests/concordia_integration_test.py b/concordia/tests/concordia_integration_test.py index 92c544d8..e87a2782 100644 --- a/concordia/tests/concordia_integration_test.py +++ b/concordia/tests/concordia_integration_test.py @@ -47,12 +47,98 @@ def _make_agent( ) -> basic_agent.BasicAgent: """Creates two agents with the same game master instructions.""" mem = mem_factory.make_blank_memory() + goal_metric = goal_achievement.GoalAchievementMetric( - model=model, player_name=name, player_goal='win', clock=clock, + model=model, + player_name=name, + player_goal='win', + clock=clock, ) morality_metric = common_sense_morality.CommonSenseMoralityMetric( - model=model, player_name=name, clock=clock, + model=model, + player_name=name, + clock=clock, + ) + + time = components.report_function.ReportFunction( + name='Current time', + function=clock.current_time_interval_str, + ) + somatic_state = agent_components.somatic_state.SomaticState( + model=model, + memory=mem, + agent_name=name, + clock_now=clock.now, + ) + identity = agent_components.identity.SimIdentity( + model=model, + memory=mem, + agent_name=name, + ) + goal_component = components.constant.ConstantComponent(state='test') + plan = agent_components.plan.SimPlan( + model=model, + memory=mem, + agent_name=name, + components=[identity], + goal=goal_component, + verbose=False, + ) + + self_perception = agent_components.self_perception.SelfPerception( + name='self perception', + model=model, + memory=mem, + agent_name=name, + clock_now=clock.now, + verbose=True, ) + situation_perception = ( + agent_components.situation_perception.SituationPerception( + name='situation perception', + model=model, + memory=mem, + agent_name=name, + clock_now=clock.now, + verbose=True, + ) + ) + person_by_situation = agent_components.person_by_situation.PersonBySituation( + name='person by situation', + model=model, + memory=mem, + agent_name=name, + clock_now=clock.now, + components=[self_perception, situation_perception], + verbose=True, + ) + persona = components.sequential.Sequential( + name='persona', + components=[ + self_perception, + situation_perception, + person_by_situation, + ], + ) + + observation = agent_components.observation.Observation( + agent_name=name, + clock_now=clock.now, + memory=mem, + timeframe=clock.get_step_size(), + component_name='current observations', + ) + observation_summary = agent_components.observation.ObservationSummary( + agent_name=name, + model=model, + clock_now=clock.now, + memory=mem, + timeframe_delta_from=datetime.timedelta(hours=4), + timeframe_delta_until=datetime.timedelta(hours=1), + components=[identity], + component_name='summary of observations', + ) + agent = basic_agent.BasicAgent( model, mem, @@ -62,18 +148,23 @@ def _make_agent( components.constant.ConstantComponent( 'Instructions:', game_master_instructions ), - components.constant.ConstantComponent( - 'General knowledge:', 'this is a test' - ), - agent_components.observation.Observation('Alice', mem), + persona, + observation, + observation_summary, + plan, + somatic_state, + time, goal_metric, morality_metric, ], verbose=True, ) reputation_metric = opinion_of_others.OpinionOfOthersMetric( - model=model, player_name=name, player_names=player_names, - context_fn=agent.state, clock=clock, + model=model, + player_name=name, + player_names=player_names, + context_fn=agent.state, + clock=clock, ) agent.add_component(reputation_metric) diff --git a/examples/cyberball/cyberball.ipynb b/examples/cyberball/cyberball.ipynb index 0448bf21..09ed2e5d 100644 --- a/examples/cyberball/cyberball.ipynb +++ b/examples/cyberball/cyberball.ipynb @@ -250,14 +250,22 @@ " agent_name=agent_config.name,\n", " )\n", "\n", - " current_obs = agent_components.observation.Observation(\n", - " agent_name=agent_config.name,\n", + " current_obs = components.observation.Observation(\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", " )\n", - " summary_obs = agent_components.observation.ObservationSummary(\n", - " model=model,\n", + " summary_obs = components.observation.ObservationSummary(\n", " agent_name=agent_config.name,\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", "\n", " morality_metric = common_sense_morality.CommonSenseMoralityMetric(\n", diff --git a/examples/magic_beans_for_sale.ipynb b/examples/magic_beans_for_sale.ipynb index cd4e7c9f..5e929824 100644 --- a/examples/magic_beans_for_sale.ipynb +++ b/examples/magic_beans_for_sale.ipynb @@ -303,12 +303,24 @@ " ]\n", " )\n", "\n", - " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " current_obs = components.observation.Observation(\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", + " )\n", " summary_obs = components.observation.ObservationSummary(\n", - " model=model,\n", " agent_name=agent_config.name,\n", - " components=[persona],\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", + " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", + "\n", " goal_metric = goal_achievement.GoalAchievementMetric(\n", " model=model,\n", " player_name=agent_config.name,\n", diff --git a/examples/phone/calendar.ipynb b/examples/phone/calendar.ipynb index 2d23352e..e996d3ca 100644 --- a/examples/phone/calendar.ipynb +++ b/examples/phone/calendar.ipynb @@ -255,11 +255,22 @@ " goal=goal_component,\n", " verbose=False,\n", " )\n", - " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " current_obs = components.observation.Observation(\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", + " )\n", " summary_obs = components.observation.ObservationSummary(\n", - " model=model,\n", " agent_name=agent_config.name,\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", " agent = basic_agent.BasicAgent(\n", " model,\n", diff --git a/examples/three_key_questions.ipynb b/examples/three_key_questions.ipynb index e939d792..29f37ebe 100644 --- a/examples/three_key_questions.ipynb +++ b/examples/three_key_questions.ipynb @@ -346,13 +346,25 @@ " name='current_time', function=clock.current_time_interval_str\n", " )\n", "\n", - " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " current_obs = components.observation.Observation(\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", + " )\n", " summary_obs = components.observation.ObservationSummary(\n", - " model=model,\n", " agent_name=agent_config.name,\n", - " components=[persona],\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", + " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", "\n", + "\n", " goal_metric = goal_achievement.GoalAchievementMetric(\n", " model=model,\n", " player_name=agent_config.name,\n", diff --git a/examples/village/day_in_riverbend.ipynb b/examples/village/day_in_riverbend.ipynb index 331cddfa..6869322d 100644 --- a/examples/village/day_in_riverbend.ipynb +++ b/examples/village/day_in_riverbend.ipynb @@ -290,14 +290,24 @@ " verbose=False,\n", " )\n", " current_obs = components.observation.Observation(\n", - " agent_config.name, memory=mem\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", " )\n", " summary_obs = components.observation.ObservationSummary(\n", - " model=model,\n", " agent_name=agent_config.name,\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", "\n", + "\n", " goal_metric = goal_achievement.GoalAchievementMetric(\n", " model=model,\n", " player_name=agent_config.name,\n", diff --git a/examples/village/riverbend_elections.ipynb b/examples/village/riverbend_elections.ipynb index 2fffba47..8607cf7f 100644 --- a/examples/village/riverbend_elections.ipynb +++ b/examples/village/riverbend_elections.ipynb @@ -276,13 +276,23 @@ " verbose=False,\n", " )\n", " current_obs = components.observation.Observation(\n", - " agent_name=agent_config.name,\n", - " memory=mem)\n", + " agent_name=agent_config.name,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe=clock.get_step_size(),\n", + " component_name='current observations',\n", + " )\n", " summary_obs = components.observation.ObservationSummary(\n", - " model=model,\n", " agent_name=agent_config.name,\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=mem,\n", + " timeframe_delta_from=datetime.timedelta(hours=4),\n", + " timeframe_delta_until=datetime.timedelta(hours=1),\n", " components=[identity],\n", + " component_name='summary of observations',\n", " )\n", + "\n", " goal_metric = goal_achievement.GoalAchievementMetric(\n", " model=model,\n", " player_name=agent_config.name,\n",