Skip to content

Commit

Permalink
add rational entity agent main role
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655100776
Change-Id: I2ce096d6be1b4a32127d5e3fb93df3160f5ca1d2
  • Loading branch information
jzleibo authored and copybara-github committed Jul 23, 2024
1 parent f7d1656 commit 7efcc4d
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 3 deletions.
1 change: 1 addition & 0 deletions concordia/components/agent/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from concordia.components.agent.v2 import memory_component
from concordia.components.agent.v2 import no_op_context_processor
from concordia.components.agent.v2 import observation
from concordia.components.agent.v2 import options_perception
from concordia.components.agent.v2 import person_by_situation
from concordia.components.agent.v2 import plan
from concordia.components.agent.v2 import report_function
Expand Down
11 changes: 8 additions & 3 deletions concordia/factory/agent/factories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from concordia.factory.agent import basic_agent__supporting_role
from concordia.factory.agent import basic_entity_agent__main_role
from concordia.factory.agent import rational_agent__main_role
from concordia.factory.agent import rational_entity_agent__main_role
from concordia.language_model import no_language_model
from concordia.typing import agent as agent_lib
from concordia.typing import entity as entity_lib
Expand All @@ -52,8 +53,9 @@
AGENT_FACTORIES = {
'basic_agent__main_role': basic_agent__main_role,
'basic_agent__supporting_role': basic_agent__supporting_role,
'rational_agent__main_role': rational_agent__main_role,
'basic_entity_agent__main_role': basic_entity_agent__main_role,
'rational_agent__main_role': rational_agent__main_role,
'rational_entity_agent__main_role': rational_entity_agent__main_role,
}


Expand All @@ -71,11 +73,14 @@ class AgentFactoriesTest(parameterized.TestCase):
dict(testcase_name='basic_agent__supporting_role',
agent_name='basic_agent__supporting_role',
main_role=False),
dict(testcase_name='basic_entity_agent__main_role',
agent_name='basic_entity_agent__main_role',
main_role=True),
dict(testcase_name='rational_agent__main_role',
agent_name='rational_agent__main_role',
main_role=True),
dict(testcase_name='basic_entity_agent__main_role',
agent_name='basic_entity_agent__main_role',
dict(testcase_name='rational_entity_agent__main_role',
agent_name='rational_entity_agent__main_role',
main_role=True),
)
def test_output_in_right_format(self, agent_name: str, main_role: bool):
Expand Down
192 changes: 192 additions & 0 deletions concordia/factory/agent/rational_entity_agent__main_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An Agent Factory."""

import datetime

from concordia.agents import basic_agent
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components.agent import v2 as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib


def _get_class_name(object_: object) -> str:
return object_.__class__.__name__


def build_agent(
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta,
) -> basic_agent.BasicAgent:
"""Build an agent.
Args:
config: The agent config to use.
model: The language model to use.
memory: The agent's memory object.
clock: The clock to use.
update_time_interval: Agent calls update every time this interval passes.
Returns:
An agent.
"""
del update_time_interval
if not config.extras.get('main_character', False):
raise ValueError('This function is meant for a main character '
'but it was called on a supporting character.')

agent_name = config.name

raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

measurements = measurements_lib.Measurements()
instructions = agent_components.instructions.Instructions(
agent_name=agent_name,
logging_channel=measurements.get_channel('Instructions').on_next,
)

time_display = agent_components.report_function.ReportFunction(
function=clock.current_time_interval_str,
pre_act_key='\nCurrent time',
logging_channel=measurements.get_channel('TimeDisplay').on_next,
)

observation_label = '\nObservation'
observation = agent_components.observation.Observation(
clock_now=clock.now,
timeframe=clock.get_step_size(),
pre_act_key=observation_label,
logging_channel=measurements.get_channel('Observation').on_next,
)
observation_summary_label = 'Summary of recent observations'
observation_summary = agent_components.observation.ObservationSummary(
model=model,
clock_now=clock.now,
timeframe_delta_from=datetime.timedelta(hours=4),
timeframe_delta_until=datetime.timedelta(hours=1),
pre_act_key=observation_summary_label,
logging_channel=measurements.get_channel('ObservationSummary').on_next,
)

relevant_memories_label = '\nRecalled memories and observations'
relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
model=model,
components={
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(time_display): 'The current date/time is'},
num_memories_to_retrieve=10,
pre_act_key=relevant_memories_label,
logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
)

options_perception_components = {}
if config.goal:
goal_label = '\nOverarching goal'
overarching_goal = agent_components.constant.Constant(
state=config.goal,
pre_act_key=goal_label,
logging_channel=measurements.get_channel(goal_label).on_next)
options_perception_components[goal_label] = goal_label
else:
goal_label = None
overarching_goal = None

options_perception_components.update({
_get_class_name(observation): observation_label,
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(relevant_memories): relevant_memories_label,
})
options_perception_label = (
f'\nQuestion: Which options are available to {agent_name} '
'right now?\nAnswer')
options_perception = (
agent_components.options_perception.AvailableOptionsPerception(
model=model,
components=options_perception_components,
clock_now=clock.now,
pre_act_key=options_perception_label,
logging_channel=measurements.get_channel(
'AvailableOptionsPerception').on_next,
)
)
best_option_perception_label = (
f'\nQuestion: Of the options available to {agent_name}, and '
'given their goal, which choice of action or strategy is '
f'best for {agent_name} to take right now?\nAnswer')
best_option_perception = {}
if config.goal:
best_option_perception[goal_label] = goal_label
best_option_perception.update({
_get_class_name(observation): observation_label,
_get_class_name(observation_summary): observation_summary_label,
_get_class_name(relevant_memories): relevant_memories_label,
_get_class_name(options_perception): options_perception_label,
})
best_option_perception = (
agent_components.options_perception.BestOptionPerception(
model=model,
components=best_option_perception,
clock_now=clock.now,
pre_act_key=best_option_perception_label,
logging_channel=measurements.get_channel(
'BestOptionPerception').on_next,
)
)

entity_components = (
# Components that provide pre_act context.
instructions,
time_display,
observation,
observation_summary,
relevant_memories,
options_perception,
best_option_perception,
)
components_of_agent = {_get_class_name(component): component
for component in entity_components}
components_of_agent[
agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
agent_components.memory_component.MemoryComponent(raw_memory))

component_order = list(components_of_agent.keys())
if overarching_goal is not None:
components_of_agent[goal_label] = overarching_goal
# Place goal after the instructions.
component_order.insert(1, goal_label)

act_component = agent_components.concat_act_component.ConcatActComponent(
model=model,
clock=clock,
component_order=component_order,
logging_channel=measurements.get_channel('ActComponent').on_next,
)

agent = entity_agent_with_logging.EntityAgentWithLogging(
agent_name=agent_name,
act_component=act_component,
context_components=components_of_agent,
component_logging=measurements,
)

return agent

0 comments on commit 7efcc4d

Please sign in to comment.