From 57899fdc916fb1a43274417d2c350272cc2f9483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Pr=C3=B6schel?= Date: Wed, 27 Sep 2023 14:49:08 +0200 Subject: [PATCH] cleanup --- docs/envs.md | 20 +- pax/conf/experiment/c_rice/debug.yaml | 81 ++ pax/conf/experiment/c_rice/marl_baseline.yaml | 54 ++ .../experiment/c_rice/mediator_gs_naive.yaml | 90 +++ .../experiment/c_rice/mediator_gs_ppo.yaml | 80 ++ pax/conf/experiment/c_rice/shaper_v_ppo.yaml | 84 ++ .../experiment/c_rice/weight_sharing.yaml | 53 ++ pax/conf/experiment/n_player/3pl.yaml | 124 --- .../experiment/n_player/3pl_old_payoff.yaml | 124 --- pax/conf/experiment/n_player/4pl.yaml | 146 ---- pax/conf/experiment/n_player/5pl.yaml | 168 ---- pax/conf/experiment/rice/marl_baseline.yaml | 4 +- pax/conf/experiment/rice/shaper_v_ppo.yaml | 2 +- pax/envs/rice/c_rice.py | 43 +- pax/runners/runner_eval_nplayer.py | 567 -------------- pax/runners/runner_evo_nplayer.py | 716 ------------------ pax/utils.py | 2 +- profile_nplayer.ipynb | 218 ------ 18 files changed, 487 insertions(+), 2089 deletions(-) create mode 100644 pax/conf/experiment/c_rice/debug.yaml create mode 100644 pax/conf/experiment/c_rice/marl_baseline.yaml create mode 100644 pax/conf/experiment/c_rice/mediator_gs_naive.yaml create mode 100644 pax/conf/experiment/c_rice/mediator_gs_ppo.yaml create mode 100644 pax/conf/experiment/c_rice/shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/c_rice/weight_sharing.yaml delete mode 100644 pax/conf/experiment/n_player/3pl.yaml delete mode 100644 pax/conf/experiment/n_player/3pl_old_payoff.yaml delete mode 100644 pax/conf/experiment/n_player/4pl.yaml delete mode 100644 pax/conf/experiment/n_player/5pl.yaml delete mode 100644 pax/runners/runner_eval_nplayer.py delete mode 100644 pax/runners/runner_evo_nplayer.py delete mode 100644 profile_nplayer.ipynb diff --git a/docs/envs.md b/docs/envs.md index 7f4a711f..6f6e5092 100644 --- a/docs/envs.md +++ b/docs/envs.md @@ -1,13 +1,17 @@ ## Environments -Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independetly you can specify your enviroment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`. +Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independently you can specify your environment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`. These are specified in the config files in `pax/configs/{env_id}/EXPERIMENT.yaml`. -| Environment ID | Environment Type | Description | -| ----------- | ----------- | ----------- | -|`iterated_matrix_game`| `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ | -|`iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | -|`infinite_matrix_game` | `meta`| An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ | -|coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode| -|coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode| +| Environment ID | Environment Type | Description | +|------------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `iterated_matrix_game` | `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ | +| `iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | +| `infinite_matrix_game` | `meta` | An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ | +| coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode | +| coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | +| cournot | `sequential`/`meta` | A one-shot version of a [Cournot competition](https://en.wikipedia.org/wiki/Cournot_competition) | +| fishery | `sequential`/`meta` | A dynamic resource harvesting game as specified in Perman et al. | +| Rice-N | `sequential`/`meta` | A re-implementation of the Integrated Assessment Model introduced by [Zhang et al.](https://papers.ssrn.com/abstract=4189735) available with either the original 27 regions or a new calibration of only 5 regions | +| C-Rice-N | `sequential`/`meta` | An extension of Rice-N with a simple climate club mechanism | diff --git a/pax/conf/experiment/c_rice/debug.yaml b/pax/conf/experiment/c_rice/debug.yaml new file mode 100644 index 00000000..fe6c6422 --- /dev/null +++ b/pax/conf/experiment/c_rice/debug.yaml @@ -0,0 +1,81 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: C-Rice-v1 +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/Rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 2 +num_inner_steps: 20 +num_iters: 1 +num_devices: 1 +num_steps: 4 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-Rice + group: 'mediator' + mode: 'offline' + name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}' + log: False + + diff --git a/pax/conf/experiment/c_rice/marl_baseline.yaml b/pax/conf/experiment/c_rice/marl_baseline.yaml new file mode 100644 index 00000000..e98e55ea --- /dev/null +++ b/pax/conf/experiment/c_rice/marl_baseline.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# Agents +agent_default: 'PPO' + +# Environment +env_id: C-Rice-v1 +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/Rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 20 +num_iters: 2000 +num_devices: 1 +num_steps: 200 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# Logging setup +wandb: + project: c-Rice + group: 'mediator' + name: 'c-Rice-MARL-${agent_default}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/c_rice/mediator_gs_naive.yaml b/pax/conf/experiment/c_rice/mediator_gs_naive.yaml new file mode 100644 index 00000000..9e3b543e --- /dev/null +++ b/pax/conf/experiment/c_rice/mediator_gs_naive.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'Naive' + +# Environment +env_id: C-Rice-v1 +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/Rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 20 +num_iters: 3500 +num_devices: 1 +num_steps: 4 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +naive: + num_minibatches: 1 + num_epochs: 1 + gamma: 1 + gae_lambda: 0.95 + max_gradient_norm: 1.0 + learning_rate: 1.0 + adam_epsilon: 1e-5 + entropy_coeff: 0.0 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-Rice + group: 'mediator' + name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml new file mode 100644 index 00000000..dd8769cc --- /dev/null +++ b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: C-Rice-v1 +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/Rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 200 +num_inner_steps: 20 +num_iters: 3500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-Rice + group: 'mediator' + name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml new file mode 100644 index 00000000..ed8bb9d5 --- /dev/null +++ b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml @@ -0,0 +1,84 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +agent2_roles: 4 + +# Environment +env_id: C-Rice-v1 +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: False +config_folder: pax/envs/Rice/5_regions +runner: evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 500 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-Rice + group: 'shaper' + name: 'c-Rice-SHAPER-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/weight_sharing.yaml b/pax/conf/experiment/c_rice/weight_sharing.yaml new file mode 100644 index 00000000..fdeb61fb --- /dev/null +++ b/pax/conf/experiment/c_rice/weight_sharing.yaml @@ -0,0 +1,53 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: C-Rice-v1 +env_type: sequential +num_players: 5 +has_mediator: False +config_folder: pax/envs/Rice/5_regions +runner: weight_sharing +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 20 +num_iters: 6e6 +save_interval: 100 +num_steps: 2000 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: c-Rice + group: 'weight_sharing' + name: 'c-Rice-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/conf/experiment/n_player/3pl.yaml b/pax/conf/experiment/n_player/3pl.yaml deleted file mode 100644 index 81459534..00000000 --- a/pax/conf/experiment/n_player/3pl.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# @package _global_ - -# Agents -agent1: 'PPO_memory' -agent2: 'PPO_memory' -agent3: 'PPO_memory' - -# Environment -env_id: iterated_nplayer_tensor_game -num_players: 3 -env_type: meta -env_discount: 0.96 -payoff_table: [ - [4, -1000], - [2.66, 5.66], - [1.33, 4.33], - [-10000, 3], -] -# Runner -runner: tensor_evo_nplayer - - -# Training -top_k: 5 -popsize: 50 -num_envs: 2 -num_opps: 10 -num_inner_steps: 100 -num_outer_steps: 1000 -num_iters: 2000 -# total_timesteps: 2.5e7 -num_devices: 1 -# Evaluation - - -# PPO agent parameters -ppo1: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo2: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo3: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -# ES parameters -es: - algo: OpenES # [OpenES, CMA_ES] - sigma_init: 0.04 # Initial scale of isotropic Gaussian noise - sigma_decay: 0.999 # Multiplicative decay factor - sigma_limit: 0.01 # Smallest possible scale - init_min: 0.0 # Range of parameter mean initialization - Min - init_max: 0.0 # Range of parameter mean initialization - Max - clip_min: -1e10 # Range of parameter proposals - Min - clip_max: 1e10 # Range of parameter proposals - Max - lrate_init: 0.1 # Initial learning rate - lrate_decay: 0.9999 # Multiplicative decay factor - lrate_limit: 0.001 # Smallest possible lrate - beta_1: 0.99 # Adam - beta_1 - beta_2: 0.999 # Adam - beta_2 - eps: 1e-8 # eps constant, - centered_rank: False # Fitness centered_rank - w_decay: 0 # Decay old elite fitness - maximise: True # Maximise fitness - z_score: False # Normalise fitness - mean_reduce: True # Remove mean -# Logging setup -wandb: - entity: "ucl-dark" - project: tensor - group: 'n-player' - name: 3pl - log: True \ No newline at end of file diff --git a/pax/conf/experiment/n_player/3pl_old_payoff.yaml b/pax/conf/experiment/n_player/3pl_old_payoff.yaml deleted file mode 100644 index 51db9bf8..00000000 --- a/pax/conf/experiment/n_player/3pl_old_payoff.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# @package _global_ - -# Agents -agent1: 'PPO_memory' -agent2: 'PPO_memory' -agent3: 'PPO_memory' - -# Environment -env_id: iterated_nplayer_tensor_game -num_players: 3 -env_type: meta -env_discount: 0.96 -payoff_table: [ - [-1, -1000], - [-3, 0], - [-5, -2], - [-10000, -4], -] -# Runner -runner: tensor_evo_nplayer - - -# Training -top_k: 5 -popsize: 50 -num_envs: 2 -num_opps: 10 -num_inner_steps: 100 -num_outer_steps: 1000 -num_iters: 2000 -# total_timesteps: 2.5e7 -num_devices: 1 -# Evaluation - - -# PPO agent parameters -ppo1: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo2: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo3: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -# ES parameters -es: - algo: OpenES # [OpenES, CMA_ES] - sigma_init: 0.04 # Initial scale of isotropic Gaussian noise - sigma_decay: 0.999 # Multiplicative decay factor - sigma_limit: 0.01 # Smallest possible scale - init_min: 0.0 # Range of parameter mean initialization - Min - init_max: 0.0 # Range of parameter mean initialization - Max - clip_min: -1e10 # Range of parameter proposals - Min - clip_max: 1e10 # Range of parameter proposals - Max - lrate_init: 0.1 # Initial learning rate - lrate_decay: 0.9999 # Multiplicative decay factor - lrate_limit: 0.001 # Smallest possible lrate - beta_1: 0.99 # Adam - beta_1 - beta_2: 0.999 # Adam - beta_2 - eps: 1e-8 # eps constant, - centered_rank: False # Fitness centered_rank - w_decay: 0 # Decay old elite fitness - maximise: True # Maximise fitness - z_score: False # Normalise fitness - mean_reduce: True # Remove mean -# Logging setup -wandb: - entity: "ucl-dark" - project: tensor - group: 'n-player' - name: 3pl_old_payoff - log: True \ No newline at end of file diff --git a/pax/conf/experiment/n_player/4pl.yaml b/pax/conf/experiment/n_player/4pl.yaml deleted file mode 100644 index 05c7133e..00000000 --- a/pax/conf/experiment/n_player/4pl.yaml +++ /dev/null @@ -1,146 +0,0 @@ -# @package _global_ - -# Agents -agent1: 'PPO_memory' -agent2: 'PPO_memory' -agent3: 'PPO_memory' -agent4: 'PPO_memory' - -# Environment -env_id: iterated_nplayer_tensor_game -num_players: 4 -env_type: meta -env_discount: 0.96 -payoff_table: [ - [4, -1000], - [3, 6], - [2, 5], - [1, 4], - [-10000, 3], -] -# Runner -runner: tensor_evo_nplayer - - -# Training -top_k: 5 -popsize: 50 -num_envs: 2 -num_opps: 10 -num_inner_steps: 100 -num_outer_steps: 1000 -num_iters: 2000 -# total_timesteps: 2.5e7 -num_devices: 1 -# Evaluation - - -# PPO agent parameters -ppo1: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo2: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo3: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo4: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -# ES parameters -es: - algo: OpenES # [OpenES, CMA_ES] - sigma_init: 0.04 # Initial scale of isotropic Gaussian noise - sigma_decay: 0.999 # Multiplicative decay factor - sigma_limit: 0.01 # Smallest possible scale - init_min: 0.0 # Range of parameter mean initialization - Min - init_max: 0.0 # Range of parameter mean initialization - Max - clip_min: -1e10 # Range of parameter proposals - Min - clip_max: 1e10 # Range of parameter proposals - Max - lrate_init: 0.1 # Initial learning rate - lrate_decay: 0.9999 # Multiplicative decay factor - lrate_limit: 0.001 # Smallest possible lrate - beta_1: 0.99 # Adam - beta_1 - beta_2: 0.999 # Adam - beta_2 - eps: 1e-8 # eps constant, - centered_rank: False # Fitness centered_rank - w_decay: 0 # Decay old elite fitness - maximise: True # Maximise fitness - z_score: False # Normalise fitness - mean_reduce: True # Remove mean -# Logging setup -wandb: - entity: "ucl-dark" - project: tensor - group: 'n-player' - name: 4pl - log: True \ No newline at end of file diff --git a/pax/conf/experiment/n_player/5pl.yaml b/pax/conf/experiment/n_player/5pl.yaml deleted file mode 100644 index b7cc739b..00000000 --- a/pax/conf/experiment/n_player/5pl.yaml +++ /dev/null @@ -1,168 +0,0 @@ -# @package _global_ - -# Agents -agent1: 'PPO_memory' -agent2: 'PPO_memory' -agent3: 'PPO_memory' -agent4: 'PPO_memory' -agent5: 'PPO_memory' - -# Environment -env_id: iterated_nplayer_tensor_game -num_players: 5 -env_type: meta -env_discount: 0.96 -payoff_table: [ - [4, -1000], - [3.2, 6.2], - [2.4, 5.4], - [1.6, 4.6], - [0.8, 3.8], - [-10000, 3], -] -# Runner -runner: tensor_evo_nplayer - - -# Training -top_k: 5 -popsize: 50 -num_envs: 2 -num_opps: 10 -num_inner_steps: 100 -num_outer_steps: 1000 -num_iters: 2000 -# total_timesteps: 2.5e7 -num_devices: 1 -# Evaluation - - -# PPO agent parameters -ppo1: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo2: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo3: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo4: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -ppo5: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.4e9 - entropy_coeff_end: 0.01 - lr_scheduling: False - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - with_cnn: False - separate: False - hidden_size: 16 -# ES parameters -es: - algo: OpenES # [OpenES, CMA_ES] - sigma_init: 0.04 # Initial scale of isotropic Gaussian noise - sigma_decay: 0.999 # Multiplicative decay factor - sigma_limit: 0.01 # Smallest possible scale - init_min: 0.0 # Range of parameter mean initialization - Min - init_max: 0.0 # Range of parameter mean initialization - Max - clip_min: -1e10 # Range of parameter proposals - Min - clip_max: 1e10 # Range of parameter proposals - Max - lrate_init: 0.1 # Initial learning rate - lrate_decay: 0.9999 # Multiplicative decay factor - lrate_limit: 0.001 # Smallest possible lrate - beta_1: 0.99 # Adam - beta_1 - beta_2: 0.999 # Adam - beta_2 - eps: 1e-8 # eps constant, - centered_rank: False # Fitness centered_rank - w_decay: 0 # Decay old elite fitness - maximise: True # Maximise fitness - z_score: False # Normalise fitness - mean_reduce: True # Remove mean -# Logging setup -wandb: - entity: "ucl-dark" - project: tensor - group: 'n-player' - name: 5pl - log: True \ No newline at end of file diff --git a/pax/conf/experiment/rice/marl_baseline.yaml b/pax/conf/experiment/rice/marl_baseline.yaml index 0a94445c..26f72bfc 100644 --- a/pax/conf/experiment/rice/marl_baseline.yaml +++ b/pax/conf/experiment/rice/marl_baseline.yaml @@ -16,7 +16,7 @@ top_k: 5 popsize: 1000 num_envs: 2 num_opps: 1 -num_outer_steps: 2000 +num_outer_steps: 1 num_inner_steps: 20 num_iters: 2000 num_devices: 1 @@ -49,6 +49,6 @@ ppo_default: # Logging setup wandb: project: rice - group: 'mediator' + group: 'marl_baseline' name: 'rice-MARL-${agent_default}-seed-${seed}' log: True diff --git a/pax/conf/experiment/rice/shaper_v_ppo.yaml b/pax/conf/experiment/rice/shaper_v_ppo.yaml index db0d6a8a..c9c834e7 100644 --- a/pax/conf/experiment/rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/rice/shaper_v_ppo.yaml @@ -76,7 +76,7 @@ es: # Logging setup wandb: - project: rice + project: c-rice group: 'shaper' name: 'rice-SHAPER-${agent_default}-seed-${seed}' log: True diff --git a/pax/envs/rice/c_rice.py b/pax/envs/rice/c_rice.py index cd100f8b..fb70978b 100644 --- a/pax/envs/rice/c_rice.py +++ b/pax/envs/rice/c_rice.py @@ -25,15 +25,13 @@ class Rice(environment.Environment): - env_id: str = "Rice-v1" + env_id: str = "C-Rice-v1" def __init__(self, num_inner_steps: int, config_folder: str, has_mediator=False): super().__init__() + if has_mediator is False: + raise NotImplementedError("C-Rice environment without mediator is not implemented yet") - # TODO refactor all the constants to use env_params - # 1. Load env params in the experiment.py#env_setup - # 2. type env params as a chex dataclass - # 3. change the references in the code to env params params, num_regions = load_rice_params(config_folder) self.has_mediator = has_mediator self.num_players = num_regions @@ -74,6 +72,7 @@ def __init__(self, num_inner_steps: int, config_folder: str, has_mediator=False) self.sub_rate = jnp.asarray(0.5, dtype=float_precision) self.dom_pref = jnp.asarray(0.5, dtype=float_precision) self.for_pref = jnp.asarray([0.5 / (self.num_players - 1)] * self.num_players, dtype=float_precision) + self.default_club_tariff_rate = jnp.asarray(0.1, dtype=float_precision) def _step( key: chex.PRNGKey, @@ -113,15 +112,19 @@ def _step( self.dice_constant["xDelta"], t ) - # Get the maximum carbon price of non-members from the last timestep - club_price = jnp.max(state.carbon_price_all * (1 - state.club_membership_all)) - club_mitigation_rates = get_club_mitigation_rates( - club_price, - intensity_all, - self.rice_constant["xtheta_2"], - mitigation_cost_all, - state.damages_all - ) + + if has_mediator: + club_mitigation_rates = actions[0, self.mitigation_rate_action_index] + else: + # Get the maximum carbon price of non-members from the last timestep + club_price = jnp.max(state.carbon_price_all * (1 - state.club_membership_all)) + club_mitigation_rates = get_club_mitigation_rates( + club_price, + intensity_all, + self.rice_constant["xtheta_2"], + mitigation_cost_all, + state.damages_all + ) mitigation_rate_all = jnp.where( club_membership_all == 1, club_mitigation_rates, @@ -234,6 +237,18 @@ def _step( self.rice_constant["xtheta_2"], damages_all) + if has_mediator: + club_tariff_rate = actions[0, self.tariffs_action_index] + else: + club_tariff_rate = self.default_club_tariff_rate + + desired_future_tariffs = region_actions[:,self.tariffs_action_index: self.tariffs_action_index + self.num_players] + future_tariffs = jnp.where( + club_membership_all == 1, + jnp.clip(desired_future_tariffs, min=club_tariff_rate), + desired_future_tariffs + ) + next_state = EnvState( inner_t=state.inner_t + 1, outer_t=state.outer_t, global_temperature=global_temperature, diff --git a/pax/runners/runner_eval_nplayer.py b/pax/runners/runner_eval_nplayer.py deleted file mode 100644 index 4a8d8cfa..00000000 --- a/pax/runners/runner_eval_nplayer.py +++ /dev/null @@ -1,567 +0,0 @@ -import os -import time -from typing import Any, List, NamedTuple, Tuple - -import jax -import jax.numpy as jnp -from omegaconf import OmegaConf, omegaconf - -import wandb -from pax.utils import MemoryState, TrainingState, load -from pax.watchers import ( - ipditm_stats, - n_player_ipd_visitation, -) - -MAX_WANDB_CALLS = 1000 - - -class Sample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - - -class MFOSSample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - meta_actions: jnp.ndarray - - -@jax.jit -def reduce_outer_traj(traj: Sample) -> Sample: - """Used to collapse lax.scan outputs dims""" - # x: [outer_loop, inner_loop, num_opps, num_envs ...] - # x: [timestep, batch_size, ...] - num_envs = traj.observations.shape[2] * traj.observations.shape[3] - num_timesteps = traj.observations.shape[0] * traj.observations.shape[1] - return jax.tree_util.tree_map( - lambda x: x.reshape((num_timesteps, num_envs) + x.shape[4:]), - traj, - ) - - -class NPlayerEvalRunner: - """ - Reinforcement Learning runner provides a convenient example for quickly writing - a MARL runner for PAX. The MARLRunner class can be used to - run any two RL agents together either in a meta-game or regular game, it composes together agents, - watchers, and the environment. Within the init, we declare vmaps and pmaps for training. - Args: - agents (Tuple[agents]): - The set of agents that will run in the experiment. Note, ordering is - important for logic used in the class. - env (gymnax.envs.Environment): - The environment that the agents will run in. - save_dir (string): - The directory to save the model to. - args (NamedTuple): - A tuple of experiment arguments used (usually provided by HydraConfig). - """ - - # flake8: noqa: C901 - def __init__(self, agents, env, save_dir, args): - self.train_steps = 0 - self.train_episodes = 0 - self.start_time = time.time() - self.args = args - self.num_opps = args.num_opps - self.random_key = jax.random.PRNGKey(args.seed) - self.save_dir = save_dir - - def _reshape_opp_dim(x): - # x: [num_opps, num_envs ...] - # x: [batch_size, ...] - batch_size = args.num_envs * args.num_opps - return jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), x - ) - - self.reduce_opp_dim = jax.jit(_reshape_opp_dim) - self.ipd_stats = n_player_ipd_visitation - # VMAP for num envs: we vmap over the rng but not params - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - self.ipditm_stats = jax.jit(ipditm_stats) - # VMAP for num opps: we vmap over the rng but not params - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( - jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - ) - - self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) - self.num_outer_steps = self.args.num_outer_steps - agent1, *other_agents = agents - # set up agents - if args.agent1 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent1.batch_init = jax.jit(jax.vmap(agent1.make_initial_state)) - else: - # batch MemoryState not TrainingState - agent1.batch_init = jax.vmap( - agent1.make_initial_state, - (None, 0), - (None, 0), - ) - agent1.batch_reset = jax.jit( - jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 - ) - - agent1.batch_policy = jax.jit( - jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)) - ) - - # go through opponents, we start with agent2 - for agent_idx, non_first_agent in enumerate(other_agents): - agent_arg = f"agent{agent_idx+2}" - # equivalent of args.agent_n - if OmegaConf.select(args, agent_arg) == "NaiveEx": - non_first_agent.batch_init = jax.jit( - jax.vmap(non_first_agent.make_initial_state) - ) - else: - non_first_agent.batch_init = jax.vmap( - non_first_agent.make_initial_state, (0, None), 0 - ) - non_first_agent.batch_policy = jax.jit( - jax.vmap(non_first_agent._policy) - ) - non_first_agent.batch_reset = jax.jit( - jax.vmap(non_first_agent.reset_memory, (0, None), 0), - static_argnums=1, - ) - non_first_agent.batch_update = jax.jit( - jax.vmap(non_first_agent.update, (1, 0, 0, 0), 0) - ) - - if args.agent1 != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) - agent1._state, agent1._mem = agent1.batch_init( - agent1._state.random_key, init_hidden - ) - - for agent_idx, non_first_agent in enumerate(other_agents): - agent_arg = f"agent{agent_idx+2}" - # equivalent of args.agent_n - if OmegaConf.select(args, agent_arg) != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile( - non_first_agent._mem.hidden, (args.num_opps, 1, 1) - ) - ( - non_first_agent._state, - non_first_agent._mem, - ) = non_first_agent.batch_init( - jax.random.split( - non_first_agent._state.random_key, args.num_opps - ), - init_hidden, - ) - - def _inner_rollout(carry, unused): - """Runner for inner episode""" - ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ) = carry - new_other_agent_mem = [None] * len(other_agents) - # unpack rngs - rngs = self.split(rngs, 4) - env_rng = rngs[:, :, 0, :] - # a1_rng = rngs[:, :, 1, :] - # a2_rng = rngs[:, :, 2, :] - rngs = rngs[:, :, 3, :] - actions = [] - ( - first_action, - first_agent_state, - new_first_agent_mem, - ) = agent1.batch_policy( - first_agent_state, - first_agent_obs, - first_agent_mem, - ) - - actions.append(first_action) - for agent_idx, non_first_agent in enumerate(other_agents): - ( - non_first_action, - other_agent_state[agent_idx], - new_other_agent_mem[agent_idx], - ) = non_first_agent.batch_policy( - other_agent_state[agent_idx], - other_agent_obs[agent_idx], - other_agent_mem[agent_idx], - ) - actions.append(non_first_action) - ( - all_agent_next_obs, - env_state, - all_agent_rewards, - done, - info, - ) = env.step( - env_rng, - env_state, - actions, - env_params, - ) - first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs - first_agent_reward, *other_agent_rewards = all_agent_rewards - - traj1 = Sample( - first_agent_next_obs, - first_action, - first_agent_reward, - new_first_agent_mem.extras["log_probs"], - new_first_agent_mem.extras["values"], - done, - first_agent_mem.hidden, - ) - other_traj = [ - Sample( - other_agent_next_obs[agent_idx], - actions[agent_idx + 1], - other_agent_rewards[agent_idx], - new_other_agent_mem[agent_idx].extras["log_probs"], - new_other_agent_mem[agent_idx].extras["values"], - done, - other_agent_mem[agent_idx].hidden, - ) - for agent_idx in range(len(other_agents)) - ] - return ( - rngs, - first_agent_next_obs, - tuple(other_agent_next_obs), - first_agent_reward, - tuple(other_agent_rewards), - first_agent_state, - other_agent_state, - new_first_agent_mem, - new_other_agent_mem, - env_state, - env_params, - ), (traj1, *other_traj) - - def _outer_rollout(carry, unused): - """Runner for trial""" - # play episode of the game - vals, trajectories = jax.lax.scan( - _inner_rollout, - carry, - None, - length=self.args.num_inner_steps, - ) - other_agent_metrics = [None] * len(other_agents) - ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ) = vals - # MFOS has to take a meta-action for each episode - if args.agent1 == "MFOS": - first_agent_mem = agent1.meta_policy(first_agent_mem) - - # update second agent - for agent_idx, non_first_agent in enumerate(other_agents): - ( - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - other_agent_metrics[agent_idx], - ) = non_first_agent.batch_update( - trajectories[agent_idx + 1], - other_agent_obs[agent_idx], - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - ) - return ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ), (trajectories, other_agent_metrics) - - def _rollout( - _rng_run: jnp.ndarray, - first_agent_state: TrainingState, - first_agent_mem: MemoryState, - other_agent_state: List[TrainingState], - other_agent_mem: List[MemoryState], - _env_params: Any, - ): - # env reset - rngs = jnp.concatenate( - [jax.random.split(_rng_run, args.num_envs)] * args.num_opps - ).reshape((args.num_opps, args.num_envs, -1)) - - obs, env_state = env.reset(rngs, _env_params) - rewards = [ - jnp.zeros((args.num_opps, args.num_envs), dtype=jnp.float32) - ] * args.num_players - # Player 1 - first_agent_mem = agent1.batch_reset(first_agent_mem, False) - # Other players - _rng_run, other_agent_rng = jax.random.split(_rng_run, 2) - for agent_idx, non_first_agent in enumerate(other_agents): - # indexing starts at 2 for args - agent_arg = f"agent{agent_idx+2}" - # equivalent of args.agent_n - if OmegaConf.select(args, agent_arg) == "NaiveEx": - ( - other_agent_mem[agent_idx], - other_agent_state[agent_idx], - ) = non_first_agent.batch_init(obs[agent_idx + 1]) - - elif self.args.env_type in ["meta"]: - # meta-experiments - init other agents per trial - ( - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - ) = non_first_agent.batch_init( - jax.random.split(other_agent_rng, self.num_opps), - non_first_agent._mem.hidden, - ) - _rng_run, other_agent_rng = jax.random.split(_rng_run, 2) - - # run trials - vals, stack = jax.lax.scan( - _outer_rollout, - ( - rngs, - obs[0], - tuple(obs[1:]), - rewards[0], - tuple(rewards[1:]), - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - _env_params, - ), - None, - length=self.num_outer_steps, - ) - - ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ) = vals - trajectories, other_agent_metrics = stack - - # reset memory - first_agent_mem = agent1.batch_reset(first_agent_mem, False) - for agent_idx, non_first_agent in enumerate(other_agents): - other_agent_mem[agent_idx] = non_first_agent.batch_reset( - other_agent_mem[agent_idx], False - ) - # Stats - if args.env_id == "iterated_nplayer_tensor_game": - total_env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipd_stats( - trajectories[0].observations, - num_players=args.num_players, - ), - ) - total_rewards = [traj.rewards.mean() for traj in trajectories] - else: - total_env_stats = {} - total_rewards = [traj.rewards.mean() for traj in trajectories] - - return ( - total_env_stats, - total_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - other_agent_metrics, - trajectories, - ) - - # self.rollout = _rollout - self.rollout = jax.jit(_rollout) - - def run_loop(self, env_params, agents, watchers): - """Run training of agents in environment""" - print("Training") - print("-----------------------") - agent1, *other_agents = agents - rng, _ = jax.random.split(self.random_key) - - first_agent_state, first_agent_mem = agent1._state, agent1._mem - other_agent_mem = [None] * len(other_agents) - other_agent_state = [None] * len(other_agents) - - for agent_idx, non_first_agent in enumerate(other_agents): - model_path = omegaconf.OmegaConf.select(self.args, f"model_path{agent_idx}", default=None) - if model_path is not None: - wandb.restore( - name=model_path, - run_path=self.args.run_path, - root=os.getcwd(), - ) - pretrained_params = load(model_path) - other_agent_state[agent_idx] = non_first_agent._state._replace( - params=pretrained_params - ) - else: - other_agent_state[agent_idx] = non_first_agent._state - other_agent_mem[agent_idx] = non_first_agent._mem - - wandb.restore( - name=self.args.model_path1, - run_path=self.args.run_path, - root=os.getcwd(), - ) - - pretrained_params = load(self.args.model_path1) - first_agent_state = first_agent_state._replace( - params=pretrained_params - ) - - # run actual loop - rng, rng_run = jax.random.split(rng, 2) - # RL Rollout - ( - env_stats, - total_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - other_agent_metrics, - trajectories, - ) = self.rollout( - rng_run, - first_agent_state, - first_agent_mem, - other_agent_state, - other_agent_mem, - env_params, - ) - - for stat in env_stats.keys(): - print(stat + f": {env_stats[stat].item()}") - print( - f"Total Episode Reward: {[float(rew.mean()) for rew in total_rewards]}" - ) - print() - - if watchers: - - list_traj1 = [ - Sample( - observations=jax.tree_util.tree_map( - lambda x: x[i, ...], trajectories[0].observations - ), - actions=trajectories[0].actions[i, ...], - rewards=trajectories[0].rewards[i, ...], - dones=trajectories[0].dones[i, ...], - # env_state=None, - behavior_log_probs=trajectories[0].behavior_log_probs[ - i, ... - ], - behavior_values=trajectories[0].behavior_values[i, ...], - hiddens=trajectories[0].hiddens[i, ...], - ) - for i in range(self.args.num_outer_steps) - ] - - list_of_env_stats = [ - jax.tree_util.tree_map( - lambda x: x.item(), - self.ipd_stats( - observations=traj.observations, - num_players=self.args.num_players, - ), - ) - for traj in list_traj1 - ] - - # log agent one - watchers[0](agents[0]) - # log the inner episodes - rewards_log = [ - { - f"eval/reward_per_timestep/player_{agent_idx+1}": float( - traj.rewards[i].mean().item() - ) - for (agent_idx, traj) in enumerate(trajectories) - } - for i in range(len(list_of_env_stats)) - ] - - for i in range(len(list_of_env_stats)): - wandb.log( - { - "train_iteration": i, - } - | list_of_env_stats[i] - | rewards_log[i] - ) - total_rewards_log = { - f"eval/meta_reward/player_{idx+1}": float(rew.mean().item()) - for (idx, rew) in enumerate(total_rewards) - } - wandb.log( - { - "episodes": 1, - } - | total_rewards_log - ) - - return agents diff --git a/pax/runners/runner_evo_nplayer.py b/pax/runners/runner_evo_nplayer.py deleted file mode 100644 index 28226fe8..00000000 --- a/pax/runners/runner_evo_nplayer.py +++ /dev/null @@ -1,716 +0,0 @@ -import os -import time -from datetime import datetime -from typing import Any, Callable, NamedTuple -from functools import partial -import jax -import jax.numpy as jnp -from evosax import FitnessShaper -from omegaconf import OmegaConf -import wandb -from pax.utils import MemoryState, TrainingState, save - -# TODO: import when evosax library is updated -# from evosax.utils import ESLog -from pax.watchers import ESLog, n_player_ipd_visitation -from pax.watchers.cournot import cournot_stats -from pax.watchers.fishery import fishery_stats -from pax.watchers.rice import rice_stats - -MAX_WANDB_CALLS = 1000 - - -class Sample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - - -class NPlayerEvoRunner: - """ - Evolutionary Strategy runner provides a convenient example for quickly writing - a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. - It composes together agents, watchers, and the environment. - Within the init, we declare vmaps and pmaps for training. - The environment provided must conform to a meta-environment. - Args: - agents (Tuple[agents]): - The set of agents that will run in the experiment. Note, ordering is - important for logic used in the class. - env (gymnax.envs.Environment): - The meta-environment that the agents will run in. - strategy (evosax.Strategy): - The evolutionary strategy that will be used to train the agents. - param_reshaper (evosax.param_reshaper.ParameterReshaper): - A function that reshapes the parameters of the agents into a format that can be - used by the strategy. - save_dir (string): - The directory to save the model to. - args (NamedTuple): - A tuple of experiment arguments used (usually provided by HydraConfig). - """ - - def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args - ): - self.args = args - self.algo = args.es.algo - self.es_params = es_params - self.generations = 0 - self.num_opps = args.num_opps - self.param_reshaper = param_reshaper - self.popsize = args.popsize - self.random_key = jax.random.PRNGKey(args.seed) - self.start_datetime = datetime.now() - self.save_dir = save_dir - self.start_time = time.time() - self.strategy = strategy - self.top_k = args.top_k - self.train_steps = 0 - self.train_episodes = 0 - # TODO JIT this - self.ipd_stats = n_player_ipd_visitation - self.fishery_stats = fishery_stats - - # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) - # Evo Runner also has an additional pmap dim (num_devices, ...) - # For the env we vmap over the rng but not params - - # num envs - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - - # num opps - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - # pop size - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( - jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - ) - self.split = jax.vmap( - jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), - (0, None), - ) - - self.num_outer_steps = args.num_outer_steps - agent1, *other_agents = agents - - # vmap agents accordingly - # agent 1 is batched over popsize and num_opps - agent1.batch_init = jax.vmap( - jax.vmap( - agent1.make_initial_state, - (None, 0), # (params, rng) - (None, 0), # (TrainingState, MemoryState) - ), - # both for Population - ) - agent1.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent1.batch_policy = jax.jit( - jax.vmap( - jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), - ) - ) - # go through opponents, we start with agent2 - for agent_idx, non_first_agent in enumerate(other_agents): - agent_arg = f"agent{agent_idx+2}" - # equivalent of args.agent_n - if OmegaConf.select(args, agent_arg) == "NaiveEx": - # special case where NaiveEx has a different call signature - non_first_agent.batch_init = jax.jit( - jax.vmap(jax.vmap(non_first_agent.make_initial_state)) - ) - else: - non_first_agent.batch_init = jax.jit( - jax.vmap( - jax.vmap( - non_first_agent.make_initial_state, (0, None), 0 - ), - (0, None), - 0, - ) - ) - - non_first_agent.batch_policy = jax.jit( - jax.vmap(jax.vmap(non_first_agent._policy, 0, 0)) - ) - non_first_agent.batch_reset = jax.jit( - jax.vmap( - jax.vmap(non_first_agent.reset_memory, (0, None), 0), - (0, None), - 0, - ), - static_argnums=1, - ) - - non_first_agent.batch_update = jax.jit( - jax.vmap( - jax.vmap(non_first_agent.update, (1, 0, 0, 0)), - (1, 0, 0, 0), - ) - ) - if OmegaConf.select(args, agent_arg) != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile( - non_first_agent._mem.hidden, (args.num_opps, 1, 1) - ) - - agent_rng = jnp.concatenate( - [ - jax.random.split( - non_first_agent._state.random_key, args.num_opps - ) - ] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - - ( - non_first_agent._state, - non_first_agent._mem, - ) = non_first_agent.batch_init( - agent_rng, - init_hidden, - ) - - # jit evo - strategy.ask = jax.jit(strategy.ask) - strategy.tell = jax.jit(strategy.tell) - param_reshaper.reshape = jax.jit(param_reshaper.reshape) - - def _inner_rollout(carry, unused): - """Runner for inner episode""" - ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ) = carry - new_other_agent_mem = [None] * len(other_agents) - # unpack rngs - rngs = self.split(rngs, 4) - env_rng = rngs[:, :, :, 0, :] - - # a1_rng = rngs[:, :, :, 1, :] - # a2_rng = rngs[:, :, :, 2, :] - rngs = rngs[:, :, :, 3, :] - actions = [] - ( - first_action, - first_agent_state, - new_first_agent_mem, - ) = agent1.batch_policy( - first_agent_state, - first_agent_obs, - first_agent_mem, - ) - actions.append(first_action) - for agent_idx, non_first_agent in enumerate(other_agents): - ( - non_first_action, - other_agent_state[agent_idx], - new_other_agent_mem[agent_idx], - ) = non_first_agent.batch_policy( - other_agent_state[agent_idx], - other_agent_obs[agent_idx], - other_agent_mem[agent_idx], - ) - actions.append(non_first_action) - ( - all_agent_next_obs, - env_state, - all_agent_rewards, - done, - info, - ) = env.step( - env_rng, - env_state, - actions, - env_params, - ) - - first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs - first_agent_reward, *other_agent_rewards = all_agent_rewards - - traj1 = Sample( - first_agent_next_obs, - first_action, - first_agent_reward, - new_first_agent_mem.extras["log_probs"], - new_first_agent_mem.extras["values"], - done, - first_agent_mem.hidden, - ) - other_traj = [ - Sample( - other_agent_next_obs[agent_idx], - actions[agent_idx + 1], - other_agent_rewards[agent_idx], - new_other_agent_mem[agent_idx].extras["log_probs"], - new_other_agent_mem[agent_idx].extras["values"], - done, - other_agent_mem[agent_idx].hidden, - ) - for agent_idx in range(len(other_agents)) - ] - return ( - rngs, - first_agent_next_obs, - tuple(other_agent_next_obs), - first_agent_reward, - tuple(other_agent_rewards), - first_agent_state, - other_agent_state, - new_first_agent_mem, - new_other_agent_mem, - env_state, - env_params, - ), (traj1, *other_traj) - - def _outer_rollout(carry, unused): - """Runner for trial""" - # play episode of the game - vals, trajectories = jax.lax.scan( - _inner_rollout, - carry, - None, - length=args.num_steps, - ) - other_agent_metrics = [None] * len(other_agents) - ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ) = vals - # MFOS has to take a meta-action for each episode - if args.agent1 == "MFOS": - first_agent_mem = agent1.meta_policy(first_agent_mem) - # update opponents, we start with agent2 - for agent_idx, non_first_agent in enumerate(other_agents): - ( - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - other_agent_metrics[agent_idx], - ) = non_first_agent.batch_update( - trajectories[agent_idx + 1], - other_agent_obs[agent_idx], - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - ) - return ( - rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - env_params, - ), (trajectories, other_agent_metrics) - - def _rollout( - _params: jnp.ndarray, - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _env_params: Any, - ): - # env reset - env_rngs = jnp.concatenate( - [jax.random.split(_rng_run, args.num_envs)] - * args.num_opps - * args.popsize - ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) - - obs, env_state = env.reset(env_rngs, _env_params) - rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] * args.num_players - - # Player 1 - _a1_state = _a1_state._replace(params=_params) - _a1_mem = agent1.batch_reset(_a1_mem, False) - # Other players - other_agent_mem = [None] * len(other_agents) - other_agent_state = [None] * len(other_agents) - - _rng_run, *other_agent_rngs = jax.random.split( - _rng_run, args.num_players - ) - for agent_idx, non_first_agent in enumerate(other_agents): - # indexing starts at 2 for args - agent_arg = f"agent{agent_idx+2}" - # equivalent of args.agent_n - if OmegaConf.select(args, agent_arg) == "NaiveEx": - ( - other_agent_mem[agent_idx], - other_agent_state[agent_idx], - ) = non_first_agent.batch_init(obs[agent_idx + 1]) - else: - # meta-experiments - init 2nd agent per trial - non_first_agent_rng = jnp.concatenate( - [ - jax.random.split( - other_agent_rngs[agent_idx], args.num_opps - ) - ] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - ( - other_agent_state[agent_idx], - other_agent_mem[agent_idx], - ) = non_first_agent.batch_init( - non_first_agent_rng, - non_first_agent._mem.hidden, - ) - - # run trials - vals, stack = jax.lax.scan( - _outer_rollout, - ( - env_rngs, - obs[0], - tuple(obs[1:]), - rewards[0], - tuple(rewards[1:]), - _a1_state, - other_agent_state, - _a1_mem, - other_agent_mem, - env_state, - _env_params, - ), - None, - length=self.num_outer_steps, - ) - ( - env_rngs, - first_agent_obs, - other_agent_obs, - first_agent_reward, - other_agent_rewards, - first_agent_state, - other_agent_state, - first_agent_mem, - other_agent_mem, - env_state, - _env_params, - ) = vals - trajectories, other_agent_metrics = stack - - # Fitness - fitness = trajectories[0].rewards.mean(axis=(0, 1, 3, 4)) - other_fitness = [ - traj.rewards.mean(axis=(0, 1, 3, 4)) - for traj in trajectories[1:] - ] - # Stats - first_agent_reward = trajectories[0].rewards.mean() - other_agent_rewards = [ - traj.rewards.mean() for traj in trajectories[1:] - ] - if args.env_id in [ - "iterated_nplayer_tensor_game", - ]: - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipd_stats( - trajectories[0].observations, args.num_players - ), - ) - elif args.env_id == "Rice-v1": - env_stats = jax.tree_util.tree_map( - lambda x: x, - rice_stats( - trajectories, args.num_players, args.mediator - ), - ) - elif args.env_id == "Fishery": - env_stats = jax.tree_util.tree_map( - lambda x: x, - self.fishery_stats( - trajectories[0], args.num_players - ), - ) - elif args.env_id == "Cournot": - env_stats = jax.tree_util.tree_map( - lambda x: x, - cournot_stats( - trajectories[0].observations, _env_params, args.num_players - ), - ) - else: - env_stats = {} - - return ( - fitness, - other_fitness, - env_stats, - first_agent_reward, - other_agent_rewards, - other_agent_metrics, - ) - - self.rollout = jax.pmap( - _rollout, - in_axes=(0, None, None, None, None), - ) - - print( - f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" - ) - - def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, - ): - """Run training of agents in environment""" - print("Training") - print("------------------------------") - log_interval = max(num_iters / MAX_WANDB_CALLS, 5) - print(f"Number of Generations: {num_iters}") - print(f"Number of Meta Episodes: {self.num_outer_steps}") - print(f"Population Size: {self.popsize}") - print(f"Number of Environments: {self.args.num_envs}") - print(f"Number of Opponent: {self.args.num_opps}") - print(f"Log Interval: {log_interval}") - print("------------------------------") - # Initialize agents and RNG - rng, _ = jax.random.split(self.random_key) - - # Initialize evolution - num_gens = num_iters - strategy = self.strategy - es_params = self.es_params - param_reshaper = self.param_reshaper - popsize = self.popsize - num_opps = self.num_opps - evo_state = strategy.initialize(rng, es_params) - fit_shaper = FitnessShaper( - maximize=self.args.es.maximise, - centered_rank=self.args.es.centered_rank, - w_decay=self.args.es.w_decay, - z_score=self.args.es.z_score, - ) - es_logging = ESLog( - param_reshaper.total_params, - num_gens, - top_k=self.top_k, - maximize=True, - ) - log = es_logging.initialize() - - # Reshape a single agent's params before vmapping - init_hidden = jnp.tile( - agents[0]._mem.hidden, - (popsize, num_opps, 1, 1), - ) - a1_rng = jax.random.split(rng, popsize) - agents[0]._state, agents[0]._mem = agents[0].batch_init( - a1_rng, - init_hidden, - ) - - a1_state, a1_mem = agents[0]._state, agents[0]._mem - - for gen in range(num_gens): - rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) - - # Ask - x, evo_state = strategy.ask(rng_evo, evo_state, es_params) - params = param_reshaper.reshape(x) - if self.args.num_devices == 1: - params = jax.tree_util.tree_map( - lambda x: jax.lax.expand_dims(x, (0,)), params - ) - # Evo Rollout - ( - fitness, - other_fitness, - env_stats, - first_agent_reward, - other_agent_reward, - other_agent_metrics, - ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) - - # Aggregate over devices - fitness = jnp.reshape(fitness, popsize * self.args.num_devices) - env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) - - # Tell - fitness_re = fit_shaper.apply(x, fitness) - - if self.args.es.mean_reduce: - fitness_re = fitness_re - fitness_re.mean() - evo_state = strategy.tell(x, fitness_re, evo_state, es_params) - - # Logging - log = es_logging.update(log, x, fitness) - - # Saving - if gen % self.args.save_interval == 0: - log_savepath = os.path.join(self.save_dir, f"generation_{gen}") - if self.args.num_devices > 1: - top_params = param_reshaper.reshape( - log["top_gen_params"][0 : self.args.num_devices] - ) - top_params = jax.tree_util.tree_map( - lambda x: x[0].reshape(x[0].shape[1:]), top_params - ) - else: - top_params = param_reshaper.reshape( - log["top_gen_params"][0:1] - ) - top_params = jax.tree_util.tree_map( - lambda x: x.reshape(x.shape[1:]), top_params - ) - save(top_params, log_savepath) - if watchers: - print(f"Saving generation {gen} locally and to WandB") - wandb.save(log_savepath) - else: - print(f"Saving iteration {gen} locally") - - if gen % log_interval == 0: - print(f"Generation: {gen}/{num_gens}") - print( - "--------------------------------------------------------------------------" - ) - print( - f"Fitness: {fitness.mean()} | Other Fitness: {[fitness.mean() for fitness in other_fitness]}" - ) - print( - f"Reward Per Timestep: {float(first_agent_reward.mean()), *[float(reward.mean()) for reward in other_agent_reward]}" - ) - print( - f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" - ) - print( - "--------------------------------------------------------------------------" - ) - print( - f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" - f" | Std: {log['log_top_gen_std'][gen]}" - ) - print( - "--------------------------------------------------------------------------" - ) - print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") - print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") - print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") - print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") - print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") - print() - - if watchers: - rewards_strs = [ - "train/reward_per_timestep/player_" + str(i) - for i in range(2, len(other_agent_reward) + 2) - ] - rewards_val = [ - float(reward.mean()) for reward in other_agent_reward - ] - rewards_dict = dict(zip(rewards_strs, rewards_val)) - fitness_str = [ - "train/fitness/player_" + str(i) - for i in range(2, len(other_fitness) + 2) - ] - fitness_val = [ - float(fitness.mean()) for fitness in other_fitness - ] - fitness_dict = dict(zip(fitness_str, fitness_val)) - all_rewards = other_agent_reward + [first_agent_reward] - global_welfare = float( - sum([reward.mean() for reward in all_rewards]) - / self.args.num_players - ) - wandb_log = { - "train_iteration": gen, - "train/fitness/top_overall_mean": log["log_top_mean"][gen], - "train/fitness/top_overall_std": log["log_top_std"][gen], - "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], - "train/fitness/top_gen_std": log["log_top_gen_std"][gen], - "train/fitness/gen_std": log["log_gen_std"][gen], - "train/time/minutes": float( - (time.time() - self.start_time) / 60 - ), - "train/time/seconds": float( - (time.time() - self.start_time) - ), - "train/fitness/player_1": float(fitness.mean()), - "train/reward_per_timestep/player_1": float( - first_agent_reward.mean() - ), - "train/global_welfare": global_welfare, - } | rewards_dict - wandb_log = wandb_log | fitness_dict - wandb_log.update(env_stats) - # loop through population - for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) - ): - wandb_log[ - f"train/fitness/top_overall_agent_{idx+1}" - ] = overall_fitness - wandb_log[ - f"train/fitness/top_gen_agent_{idx+1}" - ] = gen_fitness - - # other player metrics - # metrics [outer_timesteps, num_opps] - for agent, metrics in zip(agents[1:], other_agent_metrics): - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), metrics - ) - - agent._logger.metrics.update(flattened_metrics) - # TODO fix agent logger - # for watcher, agent in zip(watchers, agents): - # watcher(agent) - wandb_log = jax.tree_util.tree_map( - lambda x: x.item() if isinstance(x, jax.Array) else x, - wandb_log, - ) - wandb.log(wandb_log) - - return agents diff --git a/pax/utils.py b/pax/utils.py index e4fc8578..e2d3d0f0 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -148,7 +148,6 @@ def load(filename: str): return es_logger -<<<<<<< HEAD def copy_state_and_network(agent): import copy @@ -211,5 +210,6 @@ def copy_extended_state_and_network(agent): value_network = agent.value_network return state, policy_network, value_network + # TODO make this part of the args float_precision = jnp.float32 diff --git a/profile_nplayer.ipynb b/profile_nplayer.ipynb deleted file mode 100644 index af1c1a79..00000000 --- a/profile_nplayer.ipynb +++ /dev/null @@ -1,218 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "initial_id", - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "!pip install --upgrade tensorflow tensorboard_plugin_profile" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c3892445326a96bd", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-09T13:52:53.364091Z", - "start_time": "2023-09-09T13:52:53.356565Z" - } - }, - "outputs": [], - "source": [ - "%load_ext tensorboard" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "53504be560736467", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-09T16:15:41.215799Z", - "start_time": "2023-09-09T16:15:40.933603Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/0g/tbd8_nhx01s7spsllv70v4100000gn/T/ipykernel_91158/755042361.py:4: UserWarning: \n", - "The version_base parameter is not specified.\n", - "Please specify a compatability version level, or None.\n", - "Will assume defaults for version 1.1\n", - " initialize(config_path=\"../pax/pax/conf\", job_name=\"create_runner\")\n" - ] - } - ], - "source": [ - "# Load the runner\n", - "from hydra import initialize, compose\n", - "\n", - "initialize(config_path=\"../pax/pax/conf\", job_name=\"create_runner\")\n", - "args = compose(config_name=\"config\", overrides=[\"+experiment/rice=debug\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9d7f0265213fb311", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-09T16:15:50.621816Z", - "start_time": "2023-09-09T16:15:42.493724Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/christophproschel/codingprojects/ism/thesis/main/pax/pax/experiment.py:705: UserWarning: \n", - "The version_base parameter is not specified.\n", - "Please specify a compatability version level, or None.\n", - "Will assume defaults for version 1.1\n", - " @hydra.main(config_path=\"conf\", config_name=\"config\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Jax backend: cpu\n", - "Making network for Rice-v1\n", - "Making network for Rice-v1\n", - "Making network for Rice-v1\n", - "Making network for Rice-v1\n", - "Making network for Rice-v1\n", - "Making network for Rice-v1\n", - "ParameterReshaper: 5952 parameters detected for optimization.\n", - "Time to Compile Jax Methods: 2.4162590503692627 Seconds\n" - ] - } - ], - "source": [ - "from pax.utils import Section\n", - "from pax.experiment import global_setup, env_setup, agent_setup, watcher_setup, runner_setup\n", - "import logging\n", - "from jax.lib import xla_bridge\n", - "from jax.config import config\n", - "config.update('jax_disable_jit', True)\n", - "print(f\"Jax backend: {xla_bridge.get_backend().platform}\")\n", - "\n", - "\"\"\"Set up main.\"\"\"\n", - "logger = logging.getLogger()\n", - "with Section(\"Global setup\", logger=logger):\n", - " save_dir = global_setup(args)\n", - "\n", - "with Section(\"Env setup\", logger=logger):\n", - " env, env_params = env_setup(args, logger)\n", - "\n", - "with Section(\"Agent setup\", logger=logger):\n", - " agent_pair = agent_setup(args, env, env_params, logger)\n", - "\n", - "with Section(\"Watcher setup\", logger=logger):\n", - " watchers = watcher_setup(args, logger)\n", - "\n", - "runner = runner_setup(args, env, agent_pair, save_dir, logger)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a727691a72bf37cc", - "metadata": { - "is_executing": true, - "ExecuteTime": { - "start_time": "2023-09-09T16:15:50.622298Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training\n", - "------------------------------\n", - "Number of Generations: 1\n", - "Number of Meta Episodes: 2\n", - "Population Size: 1000\n", - "Number of Environments: 1\n", - "Number of Opponent: 1\n", - "Log Interval: 5\n", - "------------------------------\n" - ] - } - ], - "source": [ - "import jax\n", - "from jax.config import config\n", - "config.update('jax_disable_jit', True)\n", - "\n", - "def run(num_iters):\n", - " runner.run_loop(env_params, agent_pair, num_iters, watchers)\n", - "\n", - "# COMPILE (3 times to make sure everything is compiled, you might need more or less \n", - "# depending on your code) \n", - "# for _ in range(3):\n", - "# jax.block_until_ready(run(num_iters=1))\n", - "\n", - "with jax.profiler.trace(\"/tmp/jax-trace\", create_perfetto_link=True):\n", - " jax.block_until_ready(runner.run_loop(env_params, agent_pair, 1, False))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "print(\"Done\")" - ], - "metadata": { - "collapsed": false, - "is_executing": true - }, - "id": "767131758bded205" - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7a7433c1b31005a", - "metadata": { - "ExecuteTime": { - "start_time": "2023-09-09T13:46:51.813932Z" - } - }, - "outputs": [], - "source": [ - "%tensorboard --logdir=/tmp/jax-trace --port 6006" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}