-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding supporting players to pub_coordination and puppet agents.
- Add basic_puppet_agent and the puppet_act component that enable making agents with fixed responses to certain action calls - Add supporting players to pub_coordination, which always choose the same pub. They are used as holdouts. PiperOrigin-RevId: 669312568 Change-Id: If2f6d391fb5524697ea94692d25084e56034b765
- Loading branch information
1 parent
b6fd371
commit 8001ade
Showing
4 changed files
with
478 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""An acting component that can be set to give fixed responses.""" | ||
|
||
|
||
from collections.abc import Mapping, Sequence | ||
|
||
from concordia.document import interactive_document | ||
from concordia.language_model import language_model | ||
from concordia.typing import clock as game_clock | ||
from concordia.typing import entity as entity_lib | ||
from concordia.typing import entity_component | ||
from concordia.typing import logging | ||
from concordia.utils import helper_functions | ||
from typing_extensions import override | ||
|
||
DEFAULT_PRE_ACT_KEY = 'Act' | ||
|
||
|
||
class PuppetActComponent(entity_component.ActingComponent): | ||
"""A component which concatenates contexts from context components. | ||
The component will output a fixed response to a pre-specified calls to action. | ||
Otherwise, this component will receive the contexts from `pre_act` from all | ||
the components, and assemble them in the order specified to `__init__`. If the | ||
component order is not specified, then components will be assembled in the | ||
iteration order of the `ComponentContextMapping` passed to | ||
`get_action_attempt`. Components that return empty strings from `pre_act` are | ||
ignored. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: language_model.LanguageModel, | ||
clock: game_clock.GameClock, | ||
fixed_responses: Mapping[str, str], | ||
component_order: Sequence[str] | None = None, | ||
pre_act_key: str = DEFAULT_PRE_ACT_KEY, | ||
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, | ||
): | ||
"""Initializes the agent. | ||
Args: | ||
model: The language model to use for generating the action attempt. | ||
clock: the game clock is needed to know when is the current time | ||
fixed_responses: A mapping from call to action to fixed response. | ||
component_order: The order in which the component contexts will be | ||
assembled when calling the act component. If None, the contexts will be | ||
assembled in the iteration order of the `ComponentContextMapping` passed | ||
to `get_action_attempt`. If the component order is specified, but does | ||
not contain all the components passed to `get_action_attempt`, the | ||
missing components will be appended at the end in the iteration order of | ||
the `ComponentContextMapping` passed to `get_action_attempt`. The same | ||
component cannot appear twice in the component order. All components in | ||
the component order must be in the `ComponentContextMapping` passed to | ||
`get_action_attempt`. | ||
pre_act_key: Prefix to add to the context of the component. | ||
logging_channel: The channel to use for debug logging. | ||
Raises: | ||
ValueError: If the component order is not None and contains duplicate | ||
components. | ||
""" | ||
self._model = model | ||
self._clock = clock | ||
if component_order is None: | ||
self._component_order = None | ||
else: | ||
self._component_order = tuple(component_order) | ||
if self._component_order is not None: | ||
if len(set(self._component_order)) != len(self._component_order): | ||
raise ValueError( | ||
'The component order contains duplicate components: ' | ||
+ ', '.join(self._component_order) | ||
) | ||
|
||
self._fixed_responses = fixed_responses | ||
|
||
self._pre_act_key = pre_act_key | ||
self._logging_channel = logging_channel | ||
|
||
def _context_for_action( | ||
self, | ||
contexts: entity_component.ComponentContextMapping, | ||
) -> str: | ||
if self._component_order is None: | ||
return '\n'.join(context for context in contexts.values() if context) | ||
else: | ||
order = self._component_order + tuple( | ||
sorted(set(contexts.keys()) - set(self._component_order)) | ||
) | ||
return '\n'.join(contexts[name] for name in order if contexts[name]) | ||
|
||
@override | ||
def get_action_attempt( | ||
self, | ||
contexts: entity_component.ComponentContextMapping, | ||
action_spec: entity_lib.ActionSpec, | ||
) -> str: | ||
prompt = interactive_document.InteractiveDocument(self._model) | ||
context = self._context_for_action(contexts) | ||
prompt.statement(context + '\n') | ||
|
||
call_to_action = action_spec.call_to_action.format( | ||
name=self.get_entity().name, | ||
timedelta=helper_functions.timedelta_to_readable_str( | ||
self._clock.get_step_size() | ||
), | ||
) | ||
|
||
if call_to_action in self._fixed_responses: | ||
print( | ||
f'Using fixed response for {call_to_action}:' | ||
f' {self._fixed_responses[call_to_action]}' | ||
) | ||
output = self._fixed_responses[call_to_action] | ||
if ( | ||
action_spec.output_type == entity_lib.OutputType.CHOICE | ||
and output not in action_spec.options | ||
): | ||
raise ValueError( | ||
f'Fixed response {output} not in options: {action_spec.options}' | ||
) | ||
elif action_spec.output_type == entity_lib.OutputType.FLOAT: | ||
try: | ||
return str(float(output)) | ||
except ValueError: | ||
return '0.0' | ||
|
||
return self._fixed_responses[call_to_action] | ||
|
||
if action_spec.output_type == entity_lib.OutputType.FREE: | ||
output = self.get_entity().name + ' ' | ||
output += prompt.open_question( | ||
call_to_action, | ||
max_tokens=2200, | ||
answer_prefix=output, | ||
# This terminator protects against the model providing extra context | ||
# after the end of a directly spoken response, since it normally | ||
# puts a space after a quotation mark only in these cases. | ||
terminators=('" ', '\n'), | ||
question_label='Exercise', | ||
) | ||
self._log(output, prompt) | ||
return output | ||
elif action_spec.output_type == entity_lib.OutputType.CHOICE: | ||
idx = prompt.multiple_choice_question( | ||
question=call_to_action, answers=action_spec.options | ||
) | ||
output = action_spec.options[idx] | ||
self._log(output, prompt) | ||
return output | ||
elif action_spec.output_type == entity_lib.OutputType.FLOAT: | ||
prefix = self.get_entity().name + ' ' | ||
sampled_text = prompt.open_question( | ||
call_to_action, | ||
max_tokens=2200, | ||
answer_prefix=prefix, | ||
) | ||
self._log(sampled_text, prompt) | ||
try: | ||
return str(float(sampled_text)) | ||
except ValueError: | ||
return '0.0' | ||
else: | ||
raise NotImplementedError( | ||
f'Unsupported output type: {action_spec.output_type}. ' | ||
'Supported output types are: FREE, CHOICE, and FLOAT.' | ||
) | ||
|
||
def _log(self, result: str, prompt: interactive_document.InteractiveDocument): | ||
self._logging_channel({ | ||
'Key': self._pre_act_key, | ||
'Value': result, | ||
'Prompt': prompt.view().text().splitlines(), | ||
}) |
196 changes: 196 additions & 0 deletions
196
concordia/factory/agent/basic_puppet_agent__supporting_role.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""A Generic Agent Factory.""" | ||
|
||
from collections.abc import Mapping | ||
import datetime | ||
import types | ||
|
||
from concordia.agents import entity_agent_with_logging | ||
from concordia.associative_memory import associative_memory | ||
from concordia.associative_memory import formative_memories | ||
from concordia.clocks import game_clock | ||
from concordia.components import agent as agent_components | ||
from concordia.language_model import language_model | ||
from concordia.memory_bank import legacy_associative_memory | ||
from concordia.typing import entity_component | ||
from concordia.utils import measurements as measurements_lib | ||
|
||
|
||
def _get_class_name(object_: object) -> str: | ||
return object_.__class__.__name__ | ||
|
||
|
||
def build_agent( | ||
*, | ||
config: formative_memories.AgentConfig, | ||
model: language_model.LanguageModel, | ||
memory: associative_memory.AssociativeMemory, | ||
clock: game_clock.MultiIntervalClock, | ||
update_time_interval: datetime.timedelta, | ||
fixed_response_by_call_to_action: Mapping[str, str], | ||
additional_components: Mapping[ | ||
entity_component.ComponentName, | ||
entity_component.ContextComponent, | ||
] = types.MappingProxyType({}), | ||
) -> entity_agent_with_logging.EntityAgentWithLogging: | ||
"""Build an agent. | ||
Args: | ||
config: The agent config to use. | ||
model: The language model to use. | ||
memory: The agent's memory object. | ||
clock: The clock to use. | ||
update_time_interval: Agent calls update every time this interval passes. | ||
fixed_response_by_call_to_action: A mapping from call to action to fixed | ||
response. | ||
additional_components: Additional components to add to the agent. | ||
Returns: | ||
An agent. | ||
""" | ||
del update_time_interval | ||
if config.extras.get('main_character', False): | ||
raise ValueError('This function is meant for a supporting character ' | ||
'but it was called on a main character.') | ||
|
||
agent_name = config.name | ||
raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory) | ||
measurements = measurements_lib.Measurements() | ||
|
||
instructions = agent_components.instructions.Instructions( | ||
agent_name=agent_name, | ||
logging_channel=measurements.get_channel('Instructions').on_next, | ||
) | ||
time_display = agent_components.report_function.ReportFunction( | ||
function=clock.current_time_interval_str, | ||
pre_act_key='\nCurrent time', | ||
logging_channel=measurements.get_channel('TimeDisplay').on_next, | ||
) | ||
observation_label = '\nObservation' | ||
observation = agent_components.observation.Observation( | ||
clock_now=clock.now, | ||
timeframe=clock.get_step_size(), | ||
pre_act_key=observation_label, | ||
logging_channel=measurements.get_channel('Observation').on_next, | ||
) | ||
somatic_state_label = '\nSensations and feelings' | ||
somatic_state = agent_components.question_of_query_associated_memories.SomaticStateWithoutPreAct( | ||
model=model, | ||
clock_now=clock.now, | ||
logging_channel=measurements.get_channel('SomaticState').on_next, | ||
pre_act_key=somatic_state_label, | ||
) | ||
observation_summary_label = '\nSummary of recent observations' | ||
observation_summary = agent_components.observation.ObservationSummary( | ||
model=model, | ||
clock_now=clock.now, | ||
timeframe_delta_from=datetime.timedelta(hours=4), | ||
timeframe_delta_until=datetime.timedelta(hours=0), | ||
components={_get_class_name(somatic_state): somatic_state_label}, | ||
pre_act_key=observation_summary_label, | ||
logging_channel=measurements.get_channel('ObservationSummary').on_next, | ||
) | ||
relevant_memories_label = '\nRecalled memories and observations' | ||
relevant_memories = agent_components.all_similar_memories.AllSimilarMemories( | ||
model=model, | ||
components={ | ||
_get_class_name(observation_summary): observation_summary_label, | ||
_get_class_name(somatic_state): somatic_state_label, | ||
_get_class_name(time_display): 'The current date/time is', | ||
}, | ||
num_memories_to_retrieve=10, | ||
pre_act_key=relevant_memories_label, | ||
logging_channel=measurements.get_channel('AllSimilarMemories').on_next, | ||
) | ||
self_perception_label = ( | ||
f'\nQuestion: What kind of person is {agent_name}?\nAnswer') | ||
self_perception = agent_components.question_of_recent_memories.SelfPerception( | ||
model=model, | ||
pre_act_key=self_perception_label, | ||
logging_channel=measurements.get_channel('SelfPerception').on_next, | ||
) | ||
situation_perception_label = ( | ||
f'\nQuestion: What kind of situation is {agent_name} in ' | ||
'right now?\nAnswer') | ||
situation_perception = ( | ||
agent_components.question_of_recent_memories.SituationPerception( | ||
model=model, | ||
components={ | ||
_get_class_name(observation): observation_label, | ||
_get_class_name(somatic_state): somatic_state_label, | ||
_get_class_name(observation_summary): observation_summary_label, | ||
_get_class_name(relevant_memories): relevant_memories_label, | ||
}, | ||
clock_now=clock.now, | ||
pre_act_key=situation_perception_label, | ||
logging_channel=measurements.get_channel( | ||
'SituationPerception' | ||
).on_next, | ||
) | ||
) | ||
person_by_situation_label = ( | ||
f'\nQuestion: What would a person like {agent_name} do in ' | ||
'a situation like this?\nAnswer') | ||
person_by_situation = ( | ||
agent_components.question_of_recent_memories.PersonBySituation( | ||
model=model, | ||
components={ | ||
_get_class_name(self_perception): self_perception_label, | ||
_get_class_name(situation_perception): situation_perception_label, | ||
}, | ||
clock_now=clock.now, | ||
pre_act_key=person_by_situation_label, | ||
logging_channel=measurements.get_channel('PersonBySituation').on_next, | ||
) | ||
) | ||
|
||
entity_components = ( | ||
# Components that provide pre_act context. | ||
instructions, | ||
time_display, | ||
observation, | ||
observation_summary, | ||
relevant_memories, | ||
self_perception, | ||
situation_perception, | ||
person_by_situation, | ||
|
||
# Components that do not provide pre_act context. | ||
somatic_state, | ||
) | ||
|
||
components_of_agent = {_get_class_name(component): component | ||
for component in entity_components} | ||
components_of_agent[ | ||
agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = ( | ||
agent_components.memory_component.MemoryComponent(raw_memory)) | ||
components_of_agent.update(additional_components) | ||
|
||
act_component = agent_components.puppet_act_component.PuppetActComponent( | ||
model=model, | ||
clock=clock, | ||
logging_channel=measurements.get_channel('ActComponent').on_next, | ||
fixed_responses=fixed_response_by_call_to_action, | ||
) | ||
|
||
agent = entity_agent_with_logging.EntityAgentWithLogging( | ||
agent_name=agent_name, | ||
act_component=act_component, | ||
context_components=components_of_agent, | ||
component_logging=measurements, | ||
) | ||
|
||
return agent |
Oops, something went wrong.