From bb280e67971923b2e063dc370267284aa353c5ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Pr=C3=B6schel?= Date: Tue, 24 Oct 2023 17:13:07 +0200 Subject: [PATCH] Bring back old runner_evo, udpates to runner docs, finish runner smoke tests --- docs/getting-started/runners.md | 29 +- pax/agents/ppo/ppo.py | 6 +- pax/conf/experiment/c_rice/debug.yaml | 2 +- pax/conf/experiment/c_rice/marl_baseline.yaml | 2 +- .../experiment/c_rice/mediator_gs_ppo.yaml | 2 +- pax/conf/experiment/c_rice/shaper_v_ppo.yaml | 2 +- pax/conf/experiment/cg/mfos.yaml | 63 +- pax/conf/experiment/cg/tabular.yaml | 29 +- .../experiment/cournot/eval_shaper_v_ppo.yaml | 2 +- pax/conf/experiment/cournot/shaper_v_ppo.yaml | 2 +- .../experiment/fishery/marl_baseline.yaml | 2 +- pax/conf/experiment/fishery/mfos_v_ppo.yaml | 2 +- pax/conf/experiment/fishery/shaper_v_ppo.yaml | 2 +- pax/conf/experiment/rice/gs_v_ppo.yaml | 2 +- pax/conf/experiment/rice/mfos_v_ppo.yaml | 2 +- pax/conf/experiment/rice/shaper_v_ppo.yaml | 4 +- pax/envs/iterated_tensor_game_n_player.py | 1 + pax/experiment.py | 17 +- pax/runners/runner_evo.py | 355 +++----- pax/runners/runner_evo_nroles.py | 755 ++++++++++++++++++ pax/runners/runner_marl.py | 2 +- test/runners/test_runners.py | 23 +- 22 files changed, 1033 insertions(+), 273 deletions(-) create mode 100644 pax/runners/runner_evo_nroles.py diff --git a/docs/getting-started/runners.md b/docs/getting-started/runners.md index 899f24fb..43dc241d 100644 --- a/docs/getting-started/runners.md +++ b/docs/getting-started/runners.md @@ -1,9 +1,28 @@ -# Runner +# Runners + +## Evo Runner + +The Evo Runner optimizes the first agent using evolutionary learning. + +See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/cg/mfos.yaml) for an example of how to configure it. + +## Evo Runner N-Roles + +This runner extends the evo runner to `N > 2` agents by letting the first and second agent assume multiple roles that can be configured via `agent1_roles` and `agent2_roles` in the experiment configuration. +Both agents receive different sets of memories for each role that they assume but share the weights. + +- For heterogeneous games roles can be shuffled for each rollout using the `shuffle_players` flag. +- Using the `self_play_anneal` flag one can anneal the self-play probability from 0 to 1 over the course of the experiment. + +See [this experiment](https://github.com/akbir/pax/blob/bb0e69ef71fd01ec9c85753814ffba3c5cb77935/pax/conf/experiment/rice/shaper_v_ppo.yaml) for an example of how to configure it. + +## Weight sharing Runner + +A simple baseline for MARL experiments is having one agent assume multiple roles and share the weights between them (but not the memory). +In order for this approach to work the observation vector needs to include one entry that indicates the role of the agent (see [Terry et al.](https://arxiv.org/abs/2005.13625v7). + +See [this experiment](https://github.com/akbir/pax/blob/9d3fa62e34279a338c07cffcbf208edc8a95e7ba/pax/conf/experiment/rice/weight_sharing.yaml) for an example of how to configure it. -## Runner 1 -Lorem ipsum. -## Runner 2 -Lorem ipsum. diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index 6402bf3c..9a098846 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -19,6 +19,7 @@ make_ipd_network, ) from pax.envs.iterated_matrix_game import IteratedMatrixGame +from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer from pax.envs.rice.c_rice import ClubRice from pax.envs.rice.rice import Rice from pax.envs.rice.sarl_rice import SarlRice @@ -517,7 +518,10 @@ def make_agent( network = make_rice_sarl_network(action_spec, agent_args.hidden_size) elif args.runner == "sarl": network = make_sarl_network(action_spec) - elif args.env_id == IteratedMatrixGame.env_id: + elif args.env_id in [ + IteratedMatrixGame.env_id, + IteratedTensorGameNPlayer.env_id, + ]: network = make_ipd_network(action_spec, True, agent_args.hidden_size) else: raise NotImplementedError( diff --git a/pax/conf/experiment/c_rice/debug.yaml b/pax/conf/experiment/c_rice/debug.yaml index a34b6ccc..62bbb95e 100644 --- a/pax/conf/experiment/c_rice/debug.yaml +++ b/pax/conf/experiment/c_rice/debug.yaml @@ -10,7 +10,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/c_rice/marl_baseline.yaml b/pax/conf/experiment/c_rice/marl_baseline.yaml index 5c30f380..dd223cf2 100644 --- a/pax/conf/experiment/c_rice/marl_baseline.yaml +++ b/pax/conf/experiment/c_rice/marl_baseline.yaml @@ -9,7 +9,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True # Training diff --git a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml index f5ca57e2..acd73fdd 100644 --- a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml +++ b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml @@ -12,7 +12,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True agent2_reset_interval: 10 diff --git a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml index 81d2edbf..719e6993 100644 --- a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True default_club_mitigation_rate: 0.1 diff --git a/pax/conf/experiment/cg/mfos.yaml b/pax/conf/experiment/cg/mfos.yaml index d56a2c4b..1b4bb1dc 100644 --- a/pax/conf/experiment/cg/mfos.yaml +++ b/pax/conf/experiment/cg/mfos.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'MFOS' agent2: 'PPO_memory' @@ -11,24 +11,27 @@ egocentric: True env_discount: 0.96 payoff: [[1, 1, -2], [1, 1, -2]] -# Runner +# Runner runner: evo +top_k: 4 +popsize: 1000 #512 # env_batch_size = num_envs * num_opponents num_envs: 250 num_opps: 1 num_outer_steps: 600 -num_inner_steps: 16 -save_interval: 100 +num_inner_steps: 16 +save_interval: 100 +num_steps: '${num_inner_steps}' -# Evaluation +# Evaluation run_path: ucl-dark/cg/12auc9um model_path: exp/sanity-PPO-vs-PPO-parity/run-seed-0/2022-09-08_20.04.17.155963/iteration_500 # PPO agent parameters -ppo: +ppo1: num_minibatches: 8 - num_epochs: 2 + num_epochs: 2 gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 @@ -49,6 +52,52 @@ ppo: separate: True # only works with CNN hidden_size: 16 #50 +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.01 #0.05 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True # only works with CNN + hidden_size: 16 #50 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/cg/tabular.yaml b/pax/conf/experiment/cg/tabular.yaml index 39277290..9240a691 100644 --- a/pax/conf/experiment/cg/tabular.yaml +++ b/pax/conf/experiment/cg/tabular.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'Tabular' agent2: 'Random' @@ -25,9 +25,32 @@ num_iters: 10000 # train_batch_size = num_envs * num_opponents * num_steps # PPO agent parameters -ppo: +ppo1: num_minibatches: 8 - num_epochs: 2 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: True + learning_rate: 0.01 #0.05 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True # only works with CNN + hidden_size: 16 #50 + +ppo2: + num_minibatches: 8 + num_epochs: 2 gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 diff --git a/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml b/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml index 10a9192b..280b9a3c 100644 --- a/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml +++ b/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml @@ -12,7 +12,7 @@ b: 1 marginal_cost: 10 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/cournot/shaper_v_ppo.yaml b/pax/conf/experiment/cournot/shaper_v_ppo.yaml index 2f84fed8..ef4f0fac 100644 --- a/pax/conf/experiment/cournot/shaper_v_ppo.yaml +++ b/pax/conf/experiment/cournot/shaper_v_ppo.yaml @@ -13,7 +13,7 @@ b: 1 marginal_cost: 10 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/fishery/marl_baseline.yaml b/pax/conf/experiment/fishery/marl_baseline.yaml index f0faa91b..ef42b6bd 100644 --- a/pax/conf/experiment/fishery/marl_baseline.yaml +++ b/pax/conf/experiment/fishery/marl_baseline.yaml @@ -15,7 +15,7 @@ s_0: 0.5 s_max: 1.0 # This means the optimum quantity is 2(a-marginal_cost)/3b = 60 -runner: evo +runner: evo_nroles # env_batch_size = num_envs * num_opponents num_envs: 100 diff --git a/pax/conf/experiment/fishery/mfos_v_ppo.yaml b/pax/conf/experiment/fishery/mfos_v_ppo.yaml index 78e4aa97..a3a828ac 100644 --- a/pax/conf/experiment/fishery/mfos_v_ppo.yaml +++ b/pax/conf/experiment/fishery/mfos_v_ppo.yaml @@ -15,7 +15,7 @@ s_0: 0.5 s_max: 1.0 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/fishery/shaper_v_ppo.yaml b/pax/conf/experiment/fishery/shaper_v_ppo.yaml index 93831be4..8370a98b 100644 --- a/pax/conf/experiment/fishery/shaper_v_ppo.yaml +++ b/pax/conf/experiment/fishery/shaper_v_ppo.yaml @@ -14,7 +14,7 @@ w: 0.9 s_0: 0.5 s_max: 1.0 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/rice/gs_v_ppo.yaml b/pax/conf/experiment/rice/gs_v_ppo.yaml index daf31c69..ece5eb89 100644 --- a/pax/conf/experiment/rice/gs_v_ppo.yaml +++ b/pax/conf/experiment/rice/gs_v_ppo.yaml @@ -12,7 +12,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training diff --git a/pax/conf/experiment/rice/mfos_v_ppo.yaml b/pax/conf/experiment/rice/mfos_v_ppo.yaml index 13e5f4f4..a0870424 100644 --- a/pax/conf/experiment/rice/mfos_v_ppo.yaml +++ b/pax/conf/experiment/rice/mfos_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/rice/shaper_v_ppo.yaml b/pax/conf/experiment/rice/shaper_v_ppo.yaml index 22e4c281..78cae9a2 100644 --- a/pax/conf/experiment/rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/rice/shaper_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True # Training @@ -22,7 +22,7 @@ top_k: 5 popsize: 1000 num_envs: 1 num_opps: 1 -num_outer_steps: 200 +num_outer_steps: e200 num_inner_steps: 200 num_iters: 1500 num_devices: 1 diff --git a/pax/envs/iterated_tensor_game_n_player.py b/pax/envs/iterated_tensor_game_n_player.py index 8335228f..1b060569 100644 --- a/pax/envs/iterated_tensor_game_n_player.py +++ b/pax/envs/iterated_tensor_game_n_player.py @@ -18,6 +18,7 @@ class EnvParams: class IteratedTensorGameNPlayer(environment.Environment): + env_id = "iterated_nplayer_tensor_game" """ JAX Compatible version of tensor game environment. """ diff --git a/pax/experiment.py b/pax/experiment.py index 55b81065..15493d08 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -60,6 +60,7 @@ from pax.envs.rice.c_rice import ClubRice from pax.envs.rice.rice import Rice, EnvParams as RiceParams from pax.envs.rice.sarl_rice import SarlRice +from pax.runners.runner_evo_nroles import EvoRunnerNRoles from pax.runners.runner_weight_sharing import WeightSharingRunner from pax.runners.runner_eval import EvalRunner from pax.runners.runner_eval_multishaper import MultishaperEvalRunner @@ -281,7 +282,7 @@ def runner_setup(args, env, agents, save_dir, logger): logger.info("Evaluating with ipditmEvalRunner") return IPDITMEvalRunner(agents, env, save_dir, args) - if args.runner == "evo" or args.runner == "multishaper_evo": + if args.runner in ["evo", "multishaper_evo", "evo_nroles"]: agent1 = agents[0] algo = args.es.algo strategies = {"CMA_ES", "OpenES", "PGPE", "SimpleGA"} @@ -378,6 +379,18 @@ def get_pgpe_strategy(agent): args, ) + elif args.runner == "evo_nroles": + logger.info("Training with n_roles EVO runner") + return EvoRunnerNRoles( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) + elif args.runner == "multishaper_evo": logger.info("Training with multishaper EVO runner") return MultishaperEvoRunner( @@ -782,7 +795,7 @@ def main(args): print(f"Number of Training Iterations: {args.num_iters}") - if args.runner == "evo" or args.runner == "multishaper_evo": + if args.runner in ["evo", "evo_nroles", "multishaper_evo"]: runner.run_loop(env_params, agent_pair, args.num_iters, watchers) elif args.runner == "rl" or args.runner == "tensor_rl_nplayer": # number of episodes diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index f1e54f61..9ce590b0 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -1,32 +1,39 @@ import os import time from datetime import datetime -from typing import Any, Callable, NamedTuple, Tuple +from typing import Any, Callable, NamedTuple import jax import jax.numpy as jnp -import numpy as np from evosax import FitnessShaper import wandb -from pax.utils import MemoryState, TrainingState, save, float_precision, Sample +from pax.utils import MemoryState, TrainingState, save # TODO: import when evosax library is updated # from evosax.utils import ESLog from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats -from pax.watchers.fishery import fishery_stats -from pax.watchers.cournot import cournot_stats -from pax.watchers.rice import rice_stats -from pax.watchers.c_rice import c_rice_stats MAX_WANDB_CALLS = 1000 +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + class EvoRunner: """ - Evolutionary Strategy runner provides a convenient example for quickly writing + Evoluationary Strategy runner provides a convenient example for quickly writing a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against a Reinforcement Learner. + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. @@ -47,8 +54,6 @@ class EvoRunner: A tuple of experiment arguments used (usually provided by HydraConfig). """ - # TODO fix C901 (function too complex) - # flake8: noqa: C901 def __init__( self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): @@ -73,9 +78,6 @@ def __init__( jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) ) - if args.num_players != args.agent1_roles + args.agent2_roles: - raise ValueError("Number of players must match number of roles") - # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) # Evo Runner also has an additional pmap dim (num_devices, ...) # For the env we vmap over the rng but not params @@ -104,7 +106,7 @@ def __init__( ) self.num_outer_steps = args.num_outer_steps - agent1, agent2 = agents[0], agents[1] + agent1, agent2 = agents # vmap agents accordingly # agent 1 is batched over popsize and num_opps @@ -124,7 +126,9 @@ def __init__( ) agent1.batch_policy = jax.jit( - jax.vmap(jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0))), + jax.vmap( + jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), + ) ) if args.agent2 == "NaiveEx": @@ -140,18 +144,6 @@ def __init__( 0, ) ) - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - - a2_rng = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - - agent2._state, agent2._mem = agent2.batch_init( - a2_rng, - init_hidden, - ) agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) agent2.batch_reset = jax.jit( @@ -167,6 +159,19 @@ def __init__( (1, 0, 0, 0), ) ) + if args.agent2 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + a2_rng = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + a2_rng, + init_hidden, + ) # jit evo strategy.ask = jax.jit(strategy.ask) @@ -187,102 +192,66 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order, ) = carry # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, :, 0, :] - rngs = rngs[:, :, :, 3, :] - a1_actions = [] - new_a1_memories = [] - for _obs, _mem in zip(obs1, a1_mem): - a1_action, a1_state, new_a1_memory = agent1.batch_policy( - a1_state, - _obs, - _mem, - ) - a1_actions.append(a1_action) - new_a1_memories.append(new_a1_memory) - - a2_actions = [] - new_a2_memories = [] - for _obs, _mem in zip(obs2, a2_mem): - a2_action, a2_state, new_a2_memory = agent2.batch_policy( - a2_state, - _obs, - _mem, - ) - a2_actions.append(a2_action) - new_a2_memories.append(new_a2_memory) + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] - actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order] - obs, env_state, rewards, done, info = env.step( + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( env_rng, env_state, - tuple(actions), + (a1, a2), env_params, ) - inv_agent_order = jnp.argsort(agent_order) - obs = jnp.asarray(obs)[inv_agent_order] - rewards = jnp.asarray(rewards)[inv_agent_order] - agent1_roles = len(a1_actions) - - a1_trajectories = [ - Sample( - observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden, - ) - for observation, action, reward, new_memory, memory in zip( - obs1, - a1_actions, - rewards[:agent1_roles], - new_a1_memories, - a1_mem, - ) - ] - a2_trajectories = [ - Sample( - observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden, - ) - for observation, action, reward, new_memory, memory in zip( - obs2, - a2_actions, - rewards[agent1_roles:], - new_a2_memories, - a2_mem, - ) - ] - + traj1 = Sample( + obs1, + a1, + rewards[0], + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = Sample( + obs2, + a2, + rewards[1], + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) return ( rngs, - tuple(obs[:agent1_roles]), - tuple(obs[agent1_roles:]), - tuple(rewards[:agent1_roles]), - tuple(rewards[agent1_roles:]), + next_obs1, + next_obs2, + rewards[0], + rewards[1], a1_state, - tuple(new_a1_memories), + new_a1_mem, a2_state, - tuple(new_a2_memories), + new_a2_mem, env_state, env_params, - agent_order, ), ( - a1_trajectories, - a2_trajectories, + traj1, + traj2, ) def _outer_rollout(carry, unused): @@ -306,23 +275,18 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": - a1_mem = [agent1.meta_policy(_a1_mem) for _a1_mem in a1_mem] + a1_mem = agent1.meta_policy(a1_mem) # update second agent - new_a2_memories = [] - a2_metrics = None - for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]): - a2_state, a2_mem, a2_metrics = agent2.batch_update( - traj, - _obs, - a2_state, - mem, - ) - new_a2_memories.append(a2_mem) + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) return ( rngs, obs1, @@ -332,10 +296,9 @@ def _outer_rollout(carry, unused): a1_state, a1_mem, a2_state, - tuple(new_a2_memories), + a2_mem, env_state, env_params, - agent_order, ), (*trajectories, a2_metrics) def _rollout( @@ -343,9 +306,7 @@ def _rollout( _rng_run: jnp.ndarray, _a1_state: TrainingState, _a1_mem: MemoryState, - _a2_state: TrainingState, _env_params: Any, - roles: Tuple[int, int], ): # env reset env_rngs = jnp.concatenate( @@ -356,11 +317,9 @@ def _rollout( obs, env_state = env.reset(env_rngs, _env_params) rewards = [ - jnp.zeros( - (args.popsize, args.num_opps, args.num_envs), - dtype=float_precision, - ) - ] * (1 + args.agent2_roles) + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] # Player 1 _a1_state = _a1_state._replace(params=_params) @@ -368,6 +327,7 @@ def _rollout( # Player 2 if args.agent2 == "NaiveEx": a2_state, a2_mem = agent2.batch_init(obs[1]) + else: # meta-experiments - init 2nd agent per trial a2_rng = jnp.concatenate( @@ -378,31 +338,19 @@ def _rollout( agent2._mem.hidden, ) - if _a2_state is not None: - a2_state = _a2_state - - agent_order = jnp.arange(args.num_players) - if args.shuffle_players: - agent_order = jax.random.permutation(_rng_run, agent_order) - - agent1_roles, agent2_roles = roles # run trials vals, stack = jax.lax.scan( _outer_rollout, ( env_rngs, - # Split obs and rewards between agents - tuple(obs[:agent1_roles]), - tuple(obs[agent1_roles:]), - tuple(rewards[:agent1_roles]), - tuple(rewards[agent1_roles:]), + *obs, + *rewards, _a1_state, - (_a1_mem,) * agent1_roles, + _a1_mem, a2_state, - (a2_mem,) * agent2_roles, + a2_mem, env_state, _env_params, - agent_order, ), None, length=self.num_outer_steps, @@ -420,22 +368,12 @@ def _rollout( a2_mem, env_state, _env_params, - agent_order, ) = vals traj_1, traj_2, a2_metrics = stack # Fitness - agent_1_rewards = jnp.concatenate( - [traj.rewards for traj in traj_1] - ) - fitness = agent_1_rewards.mean(axis=(0, 1, 3, 4)) - agent_2_rewards = jnp.concatenate( - [traj.rewards for traj in traj_2] - ) - other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) - rewards_1 = agent_1_rewards.mean() - rewards_2 = agent_2_rewards.mean() - + fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) + other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) # Stats if args.env_id == "coin_game": env_stats = jax.tree_util.tree_map( @@ -445,6 +383,7 @@ def _rollout( rewards_1 = traj_1.rewards.sum(axis=1).mean() rewards_2 = traj_2.rewards.sum(axis=1).mean() + elif args.env_id in [ "iterated_matrix_game", ]: @@ -456,6 +395,9 @@ def _rollout( obs1, ), ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + elif args.env_id == "InTheMatrix": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -466,30 +408,12 @@ def _rollout( args.num_envs, ), ) - elif args.env_id == "Cournot": - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - cournot_stats( - traj_1[0].observations, _env_params, args.num_players - ), - ) - elif args.env_id == "Fishery": - env_stats = fishery_stats(traj_1 + traj_2, args.num_players) - elif args.env_id == "Rice-N": - env_stats = rice_stats( - traj_1 + traj_2, args.num_players, args.has_mediator - ) - elif args.env_id == "C-Rice-N": - env_stats = c_rice_stats( - traj_1 + traj_2, args.num_players, args.has_mediator - ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() else: env_stats = {} - - env_stats = env_stats | { - "train/agent1_roles": agent1_roles, - "train/agent2_roles": agent2_roles, - } + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() return ( fitness, @@ -498,13 +422,11 @@ def _rollout( rewards_1, rewards_2, a2_metrics, - a2_state, ) self.rollout = jax.pmap( _rollout, - in_axes=(0, None, None, None, None, None, None), - static_broadcasted_argnums=6, + in_axes=(0, None, None, None, None), ) print( @@ -530,7 +452,7 @@ def run_loop( print(f"Log Interval: {log_interval}") print("------------------------------") # Initialize agents and RNG - agent1, agent2 = agents[0], agents[1] + agent1, agent2 = agents rng, _ = jax.random.split(self.random_key) # Initialize evolution @@ -567,7 +489,6 @@ def run_loop( ) a1_state, a1_mem = agent1._state, agent1._mem - a2_state = None for gen in range(num_gens): rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) @@ -579,28 +500,6 @@ def run_loop( params = jax.tree_util.tree_map( lambda x: jax.lax.expand_dims(x, (0,)), params ) - - if gen % self.args.agent2_reset_interval == 0: - a2_state = None - - if self.args.num_devices == 1 and a2_state is not None: - # The first rollout returns a2_state with an extra batch dim that - # will cause issues when passing it back to the vmapped batch_policy - a2_state = jax.tree_util.tree_map( - lambda w: jnp.squeeze(w, axis=0), a2_state - ) - - self_play_prob = gen / num_gens - agent1_roles = 1 - if self.args.self_play_anneal: - agent1_roles = np.random.binomial( - self.args.num_players, self_play_prob - ) - agent1_roles = np.maximum( - agent1_roles, 1 - ) # Ensure at least one agent 1 - agent2_roles = self.args.num_players - agent1_roles - # Evo Rollout ( fitness, @@ -609,21 +508,10 @@ def run_loop( rewards_1, rewards_2, a2_metrics, - a2_state, - ) = self.rollout( - params, - rng_run, - a1_state, - a1_mem, - a2_state, - env_params, - (agent1_roles, agent2_roles), - ) + ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) # Aggregate over devices - fitness = jnp.reshape( - fitness, popsize * self.args.num_devices - ).astype(dtype=jnp.float32) + fitness = jnp.reshape(fitness, popsize * self.args.num_devices) env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) # Tell @@ -636,12 +524,9 @@ def run_loop( # Logging log = es_logging.update(log, x, fitness) - is_last_loop = gen == num_iters - 1 # Saving - if gen % self.args.save_interval == 0 or is_last_loop: - log_savepath1 = os.path.join( - self.save_dir, f"generation_{gen}" - ) + if gen % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"generation_{gen}") if self.args.num_devices > 1: top_params = param_reshaper.reshape( log["top_gen_params"][0 : self.args.num_devices] @@ -656,19 +541,15 @@ def run_loop( top_params = jax.tree_util.tree_map( lambda x: x.reshape(x.shape[1:]), top_params ) - save(top_params, log_savepath1) - log_savepath2 = os.path.join( - self.save_dir, f"agent2_iteration_{gen}" - ) - save(a2_state.params, log_savepath2) + save(top_params, log_savepath) if watchers: - print(f"Saving iteration {gen} locally and to WandB") - wandb.save(log_savepath1) - wandb.save(log_savepath2) + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) else: print(f"Saving iteration {gen} locally") - if gen % log_interval == 0 or is_last_loop: - print(f"Generation: {gen}/{num_iters}") + + if gen % log_interval == 0: + print(f"Generation: {gen}") print( "--------------------------------------------------------------------------" ) @@ -727,19 +608,17 @@ def run_loop( zip(log["top_fitness"], log["top_gen_fitness"]) ): wandb_log[ - f"train/fitness/top_overall_agent_{idx + 1}" + f"train/fitness/top_overall_agent_{idx+1}" ] = overall_fitness wandb_log[ - f"train/fitness/top_gen_agent_{idx + 1}" + f"train/fitness/top_gen_agent_{idx+1}" ] = gen_fitness # player 2 metrics # metrics [outer_timesteps, num_opps] - flattened_metrics = {} - if a2_metrics is not None: - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics - ) + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) agent2._logger.metrics.update(flattened_metrics) for watcher, agent in zip(watchers, agents): diff --git a/pax/runners/runner_evo_nroles.py b/pax/runners/runner_evo_nroles.py new file mode 100644 index 00000000..aacccfbc --- /dev/null +++ b/pax/runners/runner_evo_nroles.py @@ -0,0 +1,755 @@ +import os +import time +from datetime import datetime +from typing import Any, Callable, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from evosax import FitnessShaper + +import wandb +from pax.utils import MemoryState, TrainingState, save, float_precision, Sample + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.fishery import fishery_stats +from pax.watchers.cournot import cournot_stats +from pax.watchers.rice import rice_stats +from pax.watchers.c_rice import c_rice_stats + +MAX_WANDB_CALLS = 1000 + + +class EvoRunnerNRoles: + """ + This Runner extends the EvoRunner class with three features: + 1. Allow for both the first and second agent to assume multiple roles in the game. + 2. Allow for shuffling of these roles for each rollout. + 3. Enable annealed self_play via the self_play_anneal flag. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + # TODO fix C901 (function too complex) + # flake8: noqa: C901 + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + self.ipditm_stats = jax.jit( + jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) + ) + + if args.num_players != args.agent1_roles + args.agent2_roles: + raise ValueError("Number of players must match number of roles") + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + self.num_outer_steps = args.num_outer_steps + agent1, agent2 = agents[0], agents[1] + + # vmap agents accordingly + # agent 1 is batched over popsize and num_opps + agent1.batch_init = jax.vmap( + jax.vmap( + agent1.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + agent1.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent1.batch_policy = jax.jit( + jax.vmap(jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0))), + ) + + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit( + jax.vmap(jax.vmap(agent2.make_initial_state)) + ) + else: + agent2.batch_init = jax.jit( + jax.vmap( + jax.vmap(agent2.make_initial_state, (0, None), 0), + (0, None), + 0, + ) + ) + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + a2_rng = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + a2_rng, + init_hidden, + ) + + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent2.batch_update = jax.jit( + jax.vmap( + jax.vmap(agent2.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + + # jit evo + strategy.ask = jax.jit(strategy.ask) + strategy.tell = jax.jit(strategy.tell) + param_reshaper.reshape = jax.jit(param_reshaper.reshape) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + agent_order, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + rngs = rngs[:, :, :, 3, :] + + a1_actions = [] + new_a1_memories = [] + for _obs, _mem in zip(obs1, a1_mem): + a1_action, a1_state, new_a1_memory = agent1.batch_policy( + a1_state, + _obs, + _mem, + ) + a1_actions.append(a1_action) + new_a1_memories.append(new_a1_memory) + + a2_actions = [] + new_a2_memories = [] + for _obs, _mem in zip(obs2, a2_mem): + a2_action, a2_state, new_a2_memory = agent2.batch_policy( + a2_state, + _obs, + _mem, + ) + a2_actions.append(a2_action) + new_a2_memories.append(new_a2_memory) + + actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order] + obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + tuple(actions), + env_params, + ) + + inv_agent_order = jnp.argsort(agent_order) + obs = jnp.asarray(obs)[inv_agent_order] + rewards = jnp.asarray(rewards)[inv_agent_order] + agent1_roles = len(a1_actions) + + a1_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs1, + a1_actions, + rewards[:agent1_roles], + new_a1_memories, + a1_mem, + ) + ] + a2_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[agent1_roles:], + new_a2_memories, + a2_mem, + ) + ] + + return ( + rngs, + tuple(obs[:agent1_roles]), + tuple(obs[agent1_roles:]), + tuple(rewards[:agent1_roles]), + tuple(rewards[agent1_roles:]), + a1_state, + tuple(new_a1_memories), + a2_state, + tuple(new_a2_memories), + env_state, + env_params, + agent_order, + ), ( + a1_trajectories, + a2_trajectories, + ) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + agent_order, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = [agent1.meta_policy(_a1_mem) for _a1_mem in a1_mem] + + # update second agent + new_a2_memories = [] + a2_metrics = None + for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]): + a2_state, a2_mem, a2_metrics = agent2.batch_update( + traj, + _obs, + a2_state, + mem, + ) + new_a2_memories.append(a2_mem) + return ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + tuple(new_a2_memories), + env_state, + env_params, + agent_order, + ), (*trajectories, a2_metrics) + + def _rollout( + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _a2_state: TrainingState, + _env_params: Any, + roles: Tuple[int, int], + ): + # env reset + env_rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs, env_state = env.reset(env_rngs, _env_params) + rewards = [ + jnp.zeros( + (args.popsize, args.num_opps, args.num_envs), + dtype=float_precision, + ) + ] * (1 + args.agent2_roles) + + # Player 1 + _a1_state = _a1_state._replace(params=_params) + _a1_mem = agent1.batch_reset(_a1_mem, False) + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs[1]) + else: + # meta-experiments - init 2nd agent per trial + a2_rng = jnp.concatenate( + [jax.random.split(_rng_run, args.num_opps)] * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + a2_state, a2_mem = agent2.batch_init( + a2_rng, + agent2._mem.hidden, + ) + + if _a2_state is not None: + a2_state = _a2_state + + agent_order = jnp.arange(args.num_players) + if args.shuffle_players: + agent_order = jax.random.permutation(_rng_run, agent_order) + + agent1_roles, agent2_roles = roles + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + env_rngs, + # Split obs and rewards between agents + tuple(obs[:agent1_roles]), + tuple(obs[agent1_roles:]), + tuple(rewards[:agent1_roles]), + tuple(rewards[agent1_roles:]), + _a1_state, + (_a1_mem,) * agent1_roles, + a2_state, + (a2_mem,) * agent2_roles, + env_state, + _env_params, + agent_order, + ), + None, + length=self.num_outer_steps, + ) + + ( + env_rngs, + obs1, + obs2, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + agent_order, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Fitness + agent_1_rewards = jnp.concatenate( + [traj.rewards for traj in traj_1] + ) + fitness = agent_1_rewards.mean(axis=(0, 1, 3, 4)) + # At the end of self play annealing there will be no agent2 reward + if agent2_roles > 0: + agent_2_rewards = jnp.concatenate( + [traj.rewards for traj in traj_2] + ) + else: + agent_2_rewards = jnp.zeros_like(agent_1_rewards) + other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) + rewards_1 = agent_1_rewards.mean() + rewards_2 = agent_2_rewards.mean() + + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + elif args.env_id in [ + "iterated_matrix_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + elif args.env_id == "InTheMatrix": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipditm_stats( + env_state, + traj_1, + traj_2, + args.num_envs, + ), + ) + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + cournot_stats( + traj_1[0].observations, _env_params, args.num_players + ), + ) + elif args.env_id == "Fishery": + env_stats = fishery_stats(traj_1 + traj_2, args.num_players) + elif args.env_id == "Rice-N": + env_stats = rice_stats( + traj_1 + traj_2, args.num_players, args.has_mediator + ) + elif args.env_id == "C-Rice-N": + env_stats = c_rice_stats( + traj_1 + traj_2, args.num_players, args.has_mediator + ) + else: + env_stats = {} + + env_stats = env_stats | { + "train/agent1_roles": agent1_roles, + "train/agent2_roles": agent2_roles, + } + + return ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + a2_state, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None, None, None), + static_broadcasted_argnums=6, + ) + + print( + f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" + ) + + def run_loop( + self, + env_params, + agents, + num_iters: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + print(f"Number of Generations: {num_iters}") + print(f"Number of Meta Episodes: {self.num_outer_steps}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + agent1, agent2 = agents[0], agents[1] + rng, _ = jax.random.split(self.random_key) + + # Initialize evolution + num_gens = num_iters + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_state = strategy.initialize(rng, es_params) + fit_shaper = FitnessShaper( + maximize=self.args.es.maximise, + centered_rank=self.args.es.centered_rank, + w_decay=self.args.es.w_decay, + z_score=self.args.es.z_score, + ) + es_logging = ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + log = es_logging.initialize() + + # Reshape a single agent's params before vmapping + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) + a1_rng = jax.random.split(rng, popsize) + agent1._state, agent1._mem = agent1.batch_init( + a1_rng, + init_hidden, + ) + + a1_state, a1_mem = agent1._state, agent1._mem + a2_state = None + + for gen in range(num_gens): + rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) + + # Ask + x, evo_state = strategy.ask(rng_evo, evo_state, es_params) + params = param_reshaper.reshape(x) + if self.args.num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + + if gen % self.args.agent2_reset_interval == 0: + a2_state = None + + if self.args.num_devices == 1 and a2_state is not None: + # The first rollout returns a2_state with an extra batch dim that + # will cause issues when passing it back to the vmapped batch_policy + a2_state = jax.tree_util.tree_map( + lambda w: jnp.squeeze(w, axis=0), a2_state + ) + + self_play_prob = gen / num_gens + agent1_roles = self.args.agent1_roles + if self.args.self_play_anneal: + agent1_roles = np.random.binomial( + self.args.num_players, self_play_prob + ) + agent1_roles = np.maximum( + agent1_roles, 1 + ) # Ensure at least one agent 1 + agent2_roles = self.args.num_players - agent1_roles + + # Evo Rollout + ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + a2_state, + ) = self.rollout( + params, + rng_run, + a1_state, + a1_mem, + a2_state, + env_params, + (agent1_roles, agent2_roles), + ) + + # Aggregate over devices + fitness = jnp.reshape( + fitness, popsize * self.args.num_devices + ).astype(dtype=jnp.float32) + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Tell + fitness_re = fit_shaper.apply(x, fitness) + + if self.args.es.mean_reduce: + fitness_re = fitness_re - fitness_re.mean() + evo_state = strategy.tell(x, fitness_re, evo_state, es_params) + + # Logging + log = es_logging.update(log, x, fitness) + + is_last_loop = gen == num_iters - 1 + # Saving + if gen % self.args.save_interval == 0 or is_last_loop: + log_savepath1 = os.path.join( + self.save_dir, f"generation_{gen}" + ) + if self.args.num_devices > 1: + top_params = param_reshaper.reshape( + log["top_gen_params"][0 : self.args.num_devices] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + log["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath1) + log_savepath2 = os.path.join( + self.save_dir, f"agent2_iteration_{gen}" + ) + save(a2_state.params, log_savepath2) + if watchers: + print(f"Saving iteration {gen} locally and to WandB") + wandb.save(log_savepath1) + wandb.save(log_savepath2) + else: + print(f"Saving iteration {gen} locally") + if gen % log_interval == 0 or is_last_loop: + print(f"Generation: {gen}/{num_iters}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" + ) + print( + f"Reward Per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print( + f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" + ) + print( + "--------------------------------------------------------------------------" + ) + print( + f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" + f" | Std: {log['log_top_gen_std'][gen]}" + ) + print( + "--------------------------------------------------------------------------" + ) + print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") + print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") + print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") + print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") + print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") + print() + + if watchers: + wandb_log = { + "train_iteration": gen, + "train/fitness/player_1": float(fitness.mean()), + "train/fitness/player_2": float(other_fitness.mean()), + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/reward_per_timestep/player_1": float( + rewards_1.mean() + ), + "train/reward_per_timestep/player_2": float( + rewards_2.mean() + ), + } + wandb_log.update(env_stats) + # loop through population + for idx, (overall_fitness, gen_fitness) in enumerate( + zip(log["top_fitness"], log["top_gen_fitness"]) + ): + wandb_log[ + f"train/fitness/top_overall_agent_{idx + 1}" + ] = overall_fitness + wandb_log[ + f"train/fitness/top_gen_agent_{idx + 1}" + ] = gen_fitness + + # player 2 metrics + # metrics [outer_timesteps, num_opps] + flattened_metrics = {} + if a2_metrics is not None: + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + + agent2._logger.metrics.update(flattened_metrics) + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb_log = jax.tree_util.tree_map( + lambda x: x.item() if isinstance(x, jax.Array) else x, + wandb_log, + ) + wandb.log(wandb_log) + + return agents diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 92d1511c..d2c1e9e3 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -329,7 +329,7 @@ def _rollout( [jax.random.split(_rng_run, args.num_envs)] * args.num_opps ).reshape((args.num_opps, args.num_envs, -1)) - obs, env_state = env.reset(rngs, _env_params) + obs, env_state = env.batch_reset(rngs, _env_params) rewards = [ jnp.zeros((args.num_opps, args.num_envs)), jnp.zeros((args.num_opps, args.num_envs)), diff --git a/test/runners/test_runners.py b/test/runners/test_runners.py index e32bcde5..c74dc3b1 100644 --- a/test/runners/test_runners.py +++ b/test/runners/test_runners.py @@ -10,7 +10,7 @@ "++num_iters=1", "++popsize=2", "++num_outer_steps=1", - "++num_inner_steps=4", # required for ppo + "++num_inner_steps=8", # required for ppo minibatch size "++num_devices=1", "++num_envs=1", "++num_epochs=1", @@ -32,10 +32,14 @@ def _test_runner(overrides): main(cfg) -def test_runner_evo_runs(): +def test_runner_evo_nroles_runs(): _test_runner(["+experiment/rice=shaper_v_ppo"]) +def test_runner_evo_runs(): + _test_runner(["+experiment/cg=mfos"]) + + def test_runner_sarl_runs(): _test_runner(["+experiment/sarl=cartpole"]) @@ -45,14 +49,27 @@ def test_runner_eval_runs(): [ "+experiment/c_rice=eval_mediator_gs_ppo", "++model_path=test/runners/files/eval_mediator/generation_1499", + # Eval requires a full episode to be played "++num_inner_steps=20", ] ) def test_runner_marl_runs(): - _test_runner(["+experiment/imp=ppo_v_all_heads"]) + _test_runner(["+experiment/cg=tabular"]) def test_runner_weight_sharing(): _test_runner(["+experiment/rice=weight_sharing"]) + + +def test_runner_evo_multishaper(): + _test_runner( + ["+experiment/multiplayer_ipd=3pl_2shap_ipd", "++num_inner_steps=10"] + ) + + +def test_runner_marl_nplayer(): + _test_runner( + ["+experiment/multiplayer_ipd=lola_vs_ppo_ipd", "++num_inner_steps=10"] + )