Skip to content

Commit

Permalink
Add an agent with a very simple prompting strategy, just observe and …
Browse files Browse the repository at this point in the history
…recall relevant memories. Nothing else.

PiperOrigin-RevId: 675064536
Change-Id: I4c9e23984f3d4274e0874754877378640e86cb54
  • Loading branch information
jzleibo authored and copybara-github committed Sep 16, 2024
1 parent dc06b41 commit 64a8cd1
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions concordia/factory/agent/observe_recall_prompt_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An Agent Factory."""

import datetime

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib


def _get_class_name(object_: object) -> str:
return object_.__class__.__name__


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta,
) -> entity_agent_with_logging.EntityAgentWithLogging:
"""Build an agent.
Args:
config: The agent config to use.
model: The language model to use.
memory: The agent's memory object.
clock: The clock to use.
update_time_interval: Agent calls update every time this interval passes.
Returns:
An agent.
"""
del update_time_interval
if not config.extras.get('main_character', False):
raise ValueError('This function is meant for a main character '
'but it was called on a supporting character.')

agent_name = config.name

raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

measurements = measurements_lib.Measurements()
instructions = agent_components.instructions.Instructions(
agent_name=agent_name,
logging_channel=measurements.get_channel('Instructions').on_next,
)

time_display = agent_components.report_function.ReportFunction(
function=clock.current_time_interval_str,
pre_act_key='\nCurrent time',
logging_channel=measurements.get_channel('TimeDisplay').on_next,
)

observation_label = '\nObservation'
observation = agent_components.observation.Observation(
clock_now=clock.now,
timeframe=clock.get_step_size(),
pre_act_key=observation_label,
logging_channel=measurements.get_channel('Observation').on_next,
)
observation_summary_label = 'Summary of recent observations'
observation_summary = agent_components.observation.ObservationSummary(
model=model,
clock_now=clock.now,
timeframe_delta_from=datetime.timedelta(hours=4),
timeframe_delta_until=datetime.timedelta(hours=0),
pre_act_key=observation_summary_label,
logging_channel=measurements.get_channel('ObservationSummary').on_next,
)

relevant_memories_label = '\nRecalled memories and observations'
relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
model=model,
components={
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(time_display): 'The current date/time is'},
num_memories_to_retrieve=10,
pre_act_key=relevant_memories_label,
logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
)

if config.goal:
goal_label = '\nOverarching goal'
overarching_goal = agent_components.constant.Constant(
state=config.goal,
pre_act_key=goal_label,
logging_channel=measurements.get_channel(goal_label).on_next)
else:
goal_label = None
overarching_goal = None

entity_components = (
# Components that provide pre_act context.
instructions,
time_display,
observation,
observation_summary,
relevant_memories,
)
components_of_agent = {_get_class_name(component): component
for component in entity_components}
components_of_agent[
agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
agent_components.memory_component.MemoryComponent(raw_memory))

component_order = list(components_of_agent.keys())
if overarching_goal is not None:
components_of_agent[goal_label] = overarching_goal
# Place goal after the instructions.
component_order.insert(1, goal_label)

act_component = agent_components.concat_act_component.ConcatActComponent(
model=model,
clock=clock,
component_order=component_order,
logging_channel=measurements.get_channel('ActComponent').on_next,
)

agent = entity_agent_with_logging.EntityAgentWithLogging(
agent_name=agent_name,
act_component=act_component,
context_components=components_of_agent,
component_logging=measurements,
)

return agent

0 comments on commit 64a8cd1

Please sign in to comment.