diff --git a/docs/getting-started/runners.md b/docs/getting-started/runners.md index 899f24fb..43dc241d 100644 --- a/docs/getting-started/runners.md +++ b/docs/getting-started/runners.md @@ -1,9 +1,28 @@ -# Runner +# Runners + +## Evo Runner + +The Evo Runner optimizes the first agent using evolutionary learning. + +See [this experiment](https://github.com/akbir/pax/blob/9a01bae33dcb2f812977be388751393f570957e9/pax/conf/experiment/cg/mfos.yaml) for an example of how to configure it. + +## Evo Runner N-Roles + +This runner extends the evo runner to `N > 2` agents by letting the first and second agent assume multiple roles that can be configured via `agent1_roles` and `agent2_roles` in the experiment configuration. +Both agents receive different sets of memories for each role that they assume but share the weights. + +- For heterogeneous games roles can be shuffled for each rollout using the `shuffle_players` flag. +- Using the `self_play_anneal` flag one can anneal the self-play probability from 0 to 1 over the course of the experiment. + +See [this experiment](https://github.com/akbir/pax/blob/bb0e69ef71fd01ec9c85753814ffba3c5cb77935/pax/conf/experiment/rice/shaper_v_ppo.yaml) for an example of how to configure it. + +## Weight sharing Runner + +A simple baseline for MARL experiments is having one agent assume multiple roles and share the weights between them (but not the memory). +In order for this approach to work the observation vector needs to include one entry that indicates the role of the agent (see [Terry et al.](https://arxiv.org/abs/2005.13625v7). + +See [this experiment](https://github.com/akbir/pax/blob/9d3fa62e34279a338c07cffcbf208edc8a95e7ba/pax/conf/experiment/rice/weight_sharing.yaml) for an example of how to configure it. -## Runner 1 -Lorem ipsum. -## Runner 2 -Lorem ipsum. diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index a703c70e..459d6c01 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -18,6 +18,8 @@ make_rice_sarl_network, make_ipd_network, ) +from pax.envs.iterated_matrix_game import IteratedMatrixGame +from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer from pax.envs.rice.c_rice import ClubRice from pax.envs.rice.rice import Rice from pax.envs.rice.sarl_rice import SarlRice @@ -526,6 +528,11 @@ def make_agent( network = make_rice_sarl_network(action_spec, agent_args.hidden_size) elif args.runner == "sarl": network = make_sarl_network(action_spec) + elif args.env_id in [ + IteratedMatrixGame.env_id, + IteratedTensorGameNPlayer.env_id, + ]: + network = make_ipd_network(action_spec, True, agent_args.hidden_size) else: raise NotImplementedError( f"No ppo network implemented for env {args.env_id}" diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index e7fabcc0..e25e998e 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -30,6 +30,11 @@ agent1_roles: 1 agent2_roles: 1 # Make agent 2 assume multiple roles in an n-player game agent2_reset_interval: 1 # Reset agent 2 every rollout +# When True: runner_evo will replace the opponent by the agent itself +# at a linearly increasing probability during training +self_play_anneal: False + + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/c_rice/debug.yaml b/pax/conf/experiment/c_rice/debug.yaml index a34b6ccc..62bbb95e 100644 --- a/pax/conf/experiment/c_rice/debug.yaml +++ b/pax/conf/experiment/c_rice/debug.yaml @@ -10,7 +10,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml index 9fe6b239..9a8799d5 100644 --- a/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml +++ b/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml @@ -28,8 +28,8 @@ num_devices: 1 num_steps: 10 # Train to convergence -agent2_reset_interval: 1000 -# Regular mediator +agent2_reset_interval: 2000 +# Reward objective #run_path: chrismatix/c-rice/runs/3w7d59ug #model_path: exp/mediator/c_rice-mediator-gs-ppo-interval10_seed0/2023-10-09_17.00.59.872280/generation_1499 @@ -37,6 +37,23 @@ agent2_reset_interval: 1000 run_path: chrismatix/c-rice/runs/ovss1ahd model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0-climate-obj/2023-10-14_17.23.35.878225/generation_1499 +# 0.9 climate 0.1 reward +#run_path: chrismatix/c-rice/runs/mmtc40ja +#model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0-c.9-u.1/2023-10-17_17.03.26.660387/generation_1499 + + +# 0.7 climate 0.3 reward +#run_path: chrismatix/c-rice/runs/sdpc3s71 +#model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0-c.7-u.3/2023-10-20_17.12.09.658666/generation_1499 + +# 0.5 climate 0.5 reward +#run_path: chrismatix/c-rice/runs/6wpuz6i2 +#model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0-c.5-u.5/2023-10-20_15.48.04.605509/generation_1499 + +# high reward +#run_path: chrismatix/c-rice/runs/l4enoiku +#model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0/2023-10-02_18.01.15.434206/generation_1499 + # PPO agent parameters ppo_default: num_minibatches: 4 diff --git a/pax/conf/experiment/c_rice/marl_baseline.yaml b/pax/conf/experiment/c_rice/marl_baseline.yaml index 5c30f380..dd223cf2 100644 --- a/pax/conf/experiment/c_rice/marl_baseline.yaml +++ b/pax/conf/experiment/c_rice/marl_baseline.yaml @@ -9,7 +9,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True # Training diff --git a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml index f5ca57e2..acd73fdd 100644 --- a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml +++ b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml @@ -12,7 +12,7 @@ env_type: meta num_players: 6 has_mediator: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True agent2_reset_interval: 10 diff --git a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml index 81d2edbf..719e6993 100644 --- a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True default_club_mitigation_rate: 0.1 diff --git a/pax/conf/experiment/cg/mfos.yaml b/pax/conf/experiment/cg/mfos.yaml index d56a2c4b..1b4bb1dc 100644 --- a/pax/conf/experiment/cg/mfos.yaml +++ b/pax/conf/experiment/cg/mfos.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'MFOS' agent2: 'PPO_memory' @@ -11,24 +11,27 @@ egocentric: True env_discount: 0.96 payoff: [[1, 1, -2], [1, 1, -2]] -# Runner +# Runner runner: evo +top_k: 4 +popsize: 1000 #512 # env_batch_size = num_envs * num_opponents num_envs: 250 num_opps: 1 num_outer_steps: 600 -num_inner_steps: 16 -save_interval: 100 +num_inner_steps: 16 +save_interval: 100 +num_steps: '${num_inner_steps}' -# Evaluation +# Evaluation run_path: ucl-dark/cg/12auc9um model_path: exp/sanity-PPO-vs-PPO-parity/run-seed-0/2022-09-08_20.04.17.155963/iteration_500 # PPO agent parameters -ppo: +ppo1: num_minibatches: 8 - num_epochs: 2 + num_epochs: 2 gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 @@ -49,6 +52,52 @@ ppo: separate: True # only works with CNN hidden_size: 16 #50 +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.01 #0.05 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True # only works with CNN + hidden_size: 16 #50 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/cg/tabular.yaml b/pax/conf/experiment/cg/tabular.yaml index 39277290..9240a691 100644 --- a/pax/conf/experiment/cg/tabular.yaml +++ b/pax/conf/experiment/cg/tabular.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'Tabular' agent2: 'Random' @@ -25,9 +25,32 @@ num_iters: 10000 # train_batch_size = num_envs * num_opponents * num_steps # PPO agent parameters -ppo: +ppo1: num_minibatches: 8 - num_epochs: 2 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: True + learning_rate: 0.01 #0.05 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True # only works with CNN + hidden_size: 16 #50 + +ppo2: + num_minibatches: 8 + num_epochs: 2 gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 diff --git a/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml b/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml new file mode 100644 index 00000000..280b9a3c --- /dev/null +++ b/pax/conf/experiment/cournot/eval_shaper_v_ppo.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent_default: 'PPO' + +# Environment +env_id: Cournot +env_type: meta +a: 100 +b: 1 +marginal_cost: 10 + +# Runner +runner: evo_nroles + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game +num_iters: 1000 +num_devices: 1 +num_steps: '${num_inner_steps}' + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: cournot + group: 'shaper' + name: 'cournot-SHAPER-${num_players}p-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/cournot/shaper_v_ppo.yaml b/pax/conf/experiment/cournot/shaper_v_ppo.yaml index 1d44629d..ef4f0fac 100644 --- a/pax/conf/experiment/cournot/shaper_v_ppo.yaml +++ b/pax/conf/experiment/cournot/shaper_v_ppo.yaml @@ -4,6 +4,7 @@ agent1: 'PPO_memory' agent_default: 'PPO' + # Environment env_id: Cournot env_type: meta @@ -12,7 +13,7 @@ b: 1 marginal_cost: 10 # Runner -runner: tensor_evo +runner: evo_nroles # Training top_k: 5 @@ -27,28 +28,7 @@ num_steps: '${num_inner_steps}' # PPO agent parameters -ppo1: - num_minibatches: 4 - num_epochs: 2 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.02 - entropy_coeff_horizon: 2000000 - entropy_coeff_end: 0.001 - lr_scheduling: False - learning_rate: 1 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - hidden_size: 16 - -# PPO agent parameters -ppo2: +ppo_default: num_minibatches: 4 num_epochs: 2 gamma: 0.96 diff --git a/pax/conf/experiment/fishery/marl_baseline.yaml b/pax/conf/experiment/fishery/marl_baseline.yaml index f0faa91b..ef42b6bd 100644 --- a/pax/conf/experiment/fishery/marl_baseline.yaml +++ b/pax/conf/experiment/fishery/marl_baseline.yaml @@ -15,7 +15,7 @@ s_0: 0.5 s_max: 1.0 # This means the optimum quantity is 2(a-marginal_cost)/3b = 60 -runner: evo +runner: evo_nroles # env_batch_size = num_envs * num_opponents num_envs: 100 diff --git a/pax/conf/experiment/fishery/mfos_v_ppo.yaml b/pax/conf/experiment/fishery/mfos_v_ppo.yaml index 78e4aa97..a3a828ac 100644 --- a/pax/conf/experiment/fishery/mfos_v_ppo.yaml +++ b/pax/conf/experiment/fishery/mfos_v_ppo.yaml @@ -15,7 +15,7 @@ s_0: 0.5 s_max: 1.0 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/fishery/shaper_v_ppo.yaml b/pax/conf/experiment/fishery/shaper_v_ppo.yaml index a44819b6..8370a98b 100644 --- a/pax/conf/experiment/fishery/shaper_v_ppo.yaml +++ b/pax/conf/experiment/fishery/shaper_v_ppo.yaml @@ -14,22 +14,21 @@ w: 0.9 s_0: 0.5 s_max: 1.0 # Runner -runner: evo +runner: evo_nroles # Training top_k: 5 popsize: 1000 num_envs: 2 num_opps: 1 -num_outer_steps: 4%0 -num_inner_steps: 300 +num_outer_steps: 20 +num_inner_steps: 2000 num_iters: 1500 num_devices: 1 -num_steps: 1100 # PPO agent parameters -ppo1: +ppo_default: num_minibatches: 4 num_epochs: 2 gamma: 0.96 diff --git a/pax/conf/experiment/fishery/weight_sharing.yaml b/pax/conf/experiment/fishery/weight_sharing.yaml index 930dad6e..20709f6c 100644 --- a/pax/conf/experiment/fishery/weight_sharing.yaml +++ b/pax/conf/experiment/fishery/weight_sharing.yaml @@ -19,7 +19,7 @@ runner: weight_sharing # env_batch_size = num_envs * num_opponents num_envs: 50 -num_inner_steps: 300 +num_inner_steps: 2100 num_iters: 4e6 save_interval: 100 num_steps: 2100 diff --git a/pax/conf/experiment/imp/ppo_v_all_heads.yaml b/pax/conf/experiment/imp/ppo_v_all_heads.yaml index d0b09182..1dfc0cd1 100644 --- a/pax/conf/experiment/imp/ppo_v_all_heads.yaml +++ b/pax/conf/experiment/imp/ppo_v_all_heads.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'PPO' agent2: 'Altruistic' @@ -10,7 +10,7 @@ env_type: sequential env_discount: 0.99 payoff: [[1, -1], [-1, 1], [-1, 1], [1, -1]] -# Runner +# Runner runner: rl # Training hyperparameters @@ -19,20 +19,20 @@ num_opps: 1 num_steps: 150 # number of steps per episode num_iters: 1000 -# Evaluation +# Evaluation run_path: ucl-dark/ipd/w1x0vqb7 model_path: exp/PPO-vs-TitForTat-ipd-parity/PPO-vs-TitForTat-ipd-parity-run-seed-0/2022-09-08_15.56.38.018596/iteration_260 # PPO agent parameters -ppo: +ppo_default: num_minibatches: 4 num_epochs: 2 gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 value_coeff: 0.5 - clip_value: True + clip_value: True max_gradient_norm: 0.5 anneal_entropy: False entropy_coeff_start: 0.02 diff --git a/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml b/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml index ceca98d7..3444e9fd 100644 --- a/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml +++ b/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml @@ -18,8 +18,8 @@ rice_v2_network: True runner: eval -run_path: chrismatix/rice/runs/yg67hb4e -model_path: exp/shaper/rice-SHAPER-PPO_memory-seed-0-interval_20/2023-10-09_12.14.28.778753/generation_1499 +run_path: chrismatix/rice/runs/k7nw6h8j +model_path: exp/shaper/rice-SHAPER-PPO_memory-seed-0-interval_10/2023-10-16_20.08.43.003976/generation_1499 # Better run but with old network #run_path: chrismatix/rice/runs/btpdx3d2 @@ -39,10 +39,10 @@ popsize: 1000 num_devices: 1 num_envs: 20 num_opps: 1 -num_inner_steps: 20 +num_inner_steps: 2000 num_outer_steps: 1 -num_iters: 100 -num_steps: 200 +num_iters: 1000 +num_steps: 20 agent2_reset_interval: 1000 diff --git a/pax/conf/experiment/rice/eval_weight_sharing.yaml b/pax/conf/experiment/rice/eval_weight_sharing.yaml new file mode 100644 index 00000000..f999a919 --- /dev/null +++ b/pax/conf/experiment/rice/eval_weight_sharing.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: Rice-N +env_type: sequential +num_players: 5 +has_mediator: False +config_folder: pax/envs/rice/5_regions +runner: eval +rice_v2_network: True +# Training hyperparameters + +run_path: chrismatix/rice/runs/l6ug3nod +model_path: exp/weight_sharing/rice-weight_sharing-PPO_memory-seed-0/2023-10-12_18.54.03.092581/iteration_119999 + + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 20 +num_iters: 6e6 +save_interval: 100 +num_steps: 2000 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: rice + group: 'weight_sharing' + name: 'rice-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/conf/experiment/rice/gs_v_ppo.yaml b/pax/conf/experiment/rice/gs_v_ppo.yaml index daf31c69..ece5eb89 100644 --- a/pax/conf/experiment/rice/gs_v_ppo.yaml +++ b/pax/conf/experiment/rice/gs_v_ppo.yaml @@ -12,7 +12,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training diff --git a/pax/conf/experiment/rice/mfos_v_ppo.yaml b/pax/conf/experiment/rice/mfos_v_ppo.yaml index 13e5f4f4..a0870424 100644 --- a/pax/conf/experiment/rice/mfos_v_ppo.yaml +++ b/pax/conf/experiment/rice/mfos_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: False config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles # Training top_k: 5 diff --git a/pax/conf/experiment/rice/shaper_v_ppo.yaml b/pax/conf/experiment/rice/shaper_v_ppo.yaml index 4cc0a5b0..78cae9a2 100644 --- a/pax/conf/experiment/rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/rice/shaper_v_ppo.yaml @@ -14,7 +14,7 @@ num_players: 5 has_mediator: False shuffle_players: True config_folder: pax/envs/rice/5_regions -runner: evo +runner: evo_nroles rice_v2_network: True # Training @@ -22,11 +22,13 @@ top_k: 5 popsize: 1000 num_envs: 1 num_opps: 1 -num_outer_steps: 200 +num_outer_steps: e200 num_inner_steps: 200 num_iters: 1500 num_devices: 1 +agent2_reset_interval: 10 + # PPO agent parameters ppo_default: diff --git a/pax/conf/experiment/sarl/cartpole.yaml b/pax/conf/experiment/sarl/cartpole.yaml index e5e31973..a8e33bf5 100644 --- a/pax/conf/experiment/sarl/cartpole.yaml +++ b/pax/conf/experiment/sarl/cartpole.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'PPO' # Environment @@ -15,19 +15,19 @@ runner: sarl # env_batch_size = num_envs * num_opponents num_envs: 8 -num_steps: 500 # 500 Cartpole +num_steps: 500 # 500 Cartpole num_iters: 1e6 -save_interval: 100 +save_interval: 100 -# Evaluation +# Evaluation run_path: ucl-dark/cg/3sp0y2cy model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 # PPO agent parameters -ppo: +ppo0: num_minibatches: 4 - num_epochs: 4 - gamma: 0.99 + num_epochs: 4 + gamma: 0.99 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 value_coeff: 0.5 @@ -56,7 +56,7 @@ ppo: # learning_rate: 1.0 # adam_epsilon: 1e-5 # entropy_coeff: 0 - + # Logging setup @@ -65,4 +65,4 @@ wandb: project: synq group: 'sanity-${agent1}-vs-${agent2}-parity' name: run-seed-${seed} - log: False \ No newline at end of file + log: False diff --git a/pax/envs/fishery.py b/pax/envs/fishery.py index c86dc23e..8f6f1697 100644 --- a/pax/envs/fishery.py +++ b/pax/envs/fishery.py @@ -48,9 +48,10 @@ def to_obs_array(params: EnvParams) -> jnp.ndarray: class Fishery(environment.Environment): env_id: str = "Fishery" - def __init__(self, num_players: int, num_inner_steps: int): + def __init__(self, num_players: int): super().__init__() self.num_players = num_players + self.num_inner_steps = 300 def _step( key: chex.PRNGKey, @@ -60,7 +61,7 @@ def _step( ): t = state.inner_t + 1 key, _ = jax.random.split(key, 2) - done = t >= num_inner_steps + done = t >= self.num_inner_steps actions = jnp.asarray(actions).squeeze() actions = jnp.clip(actions, a_min=0) @@ -78,7 +79,9 @@ def _step( all_obs = [] all_rewards = [] for i in range(num_players): - obs = jnp.concatenate([actions, jnp.array([s_next])]) + obs = jnp.concatenate( + [actions, jnp.array([s_next]), jnp.array([i])] + ) obs = jax.lax.select(done, reset_obs[i], obs) all_obs.append(obs) @@ -116,8 +119,17 @@ def _reset( s=params.s_0, ) obs = jax.random.uniform(key, (num_players,)) - obs = jnp.concatenate([obs, jnp.array([state.s])]) - return tuple([obs for _ in range(num_players)]), state + return ( + tuple( + [ + jnp.concatenate( + [obs, jnp.array([state.s]), jnp.array([i])] + ) + for i in range(num_players) + ] + ), + state, + ) self.step = jax.jit(_step) self.reset = jax.jit(_reset) @@ -136,7 +148,7 @@ def observation_space(self, params: EnvParams) -> spaces.Box: return spaces.Box( low=0, high=float("inf"), - shape=self.num_players + 1, + shape=self.num_players + 2, # + 1 index of player, + 1 stock dtype=jnp.float32, ) diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index 218edc07..e84b9b9d 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -18,6 +18,8 @@ class EnvParams: class IteratedMatrixGame(environment.Environment): + env_id = "iterated_matrix_game" + """ JAX Compatible version of matrix game environment. """ diff --git a/pax/envs/iterated_tensor_game_n_player.py b/pax/envs/iterated_tensor_game_n_player.py index 8335228f..1b060569 100644 --- a/pax/envs/iterated_tensor_game_n_player.py +++ b/pax/envs/iterated_tensor_game_n_player.py @@ -18,6 +18,7 @@ class EnvParams: class IteratedTensorGameNPlayer(environment.Environment): + env_id = "iterated_nplayer_tensor_game" """ JAX Compatible version of tensor game environment. """ diff --git a/pax/experiment.py b/pax/experiment.py index ac6de950..719f1c24 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -77,6 +77,7 @@ from pax.envs.rice.c_rice import ClubRice from pax.envs.rice.rice import Rice, EnvParams as RiceParams from pax.envs.rice.sarl_rice import SarlRice +from pax.runners.runner_evo_nroles import EvoRunnerNRoles from pax.runners.runner_weight_sharing import WeightSharingRunner from pax.runners.runner_ipditm_eval import IPDITMEvalRunner @@ -231,7 +232,6 @@ def env_setup(args, logger=None): ) env = Fishery( num_players=args.num_players, - num_inner_steps=args.num_inner_steps, ) if logger: logger.info( @@ -307,8 +307,8 @@ def runner_setup(args, env, agents, save_dir, logger): return IPDITMEvalRunner(agents, env, save_dir, args) if args.runner in ["evo", "evo_mixed_lr", "evo_hardstop", "evo_mixed_payoff", "evo_mixed_ipd_payoff", - "evo_mixed_payoff_gen", "evo_mixed_payoff_input", "evo_scanned", "evo_mixed_payoff_only_opp", "multishaper_evo"]: - agent1, _ = agents + "evo_mixed_payoff_gen", "evo_mixed_payoff_input", "evo_scanned", "evo_mixed_payoff_only_opp", "multishaper_evo", "evo_nroles"]: + agent1 = agents[0] algo = args.es.algo strategies = {"CMA_ES", "OpenES", "PGPE", "SimpleGA"} assert algo in strategies, f"{algo} not in evolution strategies" @@ -433,6 +433,19 @@ def get_pgpe_strategy(agent): return EvoScannedRunner( agents, env, strategy, es_params, param_reshaper, save_dir, args ) + + elif args.runner == "evo_nroles": + logger.info("Training with n_roles EVO runner") + return EvoRunnerNRoles( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) + elif args.runner == "multishaper_evo": logger.info("Training with multishaper EVO runner") return MultishaperEvoRunner( @@ -864,7 +877,7 @@ def main(args): print(f"Number of Training Iterations: {args.num_iters}") if args.runner in ["evo", "evo_mixed_lr", "evo_hardstop", "evo_mixed_payoff", "evo_mixed_ipd_payoff", - "evo_mixed_payoff_gen", "evo_mixed_payoff_input", "evo_mixed_payoff_pred", "evo_scanned", "evo_mixed_payoff_only_opp", "multishaper_evo"]: + "evo_mixed_payoff_gen", "evo_mixed_payoff_input", "evo_mixed_payoff_pred", "evo_scanned", "evo_mixed_payoff_only_opp", "multishaper_evo", "evo_nroles"]: print(f"Running {args.runner}") runner.run_loop(env_params, agent_pair, args.num_iters, watchers) diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index f10c65da..43c4a0bd 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -8,24 +8,32 @@ from evosax import FitnessShaper import wandb -from pax.utils import MemoryState, TrainingState, save, float_precision, Sample +from pax.utils import MemoryState, TrainingState, save # TODO: import when evosax library is updated # from evosax.utils import ESLog from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats -from pax.watchers.fishery import fishery_stats -from pax.watchers.cournot import cournot_stats -from pax.watchers.rice import rice_stats -from pax.watchers.c_rice import c_rice_stats MAX_WANDB_CALLS = 1000 +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + class EvoRunner: """ - Evolutionary Strategy runner provides a convenient example for quickly writing + Evoluationary Strategy runner provides a convenient example for quickly writing a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against a Reinforcement Learner. + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. @@ -46,8 +54,6 @@ class EvoRunner: A tuple of experiment arguments used (usually provided by HydraConfig). """ - # TODO fix C901 (function too complex) - # flake8: noqa: C901 def __init__( self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): @@ -71,7 +77,6 @@ def __init__( self.ipditm_stats = jax.jit( jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) ) - self.cournot_stats = cournot_stats # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) # Evo Runner also has an additional pmap dim (num_devices, ...) @@ -101,7 +106,7 @@ def __init__( ) self.num_outer_steps = args.num_outer_steps - agent1, agent2 = agents[0], agents[1] + agent1, agent2 = agents # vmap agents accordingly # agent 1 is batched over popsize and num_opps @@ -187,48 +192,33 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order, ) = carry # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, :, 0, :] - rngs = rngs[:, :, :, 3, :] - a1_actions = [] - new_a1_memories = [] - for _obs, _mem in zip(obs1, a1_mem): - a1_action, a1_state, new_a1_memory = agent1.batch_policy( - a1_state, - _obs, - _mem, - ) - a1_actions.append(a1_action) - new_a1_memories.append(new_a1_memory) - - a2_actions = [] - new_a2_memories = [] - for _obs, _mem in zip(obs2, a2_mem): - a2_action, a2_state, new_a2_memory = agent2.batch_policy( - a2_state, - _obs, - _mem, - ) - a2_actions.append(a2_action) - new_a2_memories.append(new_a2_memory) + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] - actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order] - obs, env_state, rewards, done, info = env.step( + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( env_rng, env_state, - tuple(actions), + (a1, a2), env_params, ) - inv_agent_order = jnp.argsort(agent_order) - obs = jnp.asarray(obs)[inv_agent_order] - rewards = jnp.asarray(rewards)[inv_agent_order] - traj1 = Sample( obs1, a1, @@ -238,41 +228,30 @@ def _inner_rollout(carry, unused): done, a1_mem.hidden, ) - a2_trajectories = [ - Sample( - observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden, - ) - for observation, action, reward, new_memory, memory in zip( - obs2, - a2_actions, - rewards[1:], - new_a2_memories, - a2_mem, - ) - ] - + traj2 = Sample( + obs2, + a2, + rewards[1], + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) return ( rngs, - obs[0], - tuple(obs[1:]), + next_obs1, + next_obs2, rewards[0], - tuple(rewards[1:]), + rewards[1], a1_state, new_a1_mem, a2_state, - tuple(new_a2_memories), + new_a2_mem, env_state, env_params, - agent_order, ), ( traj1, - a2_trajectories, + traj2, ) def _outer_rollout(carry, unused): @@ -296,22 +275,18 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) # update second agent - new_a2_memories = [] - for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]): - a2_state, a2_mem, a2_metrics = agent2.batch_update( - traj, - _obs, - a2_state, - mem, - ) - new_a2_memories.append(a2_mem) + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) return ( rngs, obs1, @@ -321,10 +296,9 @@ def _outer_rollout(carry, unused): a1_state, a1_mem, a2_state, - tuple(new_a2_memories), + a2_mem, env_state, env_params, - agent_order, ), (*trajectories, a2_metrics) def _rollout( @@ -332,7 +306,6 @@ def _rollout( _rng_run: jnp.ndarray, _a1_state: TrainingState, _a1_mem: MemoryState, - _a2_state: TrainingState, _env_params: Any, ): # env reset @@ -344,11 +317,9 @@ def _rollout( obs, env_state = env.reset(env_rngs, _env_params) rewards = [ - jnp.zeros( - (args.popsize, args.num_opps, args.num_envs), - dtype=float_precision, - ) - ] * (1 + args.agent2_roles) + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] # Player 1 _a1_state = _a1_state._replace(params=_params) @@ -356,6 +327,7 @@ def _rollout( # Player 2 if args.agent2 == "NaiveEx": a2_state, a2_mem = agent2.batch_init(obs[1]) + else: # meta-experiments - init 2nd agent per trial a2_rng = jnp.concatenate( @@ -372,29 +344,19 @@ def _rollout( # a2_state.opt_state[2].hyperparams['step_size'] = learning_rates # jax.debug.breakpoint() - if _a2_state is not None: - a2_state = _a2_state - - agent_order = jnp.arange(args.num_players) - if args.shuffle_players: - agent_order = jax.random.permutation(_rng_run, agent_order) - # run trials vals, stack = jax.lax.scan( _outer_rollout, ( env_rngs, - obs[0], - tuple(obs[1:]), - rewards[0], - tuple(rewards[1:]), + *obs, + *rewards, _a1_state, _a1_mem, a2_state, - (a2_mem,) * args.agent2_roles, + a2_mem, env_state, _env_params, - agent_order, ), None, length=self.num_outer_steps, @@ -412,19 +374,12 @@ def _rollout( a2_mem, env_state, _env_params, - agent_order, ) = vals traj_1, traj_2, a2_metrics = stack # Fitness fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - agent_2_rewards = jnp.concatenate( - [traj.rewards for traj in traj_2] - ) - other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) - rewards_1 = traj_1.rewards.mean() - rewards_2 = agent_2_rewards.mean() - + other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) # Stats if args.env_id == "coin_game": env_stats = jax.tree_util.tree_map( @@ -434,6 +389,7 @@ def _rollout( rewards_1 = traj_1.rewards.sum(axis=1).mean() rewards_2 = traj_2.rewards.sum(axis=1).mean() + elif args.env_id in [ "iterated_matrix_game", ]: @@ -445,6 +401,9 @@ def _rollout( obs1, ), ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + elif args.env_id == "InTheMatrix": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -455,23 +414,12 @@ def _rollout( args.num_envs, ), ) - elif args.env_id == "Cournot": - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.cournot_stats(traj_1.observations, _env_params, 2), - ) - elif args.env_id == "Fishery": - env_stats = fishery_stats([traj_1] + traj_2, args.num_players) - elif args.env_id == "Rice-N": - env_stats = rice_stats( - [traj_1] + traj_2, args.num_players, args.has_mediator - ) - elif args.env_id == "C-Rice-N": - env_stats = c_rice_stats( - [traj_1] + traj_2, args.num_players, args.has_mediator - ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() else: env_stats = {} + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() return ( fitness, @@ -480,12 +428,11 @@ def _rollout( rewards_1, rewards_2, a2_metrics, - a2_state, ) self.rollout = jax.pmap( _rollout, - in_axes=(0, None, None, None, None, None), + in_axes=(0, None, None, None, None), ) print( @@ -511,7 +458,7 @@ def run_loop( print(f"Log Interval: {log_interval}") print("------------------------------") # Initialize agents and RNG - agent1, agent2 = agents[0], agents[1] + agent1, agent2 = agents rng, _ = jax.random.split(self.random_key) # Initialize evolution @@ -548,7 +495,6 @@ def run_loop( ) a1_state, a1_mem = agent1._state, agent1._mem - a2_state = None for gen in range(num_gens): rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) @@ -560,17 +506,6 @@ def run_loop( params = jax.tree_util.tree_map( lambda x: jax.lax.expand_dims(x, (0,)), params ) - - if gen % self.args.agent2_reset_interval == 0: - a2_state = None - - if self.args.num_devices == 1 and a2_state is not None: - # The first rollout returns a2_state with an extra batch dim that - # will cause issues when passing it back to the vmapped batch_policy - a2_state = jax.tree_util.tree_map( - lambda w: jnp.squeeze(w, axis=0), a2_state - ) - # Evo Rollout ( fitness, @@ -579,15 +514,10 @@ def run_loop( rewards_1, rewards_2, a2_metrics, - a2_state, - ) = self.rollout( - params, rng_run, a1_state, a1_mem, a2_state, env_params - ) + ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) # Aggregate over devices - fitness = jnp.reshape( - fitness, popsize * self.args.num_devices - ).astype(dtype=jnp.float32) + fitness = jnp.reshape(fitness, popsize * self.args.num_devices) env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) # Tell @@ -600,12 +530,9 @@ def run_loop( # Logging log = es_logging.update(log, x, fitness) - is_last_loop = gen == num_iters - 1 # Saving - if gen % self.args.save_interval == 0 or is_last_loop: - log_savepath1 = os.path.join( - self.save_dir, f"generation_{gen}" - ) + if gen % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"generation_{gen}") if self.args.num_devices > 1: top_params = param_reshaper.reshape( log["top_gen_params"][0 : self.args.num_devices] @@ -620,19 +547,15 @@ def run_loop( top_params = jax.tree_util.tree_map( lambda x: x.reshape(x.shape[1:]), top_params ) - save(top_params, log_savepath1) - log_savepath2 = os.path.join( - self.save_dir, f"agent2_iteration_{gen}" - ) - save(a2_state.params, log_savepath2) + save(top_params, log_savepath) if watchers: - print(f"Saving iteration {gen} locally and to WandB") - wandb.save(log_savepath1) - wandb.save(log_savepath2) + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) else: print(f"Saving iteration {gen} locally") - if gen % log_interval == 0 or is_last_loop: - print(f"Generation: {gen}/{num_iters}") + + if gen % log_interval == 0: + print(f"Generation: {gen}") print( "--------------------------------------------------------------------------" ) @@ -691,10 +614,10 @@ def run_loop( zip(log["top_fitness"], log["top_gen_fitness"]) ): wandb_log[ - f"train/fitness/top_overall_agent_{idx + 1}" + f"train/fitness/top_overall_agent_{idx+1}" ] = overall_fitness wandb_log[ - f"train/fitness/top_gen_agent_{idx + 1}" + f"train/fitness/top_gen_agent_{idx+1}" ] = gen_fitness # player 2 metrics diff --git a/pax/runners/runner_evo_nroles.py b/pax/runners/runner_evo_nroles.py new file mode 100644 index 00000000..aacccfbc --- /dev/null +++ b/pax/runners/runner_evo_nroles.py @@ -0,0 +1,755 @@ +import os +import time +from datetime import datetime +from typing import Any, Callable, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from evosax import FitnessShaper + +import wandb +from pax.utils import MemoryState, TrainingState, save, float_precision, Sample + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.fishery import fishery_stats +from pax.watchers.cournot import cournot_stats +from pax.watchers.rice import rice_stats +from pax.watchers.c_rice import c_rice_stats + +MAX_WANDB_CALLS = 1000 + + +class EvoRunnerNRoles: + """ + This Runner extends the EvoRunner class with three features: + 1. Allow for both the first and second agent to assume multiple roles in the game. + 2. Allow for shuffling of these roles for each rollout. + 3. Enable annealed self_play via the self_play_anneal flag. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + # TODO fix C901 (function too complex) + # flake8: noqa: C901 + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + self.ipditm_stats = jax.jit( + jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) + ) + + if args.num_players != args.agent1_roles + args.agent2_roles: + raise ValueError("Number of players must match number of roles") + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + self.num_outer_steps = args.num_outer_steps + agent1, agent2 = agents[0], agents[1] + + # vmap agents accordingly + # agent 1 is batched over popsize and num_opps + agent1.batch_init = jax.vmap( + jax.vmap( + agent1.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + agent1.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent1.batch_policy = jax.jit( + jax.vmap(jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0))), + ) + + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit( + jax.vmap(jax.vmap(agent2.make_initial_state)) + ) + else: + agent2.batch_init = jax.jit( + jax.vmap( + jax.vmap(agent2.make_initial_state, (0, None), 0), + (0, None), + 0, + ) + ) + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + a2_rng = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + a2_rng, + init_hidden, + ) + + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent2.batch_update = jax.jit( + jax.vmap( + jax.vmap(agent2.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + + # jit evo + strategy.ask = jax.jit(strategy.ask) + strategy.tell = jax.jit(strategy.tell) + param_reshaper.reshape = jax.jit(param_reshaper.reshape) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + agent_order, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + rngs = rngs[:, :, :, 3, :] + + a1_actions = [] + new_a1_memories = [] + for _obs, _mem in zip(obs1, a1_mem): + a1_action, a1_state, new_a1_memory = agent1.batch_policy( + a1_state, + _obs, + _mem, + ) + a1_actions.append(a1_action) + new_a1_memories.append(new_a1_memory) + + a2_actions = [] + new_a2_memories = [] + for _obs, _mem in zip(obs2, a2_mem): + a2_action, a2_state, new_a2_memory = agent2.batch_policy( + a2_state, + _obs, + _mem, + ) + a2_actions.append(a2_action) + new_a2_memories.append(new_a2_memory) + + actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order] + obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + tuple(actions), + env_params, + ) + + inv_agent_order = jnp.argsort(agent_order) + obs = jnp.asarray(obs)[inv_agent_order] + rewards = jnp.asarray(rewards)[inv_agent_order] + agent1_roles = len(a1_actions) + + a1_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs1, + a1_actions, + rewards[:agent1_roles], + new_a1_memories, + a1_mem, + ) + ] + a2_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[agent1_roles:], + new_a2_memories, + a2_mem, + ) + ] + + return ( + rngs, + tuple(obs[:agent1_roles]), + tuple(obs[agent1_roles:]), + tuple(rewards[:agent1_roles]), + tuple(rewards[agent1_roles:]), + a1_state, + tuple(new_a1_memories), + a2_state, + tuple(new_a2_memories), + env_state, + env_params, + agent_order, + ), ( + a1_trajectories, + a2_trajectories, + ) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + agent_order, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = [agent1.meta_policy(_a1_mem) for _a1_mem in a1_mem] + + # update second agent + new_a2_memories = [] + a2_metrics = None + for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]): + a2_state, a2_mem, a2_metrics = agent2.batch_update( + traj, + _obs, + a2_state, + mem, + ) + new_a2_memories.append(a2_mem) + return ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + tuple(new_a2_memories), + env_state, + env_params, + agent_order, + ), (*trajectories, a2_metrics) + + def _rollout( + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _a2_state: TrainingState, + _env_params: Any, + roles: Tuple[int, int], + ): + # env reset + env_rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs, env_state = env.reset(env_rngs, _env_params) + rewards = [ + jnp.zeros( + (args.popsize, args.num_opps, args.num_envs), + dtype=float_precision, + ) + ] * (1 + args.agent2_roles) + + # Player 1 + _a1_state = _a1_state._replace(params=_params) + _a1_mem = agent1.batch_reset(_a1_mem, False) + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs[1]) + else: + # meta-experiments - init 2nd agent per trial + a2_rng = jnp.concatenate( + [jax.random.split(_rng_run, args.num_opps)] * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + a2_state, a2_mem = agent2.batch_init( + a2_rng, + agent2._mem.hidden, + ) + + if _a2_state is not None: + a2_state = _a2_state + + agent_order = jnp.arange(args.num_players) + if args.shuffle_players: + agent_order = jax.random.permutation(_rng_run, agent_order) + + agent1_roles, agent2_roles = roles + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + env_rngs, + # Split obs and rewards between agents + tuple(obs[:agent1_roles]), + tuple(obs[agent1_roles:]), + tuple(rewards[:agent1_roles]), + tuple(rewards[agent1_roles:]), + _a1_state, + (_a1_mem,) * agent1_roles, + a2_state, + (a2_mem,) * agent2_roles, + env_state, + _env_params, + agent_order, + ), + None, + length=self.num_outer_steps, + ) + + ( + env_rngs, + obs1, + obs2, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + agent_order, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Fitness + agent_1_rewards = jnp.concatenate( + [traj.rewards for traj in traj_1] + ) + fitness = agent_1_rewards.mean(axis=(0, 1, 3, 4)) + # At the end of self play annealing there will be no agent2 reward + if agent2_roles > 0: + agent_2_rewards = jnp.concatenate( + [traj.rewards for traj in traj_2] + ) + else: + agent_2_rewards = jnp.zeros_like(agent_1_rewards) + other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) + rewards_1 = agent_1_rewards.mean() + rewards_2 = agent_2_rewards.mean() + + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + elif args.env_id in [ + "iterated_matrix_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + elif args.env_id == "InTheMatrix": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipditm_stats( + env_state, + traj_1, + traj_2, + args.num_envs, + ), + ) + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + cournot_stats( + traj_1[0].observations, _env_params, args.num_players + ), + ) + elif args.env_id == "Fishery": + env_stats = fishery_stats(traj_1 + traj_2, args.num_players) + elif args.env_id == "Rice-N": + env_stats = rice_stats( + traj_1 + traj_2, args.num_players, args.has_mediator + ) + elif args.env_id == "C-Rice-N": + env_stats = c_rice_stats( + traj_1 + traj_2, args.num_players, args.has_mediator + ) + else: + env_stats = {} + + env_stats = env_stats | { + "train/agent1_roles": agent1_roles, + "train/agent2_roles": agent2_roles, + } + + return ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + a2_state, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None, None, None), + static_broadcasted_argnums=6, + ) + + print( + f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" + ) + + def run_loop( + self, + env_params, + agents, + num_iters: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + print(f"Number of Generations: {num_iters}") + print(f"Number of Meta Episodes: {self.num_outer_steps}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + agent1, agent2 = agents[0], agents[1] + rng, _ = jax.random.split(self.random_key) + + # Initialize evolution + num_gens = num_iters + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_state = strategy.initialize(rng, es_params) + fit_shaper = FitnessShaper( + maximize=self.args.es.maximise, + centered_rank=self.args.es.centered_rank, + w_decay=self.args.es.w_decay, + z_score=self.args.es.z_score, + ) + es_logging = ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + log = es_logging.initialize() + + # Reshape a single agent's params before vmapping + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) + a1_rng = jax.random.split(rng, popsize) + agent1._state, agent1._mem = agent1.batch_init( + a1_rng, + init_hidden, + ) + + a1_state, a1_mem = agent1._state, agent1._mem + a2_state = None + + for gen in range(num_gens): + rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) + + # Ask + x, evo_state = strategy.ask(rng_evo, evo_state, es_params) + params = param_reshaper.reshape(x) + if self.args.num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + + if gen % self.args.agent2_reset_interval == 0: + a2_state = None + + if self.args.num_devices == 1 and a2_state is not None: + # The first rollout returns a2_state with an extra batch dim that + # will cause issues when passing it back to the vmapped batch_policy + a2_state = jax.tree_util.tree_map( + lambda w: jnp.squeeze(w, axis=0), a2_state + ) + + self_play_prob = gen / num_gens + agent1_roles = self.args.agent1_roles + if self.args.self_play_anneal: + agent1_roles = np.random.binomial( + self.args.num_players, self_play_prob + ) + agent1_roles = np.maximum( + agent1_roles, 1 + ) # Ensure at least one agent 1 + agent2_roles = self.args.num_players - agent1_roles + + # Evo Rollout + ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + a2_state, + ) = self.rollout( + params, + rng_run, + a1_state, + a1_mem, + a2_state, + env_params, + (agent1_roles, agent2_roles), + ) + + # Aggregate over devices + fitness = jnp.reshape( + fitness, popsize * self.args.num_devices + ).astype(dtype=jnp.float32) + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Tell + fitness_re = fit_shaper.apply(x, fitness) + + if self.args.es.mean_reduce: + fitness_re = fitness_re - fitness_re.mean() + evo_state = strategy.tell(x, fitness_re, evo_state, es_params) + + # Logging + log = es_logging.update(log, x, fitness) + + is_last_loop = gen == num_iters - 1 + # Saving + if gen % self.args.save_interval == 0 or is_last_loop: + log_savepath1 = os.path.join( + self.save_dir, f"generation_{gen}" + ) + if self.args.num_devices > 1: + top_params = param_reshaper.reshape( + log["top_gen_params"][0 : self.args.num_devices] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + log["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath1) + log_savepath2 = os.path.join( + self.save_dir, f"agent2_iteration_{gen}" + ) + save(a2_state.params, log_savepath2) + if watchers: + print(f"Saving iteration {gen} locally and to WandB") + wandb.save(log_savepath1) + wandb.save(log_savepath2) + else: + print(f"Saving iteration {gen} locally") + if gen % log_interval == 0 or is_last_loop: + print(f"Generation: {gen}/{num_iters}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" + ) + print( + f"Reward Per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print( + f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" + ) + print( + "--------------------------------------------------------------------------" + ) + print( + f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" + f" | Std: {log['log_top_gen_std'][gen]}" + ) + print( + "--------------------------------------------------------------------------" + ) + print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") + print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") + print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") + print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") + print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") + print() + + if watchers: + wandb_log = { + "train_iteration": gen, + "train/fitness/player_1": float(fitness.mean()), + "train/fitness/player_2": float(other_fitness.mean()), + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/reward_per_timestep/player_1": float( + rewards_1.mean() + ), + "train/reward_per_timestep/player_2": float( + rewards_2.mean() + ), + } + wandb_log.update(env_stats) + # loop through population + for idx, (overall_fitness, gen_fitness) in enumerate( + zip(log["top_fitness"], log["top_gen_fitness"]) + ): + wandb_log[ + f"train/fitness/top_overall_agent_{idx + 1}" + ] = overall_fitness + wandb_log[ + f"train/fitness/top_gen_agent_{idx + 1}" + ] = gen_fitness + + # player 2 metrics + # metrics [outer_timesteps, num_opps] + flattened_metrics = {} + if a2_metrics is not None: + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + + agent2._logger.metrics.update(flattened_metrics) + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb_log = jax.tree_util.tree_map( + lambda x: x.item() if isinstance(x, jax.Array) else x, + wandb_log, + ) + wandb.log(wandb_log) + + return agents diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 92d1511c..d2c1e9e3 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -329,7 +329,7 @@ def _rollout( [jax.random.split(_rng_run, args.num_envs)] * args.num_opps ).reshape((args.num_opps, args.num_envs, -1)) - obs, env_state = env.reset(rngs, _env_params) + obs, env_state = env.batch_reset(rngs, _env_params) rewards = [ jnp.zeros((args.num_opps, args.num_envs)), jnp.zeros((args.num_opps, args.num_envs)), diff --git a/pax/watchers/c_rice.py b/pax/watchers/c_rice.py index fec7967a..ada6f738 100644 --- a/pax/watchers/c_rice.py +++ b/pax/watchers/c_rice.py @@ -78,6 +78,7 @@ def c_rice_eval_stats( lambda x: x.reshape((x.shape[0] * ep_count, ep_length, *x.shape[2:])), env_state, ) + # Output dim: episodes x steps x envs x agents x opps x envs x ... # Because the initial obs are not included the timesteps are shifted like so: # 1, 2, ..., 19, 0 # Since the initial obs are always the same we can just shift them to the start @@ -168,7 +169,26 @@ def add_atrib(name, value, axis): ).transpose((1, 0, 2, 3, 4)), observations, ) - add_atrib("club_mitigation_rate", observations[..., 2], axis=(1, 2, 3)) + club_mitigation_rate = observations[..., 2] + add_atrib("club_mitigation_rate", club_mitigation_rate, axis=(1, 2, 3)) add_atrib("club_tariff_rate", observations[..., 3], axis=(1, 2, 3)) + # Mitigation and tariff after climate club applied + real_mitigation = jnp.where( + env_state.club_membership_all == 1, + # Ensure all arrays are of shape (episodes, steps, opps, envs, agents) + actions[..., env.mitigation_rate_action_index].transpose( + (2, 0, 3, 4, 1) + )[..., 0:1], + actions[..., env.mitigation_rate_action_index].transpose( + (2, 0, 3, 4, 1) + )[..., 1:], + ) + add_atrib( + "real_mitigation_rate", + real_mitigation, + axis=(2, 3, 4), + ) + add_atrib("real_tariff", env_state.future_tariff, axis=(0, 2, 3)) + return result diff --git a/pax/watchers/cournot.py b/pax/watchers/cournot.py index 63439a72..5b1c2e70 100644 --- a/pax/watchers/cournot.py +++ b/pax/watchers/cournot.py @@ -20,6 +20,7 @@ def cournot_stats( "cournot/quantity_loss": jnp.mean( (opt_quantity - average_quantity) ** 2 ), + "cournot/price": observations[..., -1], } for i in range(num_players): diff --git a/pax/watchers/fishery.py b/pax/watchers/fishery.py index 6ecad269..850eded0 100644 --- a/pax/watchers/fishery.py +++ b/pax/watchers/fishery.py @@ -11,7 +11,7 @@ def fishery_stats(trajectories: List[NamedTuple], num_players: int) -> dict: traj = trajectories[0] # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim - stock_obs = traj.observations[..., -1] + stock_obs = traj.observations[..., -2] actions = traj.observations[..., :num_players] completed_episodes = jnp.sum(traj.dones) stats = { diff --git a/test/envs/test_fishery.py b/test/envs/test_fishery.py index 659e83bd..08a530f4 100644 --- a/test/envs/test_fishery.py +++ b/test/envs/test_fishery.py @@ -8,7 +8,7 @@ def test_fishery_convergence(): rng = jax.random.PRNGKey(0) ep_length = 300 - env = Fishery(num_players=2, num_inner_steps=ep_length) + env = Fishery(num_players=2) env_params = EnvParams(g=0.15, e=0.009, P=200, w=0.9, s_0=1.0, s_max=1.0) # response parameter diff --git a/test/runners/__init__.py b/test/runners/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/runners/files/eval_mediator/generation_1499 b/test/runners/files/eval_mediator/generation_1499 new file mode 100644 index 00000000..734efbfc Binary files /dev/null and b/test/runners/files/eval_mediator/generation_1499 differ diff --git a/test/runners/test_runners.py b/test/runners/test_runners.py new file mode 100644 index 00000000..c74dc3b1 --- /dev/null +++ b/test/runners/test_runners.py @@ -0,0 +1,75 @@ +import os + +import pytest +from hydra import compose, initialize_config_dir + +from pax.experiment import main + +shared_overrides = [ + "++wandb.mode=disabled", + "++num_iters=1", + "++popsize=2", + "++num_outer_steps=1", + "++num_inner_steps=8", # required for ppo minibatch size + "++num_devices=1", + "++num_envs=1", + "++num_epochs=1", +] + + +@pytest.fixture(scope="module", autouse=True) +def setup_hydra(): + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../pax/conf" + ) + initialize_config_dir(config_dir=path) + + +def _test_runner(overrides): + cfg = compose( + config_name="config.yaml", overrides=shared_overrides + overrides + ) + main(cfg) + + +def test_runner_evo_nroles_runs(): + _test_runner(["+experiment/rice=shaper_v_ppo"]) + + +def test_runner_evo_runs(): + _test_runner(["+experiment/cg=mfos"]) + + +def test_runner_sarl_runs(): + _test_runner(["+experiment/sarl=cartpole"]) + + +def test_runner_eval_runs(): + _test_runner( + [ + "+experiment/c_rice=eval_mediator_gs_ppo", + "++model_path=test/runners/files/eval_mediator/generation_1499", + # Eval requires a full episode to be played + "++num_inner_steps=20", + ] + ) + + +def test_runner_marl_runs(): + _test_runner(["+experiment/cg=tabular"]) + + +def test_runner_weight_sharing(): + _test_runner(["+experiment/rice=weight_sharing"]) + + +def test_runner_evo_multishaper(): + _test_runner( + ["+experiment/multiplayer_ipd=3pl_2shap_ipd", "++num_inner_steps=10"] + ) + + +def test_runner_marl_nplayer(): + _test_runner( + ["+experiment/multiplayer_ipd=lola_vs_ppo_ipd", "++num_inner_steps=10"] + )