From bd220166a33a6e3834c3561ae03561fbfb8f119d Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Wed, 11 Sep 2024 03:21:19 -0700 Subject: [PATCH] Improve clarity of GameMaster and fix type annotations PiperOrigin-RevId: 673317752 Change-Id: I395953ec45f469ffb7ea8afae3a39d66227f59aa --- concordia/environment/game_master.py | 140 ++++++++++--------------- concordia/environment/scenes/runner.py | 6 +- 2 files changed, 60 insertions(+), 86 deletions(-) diff --git a/concordia/environment/game_master.py b/concordia/environment/game_master.py index 99e2b471..583a81a2 100644 --- a/concordia/environment/game_master.py +++ b/concordia/environment/game_master.py @@ -14,7 +14,7 @@ """A Generic Game Master.""" -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence import dataclasses import datetime import functools @@ -22,8 +22,6 @@ from typing import Any from concordia import components as generic_components -from concordia.agents import deprecated_agent -from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model @@ -83,10 +81,7 @@ def __init__( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.GameClock, - players: Sequence[ - deprecated_agent.BasicAgent - | entity_agent_with_logging.EntityAgentWithLogging - ], + players: Sequence[agent_lib.GenerativeAgent], name: str = 'Game Master', update_thought_chain: Sequence[ Callable[[interactive_document.InteractiveDocument, str, str], str] @@ -149,7 +144,7 @@ def __init__( ) components.insert(0, instructions_component) - self._components = {} + self._components: dict[str, component.Component] = {} for comp in components: if comp.name() in self._components: raise ValueError(f'Duplicate component name: {comp.name()}') @@ -167,9 +162,7 @@ def __init__( self._concurrent_externalities = concurrent_externalities self._log = [] - self.reset() - - @property + @functools.cached_property def name(self) -> str: return self._name @@ -194,36 +187,31 @@ def get_memory(self) -> associative_memory.AssociativeMemory: def _print(self, entry, color=None): print(termcolor.colored(entry, color or self._log_color)) - def reset(self): - self._last_chain = None - self._num_players = len(self._players_by_name.keys()) - def get_player_names(self): return list(self._players_by_name.keys()) - def _update_from_player(self, player_name: str, action_attempt: str): - prompt = interactive_document.InteractiveDocument(self._model) + def _handle_action(self, player_name: str, action_attempt: str) -> None: + """Resolves a given action attempt.""" - concurrency.map_parallel( - lambda construct: construct.update_before_event( - f'{player_name}: {action_attempt}' - ), - self._components.values(), - ) + concurrency.run_tasks({ + name: functools.partial( + component.update_before_event, f'{player_name}: {action_attempt}' + ) + for name, component in self._components.items() + }) + # Produce the event that has happened as the result of the action attempt + prompt = interactive_document.InteractiveDocument(self._model) for comp in self._components.values(): state_of_component = comp.state() if state_of_component: prompt.statement(comp.name() + ': ' + state_of_component + '\n') - prompt.statement(f"\n{player_name}'s attempted action: {action_attempt}") - - # Produce the event that has happened as the result of the action attempt prompt, event_statement = thought_chains.run_chain_of_thought( - self._update_from_player_thoughts, - action_attempt, - prompt, - player_name, + thoughts=self._update_from_player_thoughts, + premise=action_attempt, + document=prompt, + active_player_name=player_name, ) self._memory.add(event_statement) @@ -259,41 +247,36 @@ def _update_from_player(self, player_name: str, action_attempt: str): } # Consequences - def get_externality(externality): - return externality.update_after_event(event_statement) - if self._concurrent_externalities: - concurrency.map_parallel(get_externality, self._components.values()) + concurrency.run_tasks({ + name: functools.partial( + externality.update_after_event, event_statement + ) + for name, externality in self._components.items() + }) else: for externality in self._components.values(): externality.update_after_event(event_statement) - self._last_chain = prompt - for externality in self._components.values(): last_log = externality.get_last_log() if last_log: - if 'date' in last_log.keys(): - last_log.pop('date') - if 'Event statement' in last_log.keys(): - last_log.pop('Event statement') - + last_log = dict(last_log) + last_log.pop('date', None) + last_log.pop('Event statement', None) update_log[externality.name()] = last_log self._log.append(update_log) - return event_statement - - def _view_for_player(self, player_name): - """Send observations to a player.""" + def _player_observations(self, player_name) -> Iterator[str]: + """Yields the players observations.""" for comp in self._components.values(): state_of_component = comp.partial_state(player_name) - if state_of_component: - for observation in state_of_component.splitlines(): - if observation: - self._players_by_name[player_name].observe(observation) - - return + if not state_of_component: + continue + for observation in state_of_component.splitlines(): + if observation: + yield observation def _update_components(self) -> None: concurrency.run_tasks({ @@ -307,30 +290,24 @@ def _update_components(self) -> None: def _step_player( self, - player: deprecated_agent.BasicAgent, - action_spec: agent_lib.ActionSpec | None = None, - ): + player: agent_lib.GenerativeAgent, + action_spec: agent_lib.ActionSpec, + ) -> None: self._update_components() - self._view_for_player(player_name=player.name) - - if action_spec is None: - action_spec_this_time = self._action_spec[player.name] - else: - action_spec_this_time = action_spec - - action = player.act(action_spec_this_time) - action_spec_this_time.validate(action) - - self._update_from_player(action_attempt=action, player_name=player.name) + for observation in self._player_observations(player.name): + player.observe(observation) + action = player.act(action_spec) + action_spec.validate(action) + self._handle_action(player.name, action) def step( self, *, - active_players: Sequence[deprecated_agent.BasicAgent] | None = None, - action_spec: ( + active_players: Sequence[agent_lib.GenerativeAgent] | None = None, + action_spec_override: ( agent_lib.ActionSpec | Mapping[str, agent_lib.ActionSpec] | None ) = None, - ): + ) -> None: """Steps the game. At each step players all take a turn 'quasisimultaneously' with regard to @@ -339,32 +316,27 @@ def step( Args: active_players: Optionally specify players to take turns in this round. - action_spec: Optionally specify what kind of actions to ask the agents to - generate. + action_spec_override: Optionally override what kind of actions to ask the + agents to generate. """ if active_players: players = list(active_players) else: players = list(self._players_by_name.values()) + if self._randomise_initiative: + random.shuffle(players) - if action_spec is None: - step_player_fn = lambda player: self._step_player(player=player) - elif isinstance(action_spec, Mapping): - step_player_fn = lambda player: self._step_player( - player=player, action_spec=action_spec[player.name] - ) - elif isinstance(action_spec, agent_lib.ActionSpec): - step_player_fn = lambda player: self._step_player( - player=player, action_spec=action_spec - ) + if action_spec_override is None: + action_spec = self._action_spec + elif isinstance(action_spec_override, Mapping): + action_spec = self._action_spec | action_spec_override + elif isinstance(action_spec_override, agent_lib.ActionSpec): + action_spec = {player.name: action_spec_override for player in players} else: raise TypeError('Invalid action_spec parameter type') - if self._randomise_initiative: - random.shuffle(players) - for player in players: - step_player_fn(player) + self._step_player(player, action_spec=action_spec[player.name]) if not self._players_act_simultaneously: self._clock.advance() if self._players_act_simultaneously: diff --git a/concordia/environment/scenes/runner.py b/concordia/environment/scenes/runner.py index 9eb64137..6f2d3c33 100644 --- a/concordia/environment/scenes/runner.py +++ b/concordia/environment/scenes/runner.py @@ -116,8 +116,10 @@ def run_scenes( # Run the scene for _ in range(scene.num_rounds): this_scene_game_master_memory.add(f'[scene type] {scene.scene_type.name}') - this_scene_environment.step(active_players=participants, - action_spec=scene.scene_type.action_spec) + this_scene_environment.step( + active_players=participants, + action_spec_override=scene.scene_type.action_spec, + ) if this_scene_environment.terminate_episode(): break