Skip to content

Commit

Permalink
Fix type annotations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661217470
Change-Id: I55dd5779eaf395c6fd2ccdc42d857d49162dd063
  • Loading branch information
jagapiou authored and copybara-github committed Aug 9, 2024
1 parent db59748 commit 94ccb88
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 39 deletions.
67 changes: 33 additions & 34 deletions concordia/environment/game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

"""A Generic Game Master."""

from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
import dataclasses
import datetime
import random
from typing import Any, Mapping, Union
from typing import Any

from concordia import components as generic_components
from concordia.agents import basic_agent
Expand Down Expand Up @@ -79,6 +79,7 @@ class GameMaster(simulacrum_game_master.GameMaster):

def __init__(
self,
*,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.GameClock,
Expand All @@ -93,10 +94,11 @@ def __init__(
[interactive_document.InteractiveDocument, str, str], str
]
]
| None
) = None,
components: Sequence[component.Component] | None = None,
action_spec: agent_lib.ActionSpec | None = None,
) = DEFAULT_THOUGHTS,
components: Sequence[component.Component] = (),
action_spec: agent_lib.ActionSpec | Mapping[str, agent_lib.ActionSpec] = (
agent_lib.DEFAULT_ACTION_SPEC
),
randomise_initiative: bool = False,
player_observes_event: bool = True,
players_act_simultaneously: bool = True,
Expand All @@ -117,8 +119,7 @@ def __init__(
name: name of the game master.
update_thought_chain: chain of thoughts for update from player
components: components to condition on
action_spec: specific action_spec to pass to agents, default is used if
None
action_spec: action_specs to pass to agents
randomise_initiative: whether to randomise initiative (who goes first )
order
player_observes_event: send outcome of the players action back as
Expand All @@ -143,10 +144,13 @@ def __init__(
self._randomise_initiative = randomise_initiative
self._player_observes_event = player_observes_event
self._players_act_simultaneously = players_act_simultaneously
self._action_spec = action_spec or agent_lib.DEFAULT_ACTION_SPEC
if isinstance(action_spec, agent_lib.ActionSpec):
self._action_spec = {player.name: action_spec for player in players}
else:
self._action_spec = dict(action_spec)
self._concurrent_action = concurrent_action

components = list(components or [])
components = list(components)
if use_default_instructions:
instructions_component = generic_components.constant.ConstantComponent(
state=DEFAULT_GAME_MASTER_INSTRUCTIONS, name='Instructions'
Expand All @@ -162,7 +166,7 @@ def __init__(

self._verbose = verbose

self._update_from_player_thoughts = update_thought_chain or DEFAULT_THOUGHTS
self._update_from_player_thoughts = update_thought_chain

self._players_by_name = {player.name: player for player in players}
if len(self._players_by_name) != len(players):
Expand All @@ -177,7 +181,7 @@ def __init__(
def name(self) -> str:
return self._name

def get_history(self):
def get_history(self) -> Sequence[Mapping[str, Any]]:
return self._log.copy()

def insert_history(self, log_entry: LogEntry):
Expand All @@ -189,7 +193,7 @@ def insert_history(self, log_entry: LogEntry):
}
self._log.append(update_log)

def extend_history(self, new_history: Sequence[Any]):
def extend_history(self, new_history: Sequence[Mapping[str, Any]]):
self._log.extend(new_history)

def get_memory(self) -> associative_memory.AssociativeMemory:
Expand Down Expand Up @@ -323,10 +327,10 @@ def _step_player(
self.update_components()
self.view_for_player(player_name=player.name)

if action_spec:
action_spec_this_time = action_spec
if action_spec is None:
action_spec_this_time = self._action_spec[player.name]
else:
action_spec_this_time = self._action_spec
action_spec_this_time = action_spec

action = player.act(action_spec_this_time)
action_spec_this_time.validate(action)
Expand All @@ -338,11 +342,7 @@ def step(
*,
active_players: Sequence[basic_agent.BasicAgent] | None = None,
action_spec: (
Union[
Mapping[str, agent_lib.ActionSpec],
agent_lib.ActionSpec,
]
| None
agent_lib.ActionSpec | Mapping[str, agent_lib.ActionSpec] | None
) = None,
):
"""Steps the game.
Expand All @@ -353,27 +353,26 @@ def step(
Args:
active_players: Optionally specify players to take turns in this round.
action_spec: Optionally specify what kind of action to ask the agent to
action_spec: Optionally specify what kind of actions to ask the agents to
generate.
"""
if active_players:
players = list(active_players)
else:
players = list(self._players_by_name.values())

if action_spec:
if isinstance(action_spec, Mapping):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec[player.name]
)
elif isinstance(action_spec, agent_lib.ActionSpec):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec
)
else:
raise TypeError('Invalid action_spec parameter type')
else:
if action_spec is None:
step_player_fn = lambda player: self._step_player(player=player)
elif isinstance(action_spec, Mapping):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec[player.name]
)
elif isinstance(action_spec, agent_lib.ActionSpec):
step_player_fn = lambda player: self._step_player(
player=player, action_spec=action_spec
)
else:
raise TypeError('Invalid action_spec parameter type')

if self._randomise_initiative:
random.shuffle(players)
Expand Down
1 change: 1 addition & 0 deletions concordia/factory/agent/basic_entity_agent__main_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _get_class_name(object_: object) -> str:


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def _get_class_name(object_: object) -> str:


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta,
additional_components: Mapping[
entity_component.ComponentName, str
entity_component.ComponentName,
entity_component.ContextComponent,
] = types.MappingProxyType({}),
) -> entity_agent_with_logging.EntityAgentWithLogging:
"""Build an agent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _get_class_name(object_: object) -> str:


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def _get_class_name(object_: object) -> str:


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
clock: game_clock.MultiIntervalClock,
update_time_interval: datetime.timedelta,
additional_components: Mapping[
entity_component.ComponentName, str
entity_component.ComponentName,
entity_component.ContextComponent,
] = types.MappingProxyType({}),
) -> entity_agent_with_logging.EntityAgentWithLogging:
"""Build an agent.
Expand Down
1 change: 1 addition & 0 deletions concordia/factory/agent/synthetic_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _get_class_name(object_: object) -> str:


def build_agent(
*,
config: formative_memories.AgentConfig,
model: language_model.LanguageModel,
memory: associative_memory.AssociativeMemory,
Expand Down
7 changes: 4 additions & 3 deletions examples/modular/environment/labor_collective_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from concordia.language_model import language_model
from concordia.thought_chains import thought_chains as thought_chains_lib
from concordia.typing import agent as agent_lib
from concordia.typing import component
from concordia.typing import scene as scene_lib
from concordia.utils import concurrency
from concordia.utils import measurements as measurements_lib
Expand Down Expand Up @@ -518,7 +517,10 @@ def get_inventories_component(
main_players: Sequence[basic_agent.BasicAgent],
player_configs: Sequence[formative_memories.AgentConfig],
clock_now: Callable[[], datetime.datetime] = datetime.datetime.now,
) -> tuple[component.Component, gm_components.inventory_based_score.Score]:
) -> tuple[
gm_components.inventory.Inventory,
gm_components.inventory_based_score.Score,
]:
"""Get the inventory tracking component for the game master."""
money_config = ItemTypeConfig(name='coin')
player_initial_endowments = {
Expand All @@ -536,7 +538,6 @@ def get_inventories_component(
name='possessions',
verbose=True,
)

score = gm_components.inventory_based_score.Score(
inventory=inventories,
players=main_players, # Only main players get a score.
Expand Down

0 comments on commit 94ccb88

Please sign in to comment.