Skip to content

Commit

Permalink
Add factory for a rational supporting character.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657516116
Change-Id: Ie5dd60f898ab71e8e68763fbafcc61c613ef4c89
  • Loading branch information
jzleibo authored and copybara-github committed Jul 30, 2024
1 parent 7a86b2d commit 153610a
Showing 1 changed file with 196 additions and 0 deletions.
196 changes: 196 additions & 0 deletions concordia/factory/agent/rational_entity_agent__supporting_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Generic Agent Factory."""

from collections.abc import Mapping
import datetime
import types

from concordia.agents import basic_agent
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components.agent import v2 as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import entity_component
from concordia.utils import measurements as measurements_lib


def _get_class_name(object_: object) -> str:
return object_.__class__.__name__


def build_agent(
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta,
additional_components: Mapping[
entity_component.ComponentName, str
] = types.MappingProxyType({}),
) -> basic_agent.BasicAgent:
"""Build an agent.
Args:
config: The agent config to use.
model: The language model to use.
memory: The agent's memory object.
clock: The clock to use.
update_time_interval: Agent calls update every time this interval passes.
additional_components: Additional components to add to the agent.
Returns:
An agent.
"""
del update_time_interval
if config.extras.get('main_character', False):
raise ValueError('This function is meant for a supporting character '
'but it was called on a main character.')

agent_name = config.name
raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)
measurements = measurements_lib.Measurements()

instructions = agent_components.instructions.Instructions(
agent_name=agent_name,
logging_channel=measurements.get_channel('Instructions').on_next,
)
time_display = agent_components.report_function.ReportFunction(
function=clock.current_time_interval_str,
pre_act_key='\nCurrent time',
logging_channel=measurements.get_channel('TimeDisplay').on_next,
)
observation_label = '\nObservation'
observation = agent_components.observation.Observation(
clock_now=clock.now,
timeframe=clock.get_step_size(),
pre_act_key=observation_label,
logging_channel=measurements.get_channel('Observation').on_next,
)
observation_summary_label = '\nSummary of recent observations'
observation_summary = agent_components.observation.ObservationSummary(
model=model,
clock_now=clock.now,
timeframe_delta_from=datetime.timedelta(hours=4),
timeframe_delta_until=datetime.timedelta(hours=0),
pre_act_key=observation_summary_label,
logging_channel=measurements.get_channel('ObservationSummary').on_next,
)
relevant_memories_label = '\nRecalled memories and observations'
relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
model=model,
components={
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(time_display): 'The current date/time is'},
num_memories_to_retrieve=10,
pre_act_key=relevant_memories_label,
logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
)

options_perception_components = {}
if config.goal:
goal_label = '\nOverarching goal'
overarching_goal = agent_components.constant.Constant(
state=config.goal,
pre_act_key=goal_label,
logging_channel=measurements.get_channel(goal_label).on_next)
options_perception_components[goal_label] = goal_label
else:
goal_label = None
overarching_goal = None

options_perception_components.update({
_get_class_name(observation): observation_label,
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(relevant_memories): relevant_memories_label,
})
options_perception_label = (
f'\nQuestion: Which options are available to {agent_name} '
'right now?\nAnswer')
options_perception = (
agent_components.options_perception.AvailableOptionsPerception(
model=model,
components=options_perception_components,
clock_now=clock.now,
pre_act_key=options_perception_label,
logging_channel=measurements.get_channel(
'AvailableOptionsPerception').on_next,
)
)
best_option_perception_label = (
f'\nQuestion: Of the options available to {agent_name}, and '
'given their goal, which choice of action or strategy is '
f'best for {agent_name} to take right now?\nAnswer')
best_option_perception = {}
if config.goal:
best_option_perception[goal_label] = goal_label
best_option_perception.update({
_get_class_name(observation): observation_label,
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(relevant_memories): relevant_memories_label,
_get_class_name(options_perception): options_perception_label,
})
best_option_perception = (
agent_components.options_perception.BestOptionPerception(
model=model,
components=best_option_perception,
clock_now=clock.now,
pre_act_key=best_option_perception_label,
logging_channel=measurements.get_channel(
'BestOptionPerception').on_next,
)
)

entity_components = (
# Components that provide pre_act context.
instructions,
time_display,
observation,
observation_summary,
relevant_memories,
options_perception,
best_option_perception,
)

components_of_agent = {_get_class_name(component): component
for component in entity_components}
components_of_agent[
agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
agent_components.memory_component.MemoryComponent(raw_memory))
components_of_agent.update(additional_components)

component_order = list(components_of_agent.keys())
if overarching_goal is not None:
components_of_agent[goal_label] = overarching_goal
# Place goal after the instructions.
component_order.insert(1, goal_label)

act_component = agent_components.concat_act_component.ConcatActComponent(
model=model,
clock=clock,
logging_channel=measurements.get_channel('ActComponent').on_next,
)

agent = entity_agent_with_logging.EntityAgentWithLogging(
agent_name=agent_name,
act_component=act_component,
context_components=components_of_agent,
component_logging=measurements,
)

return agent

0 comments on commit 153610a

Please sign in to comment.