Skip to content

Commit

Permalink
create abstract class for simulation, which has a method to return pl…
Browse files Browse the repository at this point in the history
…ayer memories. This fixes failing pytype tests

PiperOrigin-RevId: 675526545
Change-Id: I3668d69c1f51dd7e54b22f99c8ab2805cd108b1f
  • Loading branch information
vezhnick authored and copybara-github committed Sep 17, 2024
1 parent 400f8fc commit 8bbffdb
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/modular/environment/forbidden_fruit.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def outcome_summary_fn(
return result


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular/environment/haggling.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def outcome_summary_fn(
return results


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular/environment/labor_collective_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def get_inventories_component(
return inventories, score


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def get_inventories_component(
return inventories, score


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular/environment/pub_coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def outcome_summary_fn(
return results


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular/environment/reality_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def outcome_summary_fn(
return result


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular/environment/state_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def outcome_summary_fn(
return result


class Simulation(scenarios_lib.Runnable):
class Simulation(scenarios_lib.RunnableSimulationWithMemories):
"""Define the simulation API object for the launch script to interact with."""

def __init__(
Expand Down
30 changes: 23 additions & 7 deletions examples/modular/scenario/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

"""Define specific scenarios for the Concordia Challenge."""

import abc
from collections.abc import Callable, Collection, Mapping
import dataclasses
import importlib
import types

from concordia.associative_memory import associative_memory
from examples.modular import environment as environment_lib
from examples.modular.environment import supporting_agent_factory
from examples.modular.scenario import supporting_agents as bots
Expand All @@ -33,6 +35,17 @@

Runnable = Callable[[], tuple[logging_lib.SimulationOutcome, str]]


class RunnableSimulationWithMemories(Runnable):
"""Define the simulation API object for the launch script to interact with."""

@abc.abstractmethod
def get_all_player_memories(
self,
) -> Mapping[str, associative_memory.AssociativeMemory]:
raise NotImplementedError


DEFAULT_IMPORT_ENV_BASE_MODULE = environment_lib.__name__
DEFAULT_IMPORT_AGENT_BASE_MODULE = agent_lib.__name__
DEFAULT_IMPORT_SUPPORT_AGENT_MODULE = supporting_agent_factory.__name__
Expand Down Expand Up @@ -74,10 +87,12 @@ class ScenarioConfig:
labor_collective_action__fixed_rule_boss=SubstrateConfig(
description=(
'labor organization collective action with a boss who applies '
'a fixed decision-making rule'),
'a fixed decision-making rule'
),
environment='labor_collective_action',
supporting_agent_module=bots.SUPPORTING_AGENT_CONFIGS[
'labor_collective_action__fixed_rule_boss'],
'labor_collective_action__fixed_rule_boss'
],
),
pub_coordination=SubstrateConfig(
description=(
Expand All @@ -95,7 +110,6 @@ class ScenarioConfig:
environment='pub_coordination_mini',
supporting_agent_module='basic_puppet_agent',
),

pub_coordination_closures=SubstrateConfig(
description=(
'pub attendance coordination with one pub sometimes being closed'
Expand Down Expand Up @@ -304,13 +318,15 @@ def build_simulation(

if substrate_config.supporting_agent_module is None:
supporting_agent_module = None
elif isinstance(substrate_config.supporting_agent_module,
bots.SupportingAgentConfig):
elif isinstance(
substrate_config.supporting_agent_module, bots.SupportingAgentConfig
):
supporting_agent_module = bots_lib.SupportingAgentFactory(
module=importlib.import_module(
f'{support_agent_base_module}.'
f'{substrate_config.supporting_agent_module.module_name}'),
overrides=substrate_config.supporting_agent_module.overrides
f'{substrate_config.supporting_agent_module.module_name}'
),
overrides=substrate_config.supporting_agent_module.overrides,
)
else:
supporting_agent_module = importlib.import_module(
Expand Down

0 comments on commit 8bbffdb

Please sign in to comment.