Skip to content

Commit

Permalink
Improve clarity of GameMaster and fix type annotations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673317752
Change-Id: I395953ec45f469ffb7ea8afae3a39d66227f59aa
  • Loading branch information
jagapiou authored and copybara-github committed Sep 11, 2024
1 parent c7ffc64 commit bd22016
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 86 deletions.
140 changes: 56 additions & 84 deletions concordia/environment/game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

"""A Generic Game Master."""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
import dataclasses
import datetime
import functools
import random
from typing import Any

from concordia import components as generic_components
from concordia.agents import deprecated_agent
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.document import interactive_document
from concordia.language_model import language_model
Expand Down Expand Up @@ -83,10 +81,7 @@ def __init__(
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.GameClock,
players: Sequence[
deprecated_agent.BasicAgent
| entity_agent_with_logging.EntityAgentWithLogging
],
players: Sequence[agent_lib.GenerativeAgent],
name: str = 'Game Master',
update_thought_chain: Sequence[
Callable[[interactive_document.InteractiveDocument, str, str], str]
Expand Down Expand Up @@ -149,7 +144,7 @@ def __init__(
)
components.insert(0, instructions_component)

self._components = {}
self._components: dict[str, component.Component] = {}
for comp in components:
if comp.name() in self._components:
raise ValueError(f'Duplicate component name: {comp.name()}')
Expand All @@ -167,9 +162,7 @@ def __init__(
self._concurrent_externalities = concurrent_externalities
self._log = []

self.reset()

@property
@functools.cached_property
def name(self) -> str:
return self._name

Expand All @@ -194,36 +187,31 @@ def get_memory(self) -> associative_memory.AssociativeMemory:
def _print(self, entry, color=None):
print(termcolor.colored(entry, color or self._log_color))

def reset(self):
self._last_chain = None
self._num_players = len(self._players_by_name.keys())

def get_player_names(self):
return list(self._players_by_name.keys())

def _update_from_player(self, player_name: str, action_attempt: str):
prompt = interactive_document.InteractiveDocument(self._model)
def _handle_action(self, player_name: str, action_attempt: str) -> None:
"""Resolves a given action attempt."""

concurrency.map_parallel(
lambda construct: construct.update_before_event(
f'{player_name}: {action_attempt}'
),
self._components.values(),
)
concurrency.run_tasks({
name: functools.partial(
component.update_before_event, f'{player_name}: {action_attempt}'
)
for name, component in self._components.items()
})

# Produce the event that has happened as the result of the action attempt
prompt = interactive_document.InteractiveDocument(self._model)
for comp in self._components.values():
state_of_component = comp.state()
if state_of_component:
prompt.statement(comp.name() + ': ' + state_of_component + '\n')

prompt.statement(f"\n{player_name}'s attempted action: {action_attempt}")

# Produce the event that has happened as the result of the action attempt
prompt, event_statement = thought_chains.run_chain_of_thought(
self._update_from_player_thoughts,
action_attempt,
prompt,
player_name,
thoughts=self._update_from_player_thoughts,
premise=action_attempt,
document=prompt,
active_player_name=player_name,
)

self._memory.add(event_statement)
Expand Down Expand Up @@ -259,41 +247,36 @@ def _update_from_player(self, player_name: str, action_attempt: str):
}

# Consequences
def get_externality(externality):
return externality.update_after_event(event_statement)

if self._concurrent_externalities:
concurrency.map_parallel(get_externality, self._components.values())
concurrency.run_tasks({
name: functools.partial(
externality.update_after_event, event_statement
)
for name, externality in self._components.items()
})
else:
for externality in self._components.values():
externality.update_after_event(event_statement)

self._last_chain = prompt

for externality in self._components.values():
last_log = externality.get_last_log()
if last_log:
if 'date' in last_log.keys():
last_log.pop('date')
if 'Event statement' in last_log.keys():
last_log.pop('Event statement')

last_log = dict(last_log)
last_log.pop('date', None)
last_log.pop('Event statement', None)
update_log[externality.name()] = last_log

self._log.append(update_log)

return event_statement

def _view_for_player(self, player_name):
"""Send observations to a player."""
def _player_observations(self, player_name) -> Iterator[str]:
"""Yields the players observations."""
for comp in self._components.values():
state_of_component = comp.partial_state(player_name)
if state_of_component:
for observation in state_of_component.splitlines():
if observation:
self._players_by_name[player_name].observe(observation)

return
if not state_of_component:
continue
for observation in state_of_component.splitlines():
if observation:
yield observation

def _update_components(self) -> None:
concurrency.run_tasks({
Expand All @@ -307,30 +290,24 @@ def _update_components(self) -> None:

def _step_player(
self,
player: deprecated_agent.BasicAgent,
action_spec: agent_lib.ActionSpec | None = None,
):
player: agent_lib.GenerativeAgent,
action_spec: agent_lib.ActionSpec,
) -> None:
self._update_components()
self._view_for_player(player_name=player.name)

if action_spec is None:
action_spec_this_time = self._action_spec[player.name]
else:
action_spec_this_time = action_spec

action = player.act(action_spec_this_time)
action_spec_this_time.validate(action)

self._update_from_player(action_attempt=action, player_name=player.name)
for observation in self._player_observations(player.name):
player.observe(observation)
action = player.act(action_spec)
action_spec.validate(action)
self._handle_action(player.name, action)

def step(
self,
*,
active_players: Sequence[deprecated_agent.BasicAgent] | None = None,
action_spec: (
active_players: Sequence[agent_lib.GenerativeAgent] | None = None,
action_spec_override: (
agent_lib.ActionSpec | Mapping[str, agent_lib.ActionSpec] | None
) = None,
):
) -> None:
"""Steps the game.
At each step players all take a turn 'quasisimultaneously' with regard to
Expand All @@ -339,32 +316,27 @@ def step(
Args:
active_players: Optionally specify players to take turns in this round.
action_spec: Optionally specify what kind of actions to ask the agents to
generate.
action_spec_override: Optionally override what kind of actions to ask the
agents to generate.
"""
if active_players:
players = list(active_players)
else:
players = list(self._players_by_name.values())
if self._randomise_initiative:
random.shuffle(players)

if action_spec is None:
step_player_fn = lambda player: self._step_player(player=player)
elif isinstance(action_spec, Mapping):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec[player.name]
)
elif isinstance(action_spec, agent_lib.ActionSpec):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec
)
if action_spec_override is None:
action_spec = self._action_spec
elif isinstance(action_spec_override, Mapping):
action_spec = self._action_spec | action_spec_override
elif isinstance(action_spec_override, agent_lib.ActionSpec):
action_spec = {player.name: action_spec_override for player in players}
else:
raise TypeError('Invalid action_spec parameter type')

if self._randomise_initiative:
random.shuffle(players)

for player in players:
step_player_fn(player)
self._step_player(player, action_spec=action_spec[player.name])
if not self._players_act_simultaneously:
self._clock.advance()
if self._players_act_simultaneously:
Expand Down
6 changes: 4 additions & 2 deletions concordia/environment/scenes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def run_scenes(
# Run the scene
for _ in range(scene.num_rounds):
this_scene_game_master_memory.add(f'[scene type] {scene.scene_type.name}')
this_scene_environment.step(active_players=participants,
action_spec=scene.scene_type.action_spec)
this_scene_environment.step(
active_players=participants,
action_spec_override=scene.scene_type.action_spec,
)
if this_scene_environment.terminate_episode():
break

Expand Down

0 comments on commit bd22016

Please sign in to comment.