From 4d9b17939457e51bf5d4d04f88c6e6c64ea06b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Pr=C3=B6schel?= Date: Mon, 7 Aug 2023 21:09:22 +0200 Subject: [PATCH] n player fixes, n player cournot --- pax/envs/cournot.py | 46 +++++++------ pax/experiment.py | 105 ++++++++++------------------- pax/runners/runner_eval_nplayer.py | 3 +- pax/runners/runner_evo.py | 3 +- pax/runners/runner_evo_nplayer.py | 4 +- pax/runners/runner_marl.py | 3 +- pax/runners/runner_marl_nplayer.py | 11 ++- pax/watchers/__init__.py | 3 +- pax/watchers/cournot.py | 13 ++-- requirements.txt | 1 + test/envs/test_cournot.py | 58 ++++++++-------- 11 files changed, 115 insertions(+), 135 deletions(-) diff --git a/pax/envs/cournot.py b/pax/envs/cournot.py index 97b6ced2..b3f2b4b9 100644 --- a/pax/envs/cournot.py +++ b/pax/envs/cournot.py @@ -18,39 +18,43 @@ class EnvParams: b: float marginal_cost: float - -def to_obs_array(params: EnvParams) -> jnp.ndarray: - return jnp.array([params.a, params.b, params.marginal_cost]) - - class CournotGame(environment.Environment): - def __init__(self, num_inner_steps: int): + def __init__(self, num_players: int, num_inner_steps: int): super().__init__() + self.num_players = num_players def _step( key: chex.PRNGKey, state: EnvState, - actions: Tuple[float, float], + actions: Tuple[float, ...], params: EnvParams, ): + assert len(actions) == num_players t = state.outer_t + done = t >= num_inner_steps key, _ = jax.random.split(key, 2) - q1 = actions[0] - q2 = actions[1] - p = params.a - params.b * (q1 + q2) - r1 = jnp.squeeze(p * q1 - params.marginal_cost * q1) - r2 = jnp.squeeze(p * q2 - params.marginal_cost * q2) - obs1 = jnp.concatenate([to_obs_array(params), jnp.array(q2), jnp.array(p)]) - obs2 = jnp.concatenate([to_obs_array(params), jnp.array(q1), jnp.array(p)]) - done = t >= num_inner_steps + actions = jnp.asarray(actions).squeeze() + actions = jnp.clip(actions, a_min=0) + p = params.a - params.b * actions.sum() + + all_obs = [] + all_rewards = [] + for i in range(num_players): + q = actions[i] + r = p * q - params.marginal_cost * q + obs = jnp.concatenate([actions, jnp.array([p])]) + all_obs.append(obs) + all_rewards.append(r) + state = EnvState( inner_t=state.inner_t + 1, outer_t=state.outer_t + 1 ) + return ( - (obs1, obs2), + tuple(all_obs), state, - (r1, r2), + tuple(all_rewards), done, {}, ) @@ -62,9 +66,9 @@ def _reset( inner_t=jnp.zeros((), dtype=jnp.int8), outer_t=jnp.zeros((), dtype=jnp.int8), ) - obs = jax.random.uniform(key, (2,)) - obs = jnp.concatenate([to_obs_array(params), obs]) - return (obs, obs), state + obs = jax.random.uniform(key, (num_players + 1,)) + obs = jnp.concatenate([obs]) + return tuple([obs for _ in range(num_players)]), state self.step = jax.jit(_step) self.reset = jax.jit(_reset) @@ -87,7 +91,7 @@ def action_space( def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=5, dtype=jnp.float32) + return spaces.Box(low=0, high=float('inf'), shape=self.num_players + 1, dtype=jnp.float32) @staticmethod def nash_policy(params: EnvParams) -> float: diff --git a/pax/experiment.py b/pax/experiment.py index 44669fed..774dd7df 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -50,21 +50,15 @@ from pax.envs.infinite_matrix_game import InfiniteMatrixGame from pax.envs.iterated_matrix_game import EnvParams as IteratedMatrixGameParams from pax.envs.iterated_matrix_game import IteratedMatrixGame -from pax.envs.iterated_tensor_game import EnvParams as IteratedTensorGameParams -from pax.envs.iterated_tensor_game import IteratedTensorGame from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer from pax.envs.iterated_tensor_game_n_player import ( EnvParams as IteratedTensorGameNPlayerParams, ) from pax.runners.runner_eval import EvalRunner -from pax.runners.runner_eval_3player import TensorEvalRunner from pax.runners.runner_eval_nplayer import NPlayerEvalRunner from pax.runners.runner_evo import EvoRunner -from pax.runners.runner_evo_3player import TensorEvoRunner -from pax.runners.runner_evo_nplayer import NPlayerEvoRunner -from pax.runners.runner_ipditm_eval import IPDITMEvalRunner +from pax.runners.runner_evo_nplayer import TensorEvoRunner from pax.runners.runner_marl import RLRunner -from pax.runners.runner_marl_3player import TensorRLRunner from pax.runners.runner_marl_nplayer import NplayerRLRunner from pax.runners.runner_sarl import SARLRunner from pax.runners.runner_ipditm_eval import IPDITMEvalRunner @@ -80,6 +74,7 @@ value_logger_ppo, ) + # NOTE: THIS MUST BE DONE BEFORE IMPORTING JAX # uncomment to debug multi-devices on CPU # os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" @@ -132,19 +127,6 @@ def env_setup(args, logger=None): f"Env Type: {args.env_type} | Inner Episode Length: {args.num_inner_steps}" ) logger.info(f"Outer Episode Length: {args.num_outer_steps}") - elif args.env_id == "iterated_tensor_game": - payoff = jnp.array(args.payoff) - - env = IteratedTensorGame( - num_inner_steps=args.num_inner_steps, - num_outer_steps=args.num_outer_steps, - ) - env_params = IteratedTensorGameParams(payoff_matrix=payoff) - - if logger: - logger.info( - f"Env Type: {args.env_type} s| Inner Episode Length: {args.num_inner_steps}" - ) elif args.env_id == "iterated_nplayer_tensor_game": payoff = jnp.array(args.payoff_table) @@ -210,7 +192,7 @@ def env_setup(args, logger=None): a=args.a, b=args.b, marginal_cost=args.marginal_cost ) env = CournotGame( - num_inner_steps=args.num_inner_steps, + num_players=args.num_players, num_inner_steps=args.num_inner_steps, ) if logger: logger.info( @@ -243,9 +225,6 @@ def runner_setup(args, env, agents, save_dir, logger): if args.runner == "eval": logger.info("Evaluating with EvalRunner") return EvalRunner(agents, env, args) - elif args.runner == "tensor_eval": - logger.info("Training with tensor eval Runner") - return TensorEvalRunner(agents, env, save_dir, args) elif args.runner == "tensor_eval_nplayer": logger.info("Training with n-player tensor eval Runner") return NPlayerEvalRunner(agents, env, save_dir, args) @@ -254,9 +233,9 @@ def runner_setup(args, env, agents, save_dir, logger): return IPDITMEvalRunner(agents, env, save_dir, args) if ( - args.runner == "evo" - or args.runner == "tensor_evo" - or args.runner == "tensor_evo_nplayer" + args.runner == "evo" + or args.runner == "tensor_evo" + or args.runner == "tensor_evo_nplayer" ): agent1 = agents[0] algo = args.es.algo @@ -362,23 +341,10 @@ def get_pgpe_strategy(agent): save_dir, args, ) - elif args.runner == "tensor_evo_nplayer": - return NPlayerEvoRunner( - agents, - env, - strategy, - es_params, - param_reshaper, - save_dir, - args, - ) elif args.runner == "rl": logger.info("Training with RL Runner") return RLRunner(agents, env, save_dir, args) - elif args.runner == "tensor_rl": - logger.info("Training with tensor RL Runner") - return TensorRLRunner(agents, env, save_dir, args) elif args.runner == "tensor_rl_nplayer": return NplayerRLRunner(agents, env, save_dir, args) elif args.runner == "sarl": @@ -393,9 +359,9 @@ def agent_setup(args, env, env_params, logger): """Set up agent variables.""" if ( - args.env_id == "iterated_matrix_game" - or args.env_id == "iterated_tensor_game" - or args.env_id == "iterated_nplayer_tensor_game" + args.env_id == "iterated_matrix_game" + or args.env_id == "iterated_tensor_game" + or args.env_id == "iterated_nplayer_tensor_game" ): obs_shape = env.observation_space(env_params).n elif args.env_id == "InTheMatrix": @@ -423,7 +389,8 @@ def get_PPO_memory_agent(seed, player_id): ) def get_PPO_agent(seed, player_id): - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id)) + default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None) + player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args) if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps @@ -559,11 +526,11 @@ def get_stay_agent(seed, player_id): return agent_1 else: - for i in range(1, args.num_players + 1): - assert ( - omegaconf.OmegaConf.select(args, "agent" + str(i)) - in strategies - ) + default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None) + agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in + range(1, args.num_players + 1)] + for strategy in agent_strategies: + assert strategy in strategies seeds = [ seed for seed in range(args.seed, args.seed + args.num_players) @@ -574,14 +541,12 @@ def get_stay_agent(seed, player_id): for seed, i in zip(seeds, range(1, args.num_players + 1)) ] agents = [] - for i in range(args.num_players): + for idx, strategy in enumerate(agent_strategies): agents.append( - strategies[ - omegaconf.OmegaConf.select(args, "agent" + str(i + 1)) - ](seeds[i], pids[i]) + strategies[strategy](seeds[idx], pids[idx]) ) logger.info( - f"Agent Pair: {[omegaconf.OmegaConf.select(args, 'agent' + str(i)) for i in range(1, args.num_players + 1)]}" + f"Agent Pair: {strategies}" ) logger.info(f"Agent seeds: {seeds}") @@ -592,7 +557,7 @@ def watcher_setup(args, logger): """Set up watcher variables.""" def ppo_memory_log( - agent, + agent, ): losses = losses_ppo(agent) if args.env_id not in [ @@ -615,6 +580,8 @@ def ppo_log(agent): losses = losses_ppo(agent) if args.env_id not in [ "coin_game", + "Cournot", + "Fishery", "InTheMatrix", "iterated_matrix_game", "iterated_tensor_game", @@ -697,13 +664,15 @@ def naive_pg_log(agent): assert args.agent1 in strategies agent_1_log = naive_pg_log # strategies[args.agent1] # - return agent_1_log else: agent_log = [] - for i in range(1, args.num_players + 1): - assert getattr(args, f"agent{i}") in strategies - agent_log.append(strategies[getattr(args, f"agent{i}")]) + default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None) + agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in + range(1, args.num_players + 1)] + for strategy in agent_strategies: + assert strategy in strategies + agent_log.append(strategies[strategy]) return agent_log @@ -734,25 +703,25 @@ def main(args): print(f"Number of Training Iterations: {args.num_iters}") if ( - args.runner == "evo" - or args.runner == "tensor_evo" - or args.runner == "tensor_evo_nplayer" + args.runner == "evo" + or args.runner == "tensor_evo" + or args.runner == "tensor_evo_nplayer" ): runner.run_loop(env_params, agent_pair, args.num_iters, watchers) elif ( - args.runner == "rl" - or args.runner == "tensor_rl" - or args.runner == "tensor_rl_nplayer" + args.runner == "rl" + or args.runner == "tensor_rl" + or args.runner == "tensor_rl_nplayer" ): # number of episodes print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env_params, agent_pair, args.num_iters, watchers) elif ( - args.runner == "ipditm_eval" - or args.runner == "tensor_eval" - or args.runner == "tensor_eval_nplayer" + args.runner == "ipditm_eval" + or args.runner == "tensor_eval" + or args.runner == "tensor_eval_nplayer" ): runner.run_loop(env_params, agent_pair, watchers) diff --git a/pax/runners/runner_eval_nplayer.py b/pax/runners/runner_eval_nplayer.py index 409fa861..bd8e3a91 100644 --- a/pax/runners/runner_eval_nplayer.py +++ b/pax/runners/runner_eval_nplayer.py @@ -7,11 +7,10 @@ from omegaconf import OmegaConf import wandb -from pax.utils import MemoryState, TrainingState, save, load +from pax.utils import MemoryState, TrainingState, load from pax.watchers import ( ipditm_stats, n_player_ipd_visitation, - tensor_ipd_visitation, ) MAX_WANDB_CALLS = 1000 diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index f7ba4561..5a811d00 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -12,7 +12,8 @@ # TODO: import when evosax library is updated # from evosax.utils import ESLog -from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats, fishery_stats +from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.fishery import fishery_stats from pax.watchers.cournot import cournot_stats MAX_WANDB_CALLS = 1000 diff --git a/pax/runners/runner_evo_nplayer.py b/pax/runners/runner_evo_nplayer.py index 4f4eeff5..304b500a 100644 --- a/pax/runners/runner_evo_nplayer.py +++ b/pax/runners/runner_evo_nplayer.py @@ -12,7 +12,7 @@ # TODO: import when evosax library is updated # from evosax.utils import ESLog -from pax.watchers import ESLog, cg_visitation, tensor_ipd_visitation +from pax.watchers import ESLog, cg_visitation, n_player_ipd_visitation MAX_WANDB_CALLS = 1000 @@ -72,7 +72,7 @@ def __init__( self.top_k = args.top_k self.train_steps = 0 self.train_episodes = 0 - self.ipd_stats = jax.jit(tensor_ipd_visitation) + self.ipd_stats = jax.jit(n_player_ipd_visitation) self.cg_stats = jax.jit(jax.vmap(cg_visitation)) # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index ef430428..3a9fa222 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -7,7 +7,8 @@ import wandb from pax.utils import MemoryState, TrainingState, save -from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats, fishery_stats +from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.fishery import fishery_stats from pax.watchers.cournot import cournot_stats MAX_WANDB_CALLS = 1000 diff --git a/pax/runners/runner_marl_nplayer.py b/pax/runners/runner_marl_nplayer.py index 3b1978f2..d693e33b 100644 --- a/pax/runners/runner_marl_nplayer.py +++ b/pax/runners/runner_marl_nplayer.py @@ -195,7 +195,6 @@ def _inner_rollout(carry, unused): # a2_rng = rngs[:, :, 2, :] rngs = rngs[:, :, 3, :] new_other_agent_mem = [None] * len(other_agents) - actions = [] ( first_action, @@ -206,7 +205,7 @@ def _inner_rollout(carry, unused): first_agent_obs, first_agent_mem, ) - actions.append(first_action) + actions = [first_action] for agent_idx, non_first_agent in enumerate(other_agents): ( non_first_action, @@ -306,6 +305,7 @@ def _outer_rollout(carry, unused): # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": first_agent_mem = agent1.meta_policy(first_agent_mem) + # TODO update first agent regularly? # update second agent for agent_idx, non_first_agent in enumerate(other_agents): @@ -539,6 +539,13 @@ def run_loop(self, env_params, agents, num_iters, watchers): agent1._logger.metrics = ( agent1._logger.metrics | flattened_metrics_1 ) + for agent, metric in zip(other_agents, other_agent_metrics): + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.mean(x), first_agent_metrics + ) + agent._logger.metrics = ( + agent._logger.metrics | flattened_metrics + ) for watcher, agent in zip(watchers, agents): watcher(agent) diff --git a/pax/watchers/__init__.py b/pax/watchers/__init__.py index a92baac6..1263b297 100644 --- a/pax/watchers/__init__.py +++ b/pax/watchers/__init__.py @@ -12,8 +12,7 @@ import pax.agents.hyper.ppo as HyperPPO import pax.agents.ppo.ppo as PPO from pax.agents.naive_exact import NaiveExact -from pax.envs.iterated_matrix_game import EnvState, IteratedMatrixGame -from pax.envs.in_the_matrix import InTheMatrix +from pax.envs.iterated_matrix_game import EnvState # five possible states START = jnp.array([[0, 0, 0, 0, 1]]) diff --git a/pax/watchers/cournot.py b/pax/watchers/cournot.py index 8cbfe6d9..a81f418f 100644 --- a/pax/watchers/cournot.py +++ b/pax/watchers/cournot.py @@ -1,20 +1,17 @@ -from typing import NamedTuple - from jax import numpy as jnp from pax.envs.cournot import EnvParams as CournotEnvParams, CournotGame -def cournot_stats(traj1: NamedTuple, traj2: NamedTuple, params: CournotEnvParams) -> dict: +def cournot_stats(observations: jnp.ndarray, params: CournotEnvParams, n_player: int) -> dict: opt_quantity = CournotGame.nash_policy(params) - average_quantity = (traj1.actions + traj2.actions) / 2 + #average_quantity = (traj1.actions + traj2.actions) / 2 return { - "quantity/1": jnp.mean(traj1.actions), - "quantity/2": jnp.mean(traj2.actions), + # "quantity/1": jnp.mean(traj1.actions), + # "quantity/2": jnp.mean(traj2.actions), "quantity/average": jnp.mean(average_quantity), - # How strongly do the joint actions deviate from the optimal quantity? + # How strongly do the joint actions deviate from the socially optimum? # Since the reward is a linear function of the quantity there is no need to consider it separately. "quantity/loss": jnp.mean((opt_quantity / 2 - average_quantity) ** 2), - "opt_quantity": opt_quantity, } diff --git a/requirements.txt b/requirements.txt index 273e12be..f7d98cc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ pytest wandb pytest-cov bsuite +yaml diff --git a/test/envs/test_cournot.py b/test/envs/test_cournot.py index e8fcf00f..a989cb59 100644 --- a/test/envs/test_cournot.py +++ b/test/envs/test_cournot.py @@ -7,31 +7,33 @@ def test_single_cournot_game(): rng = jax.random.PRNGKey(0) - env = CournotGame(num_inner_steps=1) - # This means the optimum production quantity is Q = q1 + q2 = 2(a-marginal_cost)/3b = 60 - env_params = EnvParams(a=100, b=1, marginal_cost=10) - nash_q = CournotGame.nash_policy(env_params) - assert nash_q == 60 - nash_action = jnp.array([nash_q / 2]) - - obs, env_state = env.reset(rng, env_params) - obs, env_state, rewards, done, info = env.step( - rng, env_state, (nash_action, nash_action), env_params - ) - - assert rewards[0] == rewards[1] - # p_opt = 100 - (30 + 30) = 40 - # r1_opt = 40 * 30 - 10 * 30 = 900 - nash_reward = CournotGame.nash_reward(env_params) - assert nash_reward == 1800 - assert jnp.isclose(nash_reward / 2, rewards[0], atol=0.01) - assert jnp.allclose(obs[0][3:], jnp.array([30, 40]), atol=0.01) - assert jnp.allclose(obs[1][3:], jnp.array([30, 40]), atol=0.01) - - social_opt_action = jnp.array([45 / 2]) - obs, env_state = env.reset(rng, env_params) - obs, env_state, rewards, done, info = env.step( - rng, env_state, (social_opt_action, social_opt_action), env_params - ) - assert rewards[0] == rewards[1] - assert rewards[0] + rewards[1] == 2025 + for n_player in [2, 3, 12]: + env = CournotGame(num_players=n_player, num_inner_steps=1) + # This means the optimum production quantity is Q = q1 + q2 = 2(a-marginal_cost)/3b = 60 + env_params = EnvParams(a=100, b=1, marginal_cost=10) + nash_q = CournotGame.nash_policy(env_params) + assert nash_q == 60 + nash_action = jnp.array([nash_q / n_player]) + + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, env_state, tuple([nash_action for _ in range(n_player)]), env_params + ) + + assert all(element == rewards[0] for element in rewards) + # p_opt = 100 - (30 + 30) = 40 + # r1_opt = 40 * 30 - 10 * 30 = 900 + nash_reward = CournotGame.nash_reward(env_params) + assert nash_reward == 1800 + assert jnp.isclose(nash_reward / n_player, rewards[0], atol=0.01) + expected_obs = jnp.array([60/n_player for _ in range(n_player)] + [40]) + assert jnp.allclose(obs[0], expected_obs, atol=0.01) + assert jnp.allclose(obs[0], obs[1], atol=0.0) + + social_opt_action = jnp.array([45 / n_player]) + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, env_state, tuple([social_opt_action for _ in range(n_player)]), env_params + ) + assert jnp.asarray(rewards).sum() == 2025 +