From 8ba77da4d344ea70a06daa9a33349924d0f2a7f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Pr=C3=B6schel?= Date: Wed, 16 Aug 2023 20:59:29 +0200 Subject: [PATCH] add rice environment --- pax/agents/ppo/networks.py | 26 + pax/agents/ppo/ppo.py | 6 +- pax/agents/ppo/ppo_gru.py | 12 +- pax/conf/experiment/cournot/gs_v_ppo.yaml | 38 +- .../experiment/cournot/marl_baseline.yaml | 33 +- pax/conf/experiment/cournot/mfos_v_ppo.yaml | 6 +- pax/conf/experiment/cournot/shaper_v_ppo.yaml | 8 +- pax/conf/experiment/fishery/gs_v_ppo.yaml | 6 +- .../experiment/fishery/marl_baseline.yaml | 2 +- pax/conf/experiment/rice/ctde.yaml | 51 ++ pax/conf/experiment/rice/sarl.yaml | 51 ++ pax/envs/cournot.py | 3 +- pax/envs/fishery.py | 58 +- pax/envs/region_yamls/11.yml | 13 + pax/envs/region_yamls/12.yml | 13 + pax/envs/region_yamls/13.yml | 13 + pax/envs/region_yamls/14.yml | 13 + pax/envs/region_yamls/15.yml | 13 + pax/envs/region_yamls/16.yml | 13 + pax/envs/region_yamls/17.yml | 13 + pax/envs/region_yamls/18.yml | 13 + pax/envs/region_yamls/19.yml | 13 + pax/envs/region_yamls/2.yml | 13 + pax/envs/region_yamls/20.yml | 13 + pax/envs/region_yamls/21.yml | 13 + pax/envs/region_yamls/22.yml | 13 + pax/envs/region_yamls/23.yml | 13 + pax/envs/region_yamls/24.yml | 13 + pax/envs/region_yamls/25.yml | 13 + pax/envs/region_yamls/26.yml | 13 + pax/envs/region_yamls/27.yml | 13 + pax/envs/region_yamls/28.yml | 13 + pax/envs/region_yamls/29.yml | 13 + pax/envs/region_yamls/3.yml | 13 + pax/envs/region_yamls/30.yml | 13 + pax/envs/region_yamls/4.yml | 13 + pax/envs/region_yamls/5.yml | 13 + pax/envs/region_yamls/6.yml | 13 + pax/envs/region_yamls/7.yml | 13 + pax/envs/region_yamls/9.yml | 13 + pax/envs/region_yamls/default.yml | 68 ++ pax/envs/rice.py | 645 +++++++++++++++++ pax/envs/rice_n.py | 131 ---- pax/envs/sarl_rice.py | 65 ++ pax/experiment.py | 46 +- pax/runners/runner_ctde.py | 251 +++++++ pax/runners/runner_evo.py | 4 +- pax/runners/runner_evo_nplayer.py | 646 +++++++++--------- pax/runners/runner_marl.py | 4 +- pax/runners/runner_marl_nplayer.py | 74 +- pax/runners/runner_sarl.py | 5 +- pax/watchers/cournot.py | 21 +- pax/watchers/fishery.py | 21 +- pax/watchers/rice.py | 19 + test/envs/test_fishery.py | 2 +- test/envs/test_rice.py | 65 ++ 56 files changed, 2091 insertions(+), 627 deletions(-) create mode 100644 pax/conf/experiment/rice/ctde.yaml create mode 100644 pax/conf/experiment/rice/sarl.yaml create mode 100644 pax/envs/region_yamls/11.yml create mode 100644 pax/envs/region_yamls/12.yml create mode 100644 pax/envs/region_yamls/13.yml create mode 100644 pax/envs/region_yamls/14.yml create mode 100644 pax/envs/region_yamls/15.yml create mode 100644 pax/envs/region_yamls/16.yml create mode 100644 pax/envs/region_yamls/17.yml create mode 100644 pax/envs/region_yamls/18.yml create mode 100644 pax/envs/region_yamls/19.yml create mode 100644 pax/envs/region_yamls/2.yml create mode 100644 pax/envs/region_yamls/20.yml create mode 100644 pax/envs/region_yamls/21.yml create mode 100644 pax/envs/region_yamls/22.yml create mode 100644 pax/envs/region_yamls/23.yml create mode 100644 pax/envs/region_yamls/24.yml create mode 100644 pax/envs/region_yamls/25.yml create mode 100644 pax/envs/region_yamls/26.yml create mode 100644 pax/envs/region_yamls/27.yml create mode 100644 pax/envs/region_yamls/28.yml create mode 100644 pax/envs/region_yamls/29.yml create mode 100644 pax/envs/region_yamls/3.yml create mode 100644 pax/envs/region_yamls/30.yml create mode 100644 pax/envs/region_yamls/4.yml create mode 100644 pax/envs/region_yamls/5.yml create mode 100644 pax/envs/region_yamls/6.yml create mode 100644 pax/envs/region_yamls/7.yml create mode 100644 pax/envs/region_yamls/9.yml create mode 100644 pax/envs/region_yamls/default.yml create mode 100644 pax/envs/rice.py delete mode 100644 pax/envs/rice_n.py create mode 100644 pax/envs/sarl_rice.py create mode 100644 pax/runners/runner_ctde.py create mode 100644 pax/watchers/rice.py create mode 100644 test/envs/test_rice.py diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index fcb5e66f..ab0cb11c 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -649,6 +649,32 @@ def forward_fn( return network, hidden_state +def make_GRU_rice_network( + num_actions: int, + hidden_size: int, +): + hidden_state = jnp.zeros((1, hidden_size)) + + def forward_fn( + inputs: jnp.ndarray, state: jnp.ndarray + ) -> tuple[tuple[MultivariateNormalDiag, Array], Any]: + """forward function""" + gru = hk.GRU( + hidden_size, + w_i_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + w_h_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + b_init=hk.initializers.Constant(0), + ) + + cvh = ContinuousValueHead(num_values=num_actions) + embedding, state = gru(inputs, state) + logits, values = cvh(embedding) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network, hidden_state + + def test_GRU(): key = jax.random.PRNGKey(seed=0) num_actions = 2 diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index ace779df..e2dd30eb 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -16,6 +16,8 @@ make_ipd_network, make_cournot_network, make_fishery_network, ) +from pax.envs.rice import Rice +from pax.envs.sarl_rice import SarlRice from pax.utils import Logger, MemoryState, TrainingState, get_advantages @@ -339,7 +341,7 @@ def model_update_epoch( return new_state, new_memory, metrics - def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: + def make_initial_state(key: Any, hidden: jnp.ndarray) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) @@ -495,6 +497,8 @@ def make_agent( network = make_cournot_network(action_spec, agent_args.hidden_size) elif args.env_id == "Fishery": network = make_fishery_network(action_spec, agent_args.hidden_size) + elif args.env_id == Rice.env_id: + network = make_fishery_network(action_spec, agent_args.hidden_size) else: network = make_ipd_network( action_spec, tabular, agent_args.hidden_size diff --git a/pax/agents/ppo/ppo_gru.py b/pax/agents/ppo/ppo_gru.py index d3f5a9d0..6d68e9ed 100644 --- a/pax/agents/ppo/ppo_gru.py +++ b/pax/agents/ppo/ppo_gru.py @@ -13,8 +13,9 @@ make_GRU_cartpole_network, make_GRU_coingame_network, make_GRU_ipd_network, - make_GRU_ipditm_network, make_GRU_fishery_network, + make_GRU_ipditm_network, make_GRU_fishery_network, make_GRU_rice_network, ) +from pax.envs.rice import Rice from pax.utils import MemoryState, TrainingState, get_advantages # from dm_env import TimeStep @@ -517,7 +518,14 @@ def make_gru_agent( network, initial_hidden_state = make_GRU_ipd_network( action_spec, agent_args.hidden_size ) - + elif args.env_id == "Fishery": + network, initial_hidden_state = make_GRU_fishery_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id == Rice.env_id: + network, initial_hidden_state = make_GRU_rice_network( + action_spec, agent_args.hidden_size + ) elif args.env_id == "InTheMatrix": network, initial_hidden_state = make_GRU_ipditm_network( action_spec, diff --git a/pax/conf/experiment/cournot/gs_v_ppo.yaml b/pax/conf/experiment/cournot/gs_v_ppo.yaml index d86a4838..448ffafa 100644 --- a/pax/conf/experiment/cournot/gs_v_ppo.yaml +++ b/pax/conf/experiment/cournot/gs_v_ppo.yaml @@ -1,8 +1,8 @@ # @package _global_ # Agents -agent1: 'PPO' -agent2: 'PPO' +# Agent default applies to all agents +agent_default: 'PPO' # Environment env_id: Cournot @@ -10,23 +10,23 @@ env_type: meta a: 100 b: 1 marginal_cost: 10 -# This means the optimum quantity is 2(a-marginal_cost)/3b = 60 +# This means the nash quantity is 2(a-marginal_cost)/3b = 60 # Runner -runner: evo +runner: tensor_evo # Training top_k: 5 popsize: 1000 -num_envs: 2 +num_envs: 4 num_opps: 1 -num_outer_steps: 100 -num_inner_steps: 50 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game num_iters: 3000 num_devices: 1 # PPO agent parameters -ppo1: +ppo_default: num_minibatches: 4 num_epochs: 2 gamma: 0.96 @@ -46,28 +46,6 @@ ppo1: with_cnn: False hidden_size: 16 -# PPO agent parameters -ppo2: - num_minibatches: 4 - num_epochs: 2 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: False - entropy_coeff_start: 0.02 - entropy_coeff_horizon: 2000000 - entropy_coeff_end: 0.001 - lr_scheduling: False - learning_rate: 1 - adam_epsilon: 1e-5 - with_memory: False - with_cnn: False - hidden_size: 16 - - # ES parameters es: algo: OpenES # [OpenES, CMA_ES] diff --git a/pax/conf/experiment/cournot/marl_baseline.yaml b/pax/conf/experiment/cournot/marl_baseline.yaml index baceb17f..4fce1d49 100644 --- a/pax/conf/experiment/cournot/marl_baseline.yaml +++ b/pax/conf/experiment/cournot/marl_baseline.yaml @@ -1,8 +1,7 @@ # @package _global_ # Agents -agent1: 'PPO' -agent2: 'PPO' +agent_default: 'PPO' # Environment env_id: Cournot @@ -10,14 +9,14 @@ env_type: sequential a: 100 b: 1 marginal_cost: 10 -# This means the optimum quantity is 2(a-marginal_cost)/3b = 60 -runner: rl +# This means the nash quantity is 2(a-marginal_cost)/3b = 60 +runner: tensor_rl_nplayer # env_batch_size = num_envs * num_opponents num_envs: 20 num_opps: 1 -num_outer_steps: 1 -num_inner_steps: 100 # number of inner steps (only for MetaFinite Env) +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game num_iters: 1e7 # Useful information @@ -25,27 +24,7 @@ num_iters: 1e7 # batch_size % num_minibatches == 0 must hold # PPO agent parameters -ppo1: - num_minibatches: 10 - num_epochs: 4 - gamma: 0.96 - gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 - clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: True - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 0.25e9 - entropy_coeff_end: 0.05 - lr_scheduling: True - learning_rate: 3e-4 - adam_epsilon: 1e-5 - with_memory: True - hidden_size: 16 - with_cnn: False - -ppo2: +ppo_default: num_minibatches: 10 num_epochs: 4 gamma: 0.96 diff --git a/pax/conf/experiment/cournot/mfos_v_ppo.yaml b/pax/conf/experiment/cournot/mfos_v_ppo.yaml index 2b7afc60..3514873a 100644 --- a/pax/conf/experiment/cournot/mfos_v_ppo.yaml +++ b/pax/conf/experiment/cournot/mfos_v_ppo.yaml @@ -12,15 +12,15 @@ b: 1 marginal_cost: 10 # Runner -runner: evo +runner: tensor_evo # Training top_k: 5 popsize: 1000 num_envs: 4 num_opps: 1 -num_outer_steps: 100 -num_inner_steps: 50 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game num_iters: 3000 num_devices: 1 num_steps: '${num_inner_steps}' diff --git a/pax/conf/experiment/cournot/shaper_v_ppo.yaml b/pax/conf/experiment/cournot/shaper_v_ppo.yaml index b3ea5785..de216982 100644 --- a/pax/conf/experiment/cournot/shaper_v_ppo.yaml +++ b/pax/conf/experiment/cournot/shaper_v_ppo.yaml @@ -12,15 +12,15 @@ b: 1 marginal_cost: 10 # Runner -runner: evo +runner: tensor_evo # Training top_k: 5 popsize: 1000 -num_envs: 2 +num_envs: 4 num_opps: 1 -num_outer_steps: 1 -num_inner_steps: 100 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game num_iters: 5000 num_devices: 1 num_steps: '${num_inner_steps}' diff --git a/pax/conf/experiment/fishery/gs_v_ppo.yaml b/pax/conf/experiment/fishery/gs_v_ppo.yaml index 195f3c13..c141f0bd 100644 --- a/pax/conf/experiment/fishery/gs_v_ppo.yaml +++ b/pax/conf/experiment/fishery/gs_v_ppo.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'PPO' -agent2: 'PPO_memory' +agent_default: 'PPO_memory' # Environment env_id: Fishery @@ -16,14 +16,14 @@ s_max: 1.0 # Runner -runner: evo +runner: tensor_evo # Training top_k: 5 popsize: 1000 num_envs: 2 num_opps: 1 -num_outer_steps: 40 +num_outer_steps: 1 num_inner_steps: 300 num_iters: 1000 num_devices: 1 diff --git a/pax/conf/experiment/fishery/marl_baseline.yaml b/pax/conf/experiment/fishery/marl_baseline.yaml index f0cba580..0387cd2f 100644 --- a/pax/conf/experiment/fishery/marl_baseline.yaml +++ b/pax/conf/experiment/fishery/marl_baseline.yaml @@ -15,7 +15,7 @@ s_0: 0.5 s_max: 1.0 # This means the optimum quantity is 2(a-marginal_cost)/3b = 60 -runner: rl +runner: tensor_rl_nplayer # env_batch_size = num_envs * num_opponents num_envs: 100 diff --git a/pax/conf/experiment/rice/ctde.yaml b/pax/conf/experiment/rice/ctde.yaml new file mode 100644 index 00000000..cb701f84 --- /dev/null +++ b/pax/conf/experiment/rice/ctde.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: Rice-v1 +env_type: sequential +num_players: 27 +config_folder: pax/envs/region_yamls +runner: ctde +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 100 +num_iters: 1e7 +save_interval: 100 +num_steps: '${num_inner_steps}' + +# Evaluation +#run_path: ucl-dark/cg/3sp0y2cy +#model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 + +ppo0: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 2.5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + +# Logging setup +wandb: + group: rice + name: 'rice-SARL-${agent1}' diff --git a/pax/conf/experiment/rice/sarl.yaml b/pax/conf/experiment/rice/sarl.yaml new file mode 100644 index 00000000..cb701f84 --- /dev/null +++ b/pax/conf/experiment/rice/sarl.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: Rice-v1 +env_type: sequential +num_players: 27 +config_folder: pax/envs/region_yamls +runner: ctde +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 100 +num_iters: 1e7 +save_interval: 100 +num_steps: '${num_inner_steps}' + +# Evaluation +#run_path: ucl-dark/cg/3sp0y2cy +#model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 + +ppo0: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 2.5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + +# Logging setup +wandb: + group: rice + name: 'rice-SARL-${agent1}' diff --git a/pax/envs/cournot.py b/pax/envs/cournot.py index b3f2b4b9..d5335e25 100644 --- a/pax/envs/cournot.py +++ b/pax/envs/cournot.py @@ -42,9 +42,10 @@ def _step( all_rewards = [] for i in range(num_players): q = actions[i] - r = p * q - params.marginal_cost * q obs = jnp.concatenate([actions, jnp.array([p])]) all_obs.append(obs) + + r = p * q - params.marginal_cost * q all_rewards.append(r) state = EnvState( diff --git a/pax/envs/fishery.py b/pax/envs/fishery.py index adf44c4d..00b724e8 100644 --- a/pax/envs/fishery.py +++ b/pax/envs/fishery.py @@ -46,37 +46,28 @@ def to_obs_array(params: EnvParams) -> jnp.ndarray: class Fishery(environment.Environment): - def __init__(self, num_inner_steps: int): + def __init__(self, num_players: int, num_inner_steps: int): super().__init__() + self.num_players = num_players def _step( key: chex.PRNGKey, state: EnvState, - actions: Tuple[float, float], + actions: Tuple[float, ...], params: EnvParams, ): t = state.inner_t key, _ = jax.random.split(key, 2) - # TODO implement action clipping as part of the runners - e1 = jnp.clip(actions[0].squeeze(), 0) - e2 = jnp.clip(actions[1].squeeze(), 0) - E = e1 + e2 + done = t >= num_inner_steps + + actions = jnp.asarray(actions).squeeze() + actions = jnp.clip(actions, a_min=0) + E = actions.sum() s_growth = state.s + params.g * state.s * (1 - state.s / params.s_max) # Prevent s from dropping below 0 H = jnp.clip(E * state.s * params.e, a_max=s_growth) s_next = s_growth - H - - # reward = benefit - cost - # = P * H - w * E - r1 = params.P * jnp.where(E != 0, e1 / E, 0) * H - params.w * e1 - r2 = params.P * jnp.where(E != 0, e2 / E, 0) * H - params.w * e2 - - obs1 = jnp.concatenate([jnp.array([state.s, e1, e2]), to_obs_array(params)]) - obs2 = jnp.concatenate([jnp.array([state.s, e2, e1]), to_obs_array(params)]) - - done = t >= num_inner_steps - next_state = EnvState( inner_t=state.inner_t + 1, outer_t=state.outer_t, s=s_next @@ -84,21 +75,29 @@ def _step( reset_obs, reset_state = _reset(key, params) reset_state = reset_state.replace(outer_t=state.outer_t + 1) - obs1 = jnp.where(done, reset_obs[0], obs1) - obs2 = jnp.where(done, reset_obs[1], obs2) + all_obs = [] + all_rewards = [] + for i in range(num_players): + obs = jnp.concatenate([actions, jnp.array([s_next])]) + obs = jax.lax.select(done, reset_obs[i], obs) + all_obs.append(obs) + + e = actions[i] + # reward = benefit - cost + # = P * H - w * E + r = jnp.where(E != 0, params.P * e / E * H - params.w * e, 0) + all_rewards.append(r) state = jax.tree_map( lambda x, y: jax.lax.select(done, x, y), reset_state, next_state, ) - r1 = jax.lax.select(done, 0.0, r1) - r2 = jax.lax.select(done, 0.0, r2) return ( - (obs1, obs2), + tuple(all_obs), state, - (r1, r2), + tuple(all_rewards), done, { "H": H, @@ -114,13 +113,18 @@ def _reset( outer_t=jnp.zeros((), dtype=jnp.int16), s=params.s_0 ) - obs = jax.random.uniform(key, (2,)) - obs = jnp.concatenate([jnp.array([state.s]), obs, to_obs_array(params)]) - return (obs, obs), state + obs = jax.random.uniform(key, (num_players,)) + obs = jnp.concatenate([obs, jnp.array([state.s])]) + return tuple([obs for _ in range(num_players)]), state self.step = jax.jit(_step) self.reset = jax.jit(_reset) + @staticmethod + def name() -> str: + """Environment name.""" + return "Fishery" + @property def name(self) -> str: """Environment name.""" @@ -139,7 +143,7 @@ def action_space( def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=7, dtype=jnp.float32) + return spaces.Box(low=0, high=float('inf'), shape=self.num_players + 1, dtype=jnp.float32) @staticmethod def equilibrium(params: EnvParams) -> float: diff --git a/pax/envs/region_yamls/11.yml b/pax/envs/region_yamls/11.yml new file mode 100644 index 00000000..43a27ea8 --- /dev/null +++ b/pax/envs/region_yamls/11.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 1.8724419952820714 + xK_0: 0.239419592 + xL_0: 476.878017 + xL_a: 669.593553 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.13880429539557834 + xg_A: 0.12202134941105497 + xgamma: 0.3 + xl_g: 0.034238352160625596 + xsigma_0: 0.4559257467059924 diff --git a/pax/envs/region_yamls/12.yml b/pax/envs/region_yamls/12.yml new file mode 100644 index 00000000..c1a0a6b7 --- /dev/null +++ b/pax/envs/region_yamls/12.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.405493223457656 + xK_0: 3.30354611 + xL_0: 68.394527 + xL_a: 93.497311 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1880269001436297 + xg_A: 0.10300420806704261 + xgamma: 0.3 + xl_g: 0.05753057218640376 + xsigma_0: 0.5289744017993728 diff --git a/pax/envs/region_yamls/13.yml b/pax/envs/region_yamls/13.yml new file mode 100644 index 00000000..befbd70e --- /dev/null +++ b/pax/envs/region_yamls/13.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.5579000509140952 + xK_0: 0.109143954 + xL_0: 64.122372 + xL_a: 135.074132 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.16127452284439697 + xg_A: 0.12735655631209186 + xgamma: 0.3 + xl_g: 0.02623933488387354 + xsigma_0: 0.8162518983719008 diff --git a/pax/envs/region_yamls/14.yml b/pax/envs/region_yamls/14.yml new file mode 100644 index 00000000..c27b8b91 --- /dev/null +++ b/pax/envs/region_yamls/14.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 1.92663301826947 + xK_0: 1.423908312 + xL_0: 284.698846 + xL_a: 465.307807 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.24445012982169362 + xg_A: 0.1335428337437049 + xgamma: 0.3 + xl_g: 0.024422285778436918 + xsigma_0: 1.220638524516315 diff --git a/pax/envs/region_yamls/15.yml b/pax/envs/region_yamls/15.yml new file mode 100644 index 00000000..95ebe5c8 --- /dev/null +++ b/pax/envs/region_yamls/15.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.111280036435135 + xK_0: 0.268152174 + xL_0: 28.141422 + xL_a: 23.573851 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.16335430971807735 + xg_A: 0.10573757974990125 + xgamma: 0.3 + xl_g: -0.05715547594428186 + xsigma_0: 0.29029694003558093 diff --git a/pax/envs/region_yamls/16.yml b/pax/envs/region_yamls/16.yml new file mode 100644 index 00000000..ea3ef19f --- /dev/null +++ b/pax/envs/region_yamls/16.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 4.217213133650901 + xK_0: 3.18362519 + xL_0: 548.75442 + xL_a: 560.054221 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1703030497846267 + xg_A: 0.09485139864239062 + xgamma: 0.3 + xl_g: 0.08033413573292254 + xsigma_0: 0.3019631318655498 diff --git a/pax/envs/region_yamls/17.yml b/pax/envs/region_yamls/17.yml new file mode 100644 index 00000000..8e79670d --- /dev/null +++ b/pax/envs/region_yamls/17.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.4913586566019945 + xK_0: 0.043635414 + xL_0: 46.488546 + xL_a: 59.987638 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05834835573455604 + xg_A: 0.049004053769246436 + xgamma: 0.3 + xl_g: 0.03709027315241262 + xsigma_0: 0.4196283605267465 diff --git a/pax/envs/region_yamls/18.yml b/pax/envs/region_yamls/18.yml new file mode 100644 index 00000000..1298357c --- /dev/null +++ b/pax/envs/region_yamls/18.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.5248787296824777 + xK_0: 1.080409098 + xL_0: 69.194146 + xL_a: 100.015768 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.3464232696284064 + xg_A: 0.0785686384327884 + xgamma: 0.3 + xl_g: 0.028895835870575235 + xsigma_0: 1.0104732880546095 diff --git a/pax/envs/region_yamls/19.yml b/pax/envs/region_yamls/19.yml new file mode 100644 index 00000000..b6f3dc25 --- /dev/null +++ b/pax/envs/region_yamls/19.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.4596628149703816 + xK_0: 0.183982308 + xL_0: 513.737375 + xL_a: 1867.771496 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 1.8390289471577375 + xg_A: 0.46217845237530203 + xgamma: 0.3 + xl_g: 0.017149514576045286 + xsigma_0: 0.3103140976545981 diff --git a/pax/envs/region_yamls/2.yml b/pax/envs/region_yamls/2.yml new file mode 100644 index 00000000..e68f20ec --- /dev/null +++ b/pax/envs/region_yamls/2.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 12.157936179442062 + xK_0: 2.64167507 + xL_0: 38.101107 + xL_a: 56.990157 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.13084887390535965 + xg_A: 0.06274070897633105 + xgamma: 0.3 + xl_g: 0.020192884216840113 + xsigma_0: 0.35044418275452427 diff --git a/pax/envs/region_yamls/20.yml b/pax/envs/region_yamls/20.yml new file mode 100644 index 00000000..3210b7da --- /dev/null +++ b/pax/envs/region_yamls/20.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 0.9929511285910457 + xK_0: 0.160199062 + xL_0: 522.481879 + xL_a: 1830.325243 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08560686741591728 + xg_A: 0.06506072277236097 + xgamma: 0.3 + xl_g: 0.01902705663391574 + xsigma_0: 0.23517024551671273 diff --git a/pax/envs/region_yamls/21.yml b/pax/envs/region_yamls/21.yml new file mode 100644 index 00000000..6ff60218 --- /dev/null +++ b/pax/envs/region_yamls/21.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 5.000360862831762 + xK_0: 2.289358084859004 + xL_0: 165.293239 + xL_a: 230.19114338372032 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.18278991377259912 + xg_A: 0.07108490122759262 + xgamma: 0.3 + xl_g: 0.026773049602328805 + xsigma_0: 0.4187771240034329 diff --git a/pax/envs/region_yamls/22.yml b/pax/envs/region_yamls/22.yml new file mode 100644 index 00000000..3a3af1db --- /dev/null +++ b/pax/envs/region_yamls/22.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 29.853559456004625 + xK_0: 2.019951041942154 + xL_0: 165.75054 + xL_a: 216.9269455 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757142538145 + xg_A: 0.07541285925058157 + xgamma: 0.3 + xl_g: -0.0024986057450947508 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/region_yamls/23.yml b/pax/envs/region_yamls/23.yml new file mode 100644 index 00000000..60ce3e51 --- /dev/null +++ b/pax/envs/region_yamls/23.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 23.314991608844633 + xK_0: 3.0391651447451187 + xL_0: 109.39535640000001 + xL_a: 143.17178403 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757821836926 + xg_A: 0.07541286104378697 + xgamma: 0.3 + xl_g: -0.002498605753950543 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/region_yamls/24.yml b/pax/envs/region_yamls/24.yml new file mode 100644 index 00000000..db156cb1 --- /dev/null +++ b/pax/envs/region_yamls/24.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 29.853559456004625 + xK_0: 0.6867833542603324 + xL_0: 56.355183600000004 + xL_a: 73.75516147 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757142538145 + xg_A: 0.07541285925058157 + xgamma: 0.3 + xl_g: -0.002498605778464897 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/region_yamls/25.yml b/pax/envs/region_yamls/25.yml new file mode 100644 index 00000000..43ebae03 --- /dev/null +++ b/pax/envs/region_yamls/25.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 10.922004036104973 + xK_0: 0.6059142357084183 + xL_0: 705.464681 + xL_a: 532.496728 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877082389193 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/region_yamls/26.yml b/pax/envs/region_yamls/26.yml new file mode 100644 index 00000000..0dc50d55 --- /dev/null +++ b/pax/envs/region_yamls/26.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 9.633893693772771 + xK_0: 0.6076078389971926 + xL_0: 465.60668946000004 + xL_a: 351.44784048 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877094273714 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/region_yamls/27.yml b/pax/envs/region_yamls/27.yml new file mode 100644 index 00000000..dd580c1b --- /dev/null +++ b/pax/envs/region_yamls/27.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.620918323265558 + xK_0: 0.45330037729157585 + xL_0: 239.85799154000003 + xL_a: 181.04888752000002 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877127171766 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/region_yamls/28.yml b/pax/envs/region_yamls/28.yml new file mode 100644 index 00000000..8906fe2a --- /dev/null +++ b/pax/envs/region_yamls/28.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.1898536850408714 + xK_0: 0.1287514001006796 + xL_0: 690.0021925 + xL_a: 723.512806 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05389348561714375 + xg_A: 0.06812795377170236 + xgamma: 0.3 + xl_g: -0.012597171762552104 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/region_yamls/29.yml b/pax/envs/region_yamls/29.yml new file mode 100644 index 00000000..8e40829c --- /dev/null +++ b/pax/envs/region_yamls/29.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.033527139192083 + xK_0: 0.3810937821808831 + xL_0: 455.40144705 + xL_a: 477.51845196000005 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.053893489919463196 + xg_A: 0.06812795559760368 + xgamma: 0.3 + xl_g: -0.012597171772754604 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/region_yamls/3.yml b/pax/envs/region_yamls/3.yml new file mode 100644 index 00000000..8298c5ec --- /dev/null +++ b/pax/envs/region_yamls/3.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 13.219587477586199 + xK_0: 16.295084052817813 + xL_0: 502.409662 + xL_a: 445.861101 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.25224119422016716 + xg_A: 0.07423569831381745 + xgamma: 0.3 + xl_g: -0.033398145012670695 + xsigma_0: 0.17048017530013193 diff --git a/pax/envs/region_yamls/30.yml b/pax/envs/region_yamls/30.yml new file mode 100644 index 00000000..c7d7c205 --- /dev/null +++ b/pax/envs/region_yamls/30.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.1898536850408714 + xK_0: 0.04377547603423107 + xL_0: 234.60074545 + xL_a: 245.99435404000002 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05389348561714375 + xg_A: 0.06812795377170236 + xgamma: 0.3 + xl_g: -0.012597171800996251 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/region_yamls/4.yml b/pax/envs/region_yamls/4.yml new file mode 100644 index 00000000..aa99fbe7 --- /dev/null +++ b/pax/envs/region_yamls/4.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 6.386787299600672 + xK_0: 1.094110266 + xL_0: 317.880267 + xL_a: 287.533185 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.19401001361417486 + xg_A: 0.23666124530625898 + xgamma: 0.3 + xl_g: -0.052512175141607595 + xsigma_0: 0.8402859337043421 diff --git a/pax/envs/region_yamls/5.yml b/pax/envs/region_yamls/5.yml new file mode 100644 index 00000000..d54301c4 --- /dev/null +++ b/pax/envs/region_yamls/5.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.480556451289923 + xK_0: 0.090493838 + xL_0: 94.484285 + xL_a: 102.997258 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.20277540450432627 + xg_A: 0.20063089079847785 + xgamma: 0.3 + xl_g: 0.036907187009908436 + xsigma_0: 1.6646404809736024 diff --git a/pax/envs/region_yamls/6.yml b/pax/envs/region_yamls/6.yml new file mode 100644 index 00000000..769ac92e --- /dev/null +++ b/pax/envs/region_yamls/6.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 10.852953800501595 + xK_0: 17.553847656 + xL_0: 222.891134 + xL_a: 168.350837 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.005000000000006631 + xg_A: -0.00046726965128841526 + xgamma: 0.3 + xl_g: -0.011976043898247184 + xsigma_0: 0.2851271547872655 diff --git a/pax/envs/region_yamls/7.yml b/pax/envs/region_yamls/7.yml new file mode 100644 index 00000000..f3f99553 --- /dev/null +++ b/pax/envs/region_yamls/7.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 4.135420683261054 + xK_0: 1.00243116 + xL_0: 103.2943 + xL_a: 87.417937 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1577590297489783 + xg_A: 0.12252697231832654 + xgamma: 0.3 + xl_g: -0.06254962519160859 + xsigma_0: 0.6013249328720022 diff --git a/pax/envs/region_yamls/9.yml b/pax/envs/region_yamls/9.yml new file mode 100644 index 00000000..58ded32e --- /dev/null +++ b/pax/envs/region_yamls/9.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.7159049409990375 + xK_0: 1.0340369 + xL_0: 573.818276 + xL_a: 681.210099 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09686897303646223 + xg_A: 0.10149076832484585 + xgamma: 0.3 + xl_g: 0.04313067107477931 + xsigma_0: 0.6378326935010085 diff --git a/pax/envs/region_yamls/default.yml b/pax/envs/region_yamls/default.yml new file mode 100644 index 00000000..dd2b3d81 --- /dev/null +++ b/pax/envs/region_yamls/default.yml @@ -0,0 +1,68 @@ +_DICE_CONSTANT: + xt_0: 2015 # starting year of the whole model + xDelta: 5 # the time interval (year) + xN: 20 # total time steps + + # Climate diffusion parameters + xPhi_T: [ [ 0.8718, 0.0088 ], [ 0.025, 0.975 ] ] + xB_T: [ 0.1005, 0 ] + # xB_T: [0.03, 0] + + # Carbon cycle diffusion parameters (the zeta matrix in the paper) + xPhi_M: [ [ 0.88, 0.196, 0 ], [ 0.12, 0.797, 0.001465 ], [ 0, 0.007, 0.99853488 ] ] + # xB_M: [0.2727272727272727, 0, 0] # 12/44 + xB_M: [ 1.36388, 0, 0 ] # 12/44 + xeta: 3.6813 #?? I don't find where it's used + + xM_AT_1750: 588 # atmospheric mass of carbon in the year of 1750 + xf_0: 0.5 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xf_1: 1 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xt_f: 20 # in Eq 3 time step param to effect of greenhouse gases other than carbon dioxide + xE_L0: 2.6 # 2.6 # in Eq 4 param to the emissions due to land use changes + xdelta_EL: 0.001 # 0.115 # 0.115 # in Eq 4 param to the emissions due to land use changes + + xM_AT_0: 851 # in CAP the atmospheric mass of carbon in the year t + xM_UP_0: 460 # in CAP the atmospheric upper bound of mass of carbon in the year t + xM_LO_0: 1740 # in CAP the atmospheric lower bound of mass of carbon in the year t + xe_0: 35.85 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xq_0: 105.5 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xmu_0: 0.03 # in EI define the initial simga_0: e0/(q0(1-mu0)) + + # From Python implementation PyDICE + xF_2x: 3.6813 # 3.6813 # Forcing that doubles equilibrium carbon. + xT_2x: 3.1 # 3.1 # Equilibrium temperature increase at double carbon eq. + +_RICE_CONSTANT_DEFAULT: + xA_0: 5.115 # in TFP technology at starting point + xK_0: 223 # in CAP initial condition for capital + xL_0: 7403 # in POP population at the staring point + xL_a: 11500 # in POP the expected population at convergence + xa_1: 0 + xa_2: 0.00236 # in CAP Eq 6 + xa_3: 2 # in CAP Eq 6 + xdelta_A: 0.005 # in TFP control the rate of increasing of tech smaller->faster + xg_A: 0.076 # in TFP control the rate of increasing of tech larger->faster + xgamma: 0.3 # in CAP Eq 5 the capital elasticty + xl_g: 0.134 # in POP control the rate to converge + xsigma_0: 0.3503 # e0/(q0(1-mu0)) in EI emission intensity at the starting point + +_RICE_GLOBAL_CONSTANT: + xtheta_2: 2.6 # in CAP Eq 6 + xdelta_K: 0.1 # in CAP Eq 9 param discribe the depreciate of the capital + xalpha: 1.45 # Utility function param + + xrho: 0.015 # discount factor of the utility + + xg_sigma: 0.0025 # 0.0152 # 0.0025 in EI control the rate of mitigation larger->reduce more emission + xdelta_sigma: 0.1 # 0.01 in EI control the rate of mitigation larger->reduce less emission + xp_b: 550 # 550 # in Eq 2 (estimate of the cost of mitigation) represents the price of a backstop technology that can remove carbon dioxide from the atmosphere + xdelta_pb: 0.001 # 0.025 # in Eq 2 control the how the cost of mitigation change through time larger->cost less as time goes by + + xscale_1: 0.030245527 # in Eq 29 Nordhaus scaled cost function param + xscale_2: 10993.704 # in Eq 29 Nordhaus scaled cost function param + + xT_AT_0: 0.85 # in CAP a part of damage function initial condition + xT_LO_0: 0.0068 # in CAP a part of damage function initial condition + r: 0.1 # balance interest rate adjusted for xDelta=5 + + diff --git a/pax/envs/rice.py b/pax/envs/rice.py new file mode 100644 index 00000000..0b998bfd --- /dev/null +++ b/pax/envs/rice.py @@ -0,0 +1,645 @@ +import os +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +import yaml +from gymnax.environments import environment, spaces + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + + # Ecological + global_temperature: chex.ArrayDevice + global_carbon_mass: chex.ArrayDevice + global_exogenous_emissions: float + global_land_emissions: float + + # Economic + labor_all: chex.ArrayDevice + capital_all: chex.ArrayDevice + production_factor_all: chex.ArrayDevice + intensity_all: chex.ArrayDevice + balance_all: chex.ArrayDevice + + # Tariffs are applied to the next time step + future_tariff: chex.ArrayDevice + + # The following values are intermediary values + # that we only track in the state for easier evaluation and logging + gross_output_all: chex.ArrayDevice + investment_all: chex.ArrayDevice + production_all: chex.ArrayDevice + utility_all: chex.ArrayDevice + social_welfare_all: chex.ArrayDevice + capital_depreciation_all: chex.ArrayDevice + mitigation_cost_all: chex.ArrayDevice + consumption_all: chex.ArrayDevice + damages_all: chex.ArrayDevice + abatement_cost_all: chex.ArrayDevice + + +@chex.dataclass +class EnvParams: + pass + + +""" +Based off the MARL environment from https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4189735 +which in turn is an adaptation of the RICE IAM +""" + + +class Rice(environment.Environment): + env_id: str = "Rice-v1" + + def __init__(self, num_inner_steps: int, config_folder: str): + super().__init__() + + # TODO refactor all the constants to use env_params + # 1. Load env params in the experiment.py#env_setup + # 2. type env params as a chex dataclass + # 3. change the references in the code to env params + params, num_regions = load_rice_params(config_folder) + self.num_players = num_regions + self.rice_constant = params["_RICE_GLOBAL_CONSTANT"] + self.dice_constant = params["_DICE_CONSTANT"] + self.region_constants = params["_REGIONS"] + + self.savings_action_n = 1 + self.mitigation_rate_action_n = 1 + # Each region sets max allowed export from own region + self.export_action_n = 1 + # Each region sets import bids (max desired imports from other countries) + # TODO Find an "automatic" model for trade imports + # Reason: it's counter-intuitive that regions would both set import tariffs and imports + self.import_actions_n = self.num_players + # Each region sets import tariffs imposed on other countries + self.tariff_actions_n = self.num_players + + self.actions_n = ( + self.savings_action_n + + self.mitigation_rate_action_n + + self.export_action_n + + self.import_actions_n + + self.tariff_actions_n + ) + + # Determine the index of each action to slice them in the step function + self.savings_action_index = 0 + self.mitigation_rate_action_index = self.savings_action_index + self.savings_action_n + self.export_action_index = self.mitigation_rate_action_index + self.mitigation_rate_action_n + self.tariffs_action_index = self.export_action_index + self.export_action_n + self.desired_imports_action_index = self.tariffs_action_index + self.tariff_actions_n + + # Parameters for armington aggregation utility + self.sub_rate = 0.5 + self.dom_pref = 0.5 + self.for_pref = jnp.asarray([0.5 / (self.num_players - 1)] * self.num_players) + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, + ): + t = state.inner_t + key, _ = jax.random.split(key, 2) + done = t >= num_inner_steps + + t_at = state.global_temperature[0] + global_exogenous_emissions = get_exogenous_emissions( + self.dice_constant["xf_0"], self.dice_constant["xf_1"], self.dice_constant["xt_f"], t + ) + + # TODO clip all rates between 0 and 1 + actions = jnp.asarray(actions).squeeze() + actions = jnp.clip(actions, a_min=0, a_max=1) + + # Intermediary variables + scaled_imports = jnp.zeros((self.num_players, self.num_players)) + gross_output_all = [] + investment_all = [] + production_all = [] + utility_all = [] + social_welfare_all = [] + capital_depreciation_all = [] + mitigation_cost_all = [] + consumption_all = [] + damages_all = [] + abatement_cost_all = [] + + # Next state variables + next_labor_all = [] + next_capital_all = [] + next_production_factor_all = [] + next_intensity_all = [] + next_balance_all = [] + + for i in range(self.num_players): + savings = actions[i, self.savings_action_index] + mitigation_rate = actions[i, self.mitigation_rate_action_index] + + intensity = state.intensity_all[i] + production_factor = state.production_factor_all[i] + capital = state.capital_all[i] + labor = state.labor_all[i] + gov_balance = state.balance_all[i] + + region_const = self.region_constants[i] + + mitigation_cost = get_mitigation_cost( + self.rice_constant["xp_b"], + self.rice_constant["xtheta_2"], + self.rice_constant["xdelta_pb"], + intensity, + t + ) + mitigation_cost_all.append(mitigation_cost) + damages = get_damages(t_at, region_const["xa_1"], region_const["xa_2"], region_const["xa_3"]) + damages_all.append(damages) + abatement_cost = get_abatement_cost( + mitigation_rate, mitigation_cost, self.rice_constant["xtheta_2"] + ) + abatement_cost_all.append(abatement_cost) + production = get_production( + production_factor, + capital, + labor, + region_const["xgamma"], + ) + production_all.append(production) + + gross_output = get_gross_output(damages, abatement_cost, production) + gross_output_all.append(gross_output) + gov_balance = gov_balance * (1 + self.rice_constant["r"]) + next_balance_all.append(gov_balance) + investment = get_investment(savings, gross_output) + investment_all.append(investment) + + # Trade + desired_imports = actions[i, + self.desired_imports_action_index: self.desired_imports_action_index + self.import_actions_n] + # Countries cannot import from themselves + desired_imports.at[i].set(0) + total_desired_imports = desired_imports.sum() + clipped_desired_imports = jnp.clip(total_desired_imports, 0, gross_output) + desired_imports = desired_imports * clipped_desired_imports / total_desired_imports + + # Scale imports based on gov balance + init_capital_multiplier = 10.0 # TODO: Why this number? + # TODO missing paranthesis + debt_ratio = gov_balance / init_capital_multiplier * region_const["xK_0"] + debt_ratio = jnp.clip(debt_ratio, -1.0, 0.0) + desired_imports *= 1 + debt_ratio + scaled_imports.at[i].set(desired_imports) + + # TODO this loop can be vectorized + # - get max potential exports as a vector + # - rescale along each axis + # Second loop to calculate actual exports knowing the imports + for i in range(self.num_players): + gross_output = gross_output_all[i] + export_limit = actions[i, self.export_action_index] + investment = investment_all[i] + + # scale desired imports according to max exports + max_potential_exports = get_max_potential_exports( + export_limit, gross_output, investment + ) + total_desired_exports = jnp.sum(scaled_imports[:, i]) + clipped_desired_exports = jnp.clip(total_desired_exports, 0, max_potential_exports) + scaled_imports.at[:, i].set(scaled_imports[:, i] * clipped_desired_exports / total_desired_exports) + + tariffed_imports = jnp.zeros((self.num_players, self.num_players)) + prev_tariffs = state.future_tariff + # Third loop to calculate tariffs and welfare + for i in range(self.num_players): + gross_output = gross_output_all[i] + investment = investment_all[i] + labor = state.labor_all[i] + + # calculate tariffed imports, tariff revenue and budget balance + tariffed_imports.at[i].set(scaled_imports[i] * (1 - state.future_tariff[i])) + + tariff_revenue = jnp.sum( + scaled_imports[i, :] * prev_tariffs[i, :] + ) + # Aggregate consumption from domestic and foreign goods + # domestic consumption + c_dom = get_consumption(gross_output, investment, exports=scaled_imports[:, i]) + + consumption = get_armington_agg( + c_dom=c_dom, + c_for=tariffed_imports[i, :], + sub_rate=self.sub_rate, + dom_pref=self.dom_pref, + for_pref=self.for_pref, + ) + consumption_all.append(consumption) + + utility = get_utility(labor, consumption, self.rice_constant["xalpha"]) + utility_all.append(utility) + social_welfare_all.append(get_social_welfare( + utility, self.rice_constant["xrho"], self.dice_constant["xDelta"], t + )) + + # Update government balance + # TODO add tariff revenue?? + next_balance_all[i] = (next_balance_all[i] + + self.dice_constant["xDelta"] * ( + jnp.sum(scaled_imports[:, i]) + - jnp.sum(scaled_imports[i, :]) + )) + + # Update ecology + m_at = state.global_carbon_mass[0] + global_temperature = get_global_temperature( + self.dice_constant["xPhi_T"], + state.global_temperature, + self.dice_constant["xB_T"], + self.dice_constant["xF_2x"], + m_at, + self.dice_constant["xM_AT_1750"], + global_exogenous_emissions, + ) + + # TODO it should be possible to vectorize this + aux_m_all = jnp.zeros(self.num_players) + global_land_emissions = get_land_emissions( + self.dice_constant["xE_L0"], self.dice_constant["xdelta_EL"], t, self.num_players + ) + for i in range(self.num_players): + mitigation_rate = actions[i, self.mitigation_rate_action_index] + aux_m_all.at[i].set(get_aux_m( + state.intensity_all[i], + mitigation_rate, + production_all[i], + global_land_emissions + )) + + global_carbon_mass = get_global_carbon_mass( + self.dice_constant["xPhi_M"], + state.global_carbon_mass, + self.dice_constant["xB_M"], + jnp.sum(aux_m_all), + ) + + for i in range(self.num_players): + region_const = self.region_constants[i] + capital_depreciation = get_capital_depreciation( + self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] + ) + capital_depreciation_all.append(capital_depreciation) + + next_capital_all.append(get_capital( + capital_depreciation, state.capital_all[i], + self.dice_constant["xDelta"], + investment_all[i] + )) + next_labor_all.append( + get_labor(state.labor_all[i], region_const["xL_a"], region_const["xl_g"])) + next_production_factor_all.append(get_production_factor( + state.production_factor_all[i], + region_const["xg_A"], + region_const["xdelta_A"], + self.dice_constant["xDelta"], + t, + )) + next_intensity_all.append(get_carbon_intensity( + state.intensity_all[i], + region_const["xsigma_0"], + self.rice_constant["xdelta_sigma"], + self.dice_constant["xDelta"], + t + )) + + next_state = EnvState( + inner_t=state.inner_t + 1, outer_t=state.outer_t, + global_temperature=global_temperature, + global_carbon_mass=global_carbon_mass, + global_exogenous_emissions=global_exogenous_emissions, + global_land_emissions=global_land_emissions, + + labor_all=jnp.asarray(next_labor_all), + capital_all=jnp.asarray(next_capital_all), + production_factor_all=jnp.asarray(next_production_factor_all), + intensity_all=jnp.asarray(next_intensity_all), + balance_all=jnp.asarray(next_balance_all), + + future_tariff=actions[:, self.tariffs_action_index: self.tariffs_action_index + self.num_players], + + gross_output_all=jnp.asarray(gross_output_all), + investment_all=jnp.asarray(investment_all), + production_all=jnp.asarray(production_all), + utility_all=jnp.asarray(utility_all), + social_welfare_all=jnp.asarray(social_welfare_all), + capital_depreciation_all=jnp.asarray(capital_depreciation_all), + mitigation_cost_all=jnp.asarray(mitigation_cost_all), + consumption_all=jnp.asarray(consumption_all), + damages_all=jnp.asarray(damages_all), + abatement_cost_all=jnp.asarray(abatement_cost_all), + ) + + reset_obs, reset_state = _reset(key, params) + reset_state = reset_state.replace(outer_t=state.outer_t + 1) + + obs = [] + for i in range(self.num_players): + obs.append(self._generate_observation(i, actions, next_state)) + obs = jax.tree_map(lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs)) + + state = jax.tree_map( + lambda x, y: jnp.where(done, x, y), + reset_state, + next_state, + ) + + return ( + tuple(obs), + state, + tuple(state.utility_all), + done, + {}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[Tuple, EnvState]: + state = self._get_initial_state() + actions = jnp.asarray([jax.random.uniform(key, (self.num_actions,)) for _ in range(self.num_players)]) + return tuple([self._generate_observation(i, actions, state) for i in range(self.num_players)]), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + def _get_initial_state(self) -> EnvState: + return EnvState( + inner_t=jnp.zeros((), dtype=jnp.int16), + outer_t=jnp.zeros((), dtype=jnp.int16), + global_temperature=jnp.array([self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]]), + global_carbon_mass=jnp.array( + [self.dice_constant["xM_AT_0"], self.dice_constant["xM_UP_0"], self.dice_constant["xM_LO_0"]]), + global_exogenous_emissions=jnp.zeros((), dtype=jnp.float32), + global_land_emissions=jnp.zeros((), dtype=jnp.float32), + labor_all=jnp.array( + [self.region_constants[region_id]["xL_0"] for region_id in range(self.num_players)]), + capital_all=jnp.array( + [self.region_constants[region_id]["xK_0"] for region_id in range(self.num_players)]), + production_factor_all=jnp.array( + [self.region_constants[region_id]["xA_0"] for region_id in range(self.num_players)]), + intensity_all=jnp.array( + [self.region_constants[region_id]["xsigma_0"] for region_id in range(self.num_players)]), + + balance_all=jnp.zeros(self.num_players, dtype=jnp.float32), + future_tariff=jnp.zeros((self.num_players, self.num_players), dtype=jnp.float32), + + gross_output_all=jnp.zeros(self.num_players, dtype=jnp.float32), + investment_all=jnp.zeros(self.num_players, dtype=jnp.float32), + production_all=jnp.zeros(self.num_players, dtype=jnp.float32), + utility_all=jnp.zeros(self.num_players, dtype=jnp.float32), + social_welfare_all=jnp.zeros(self.num_players, dtype=jnp.float32), + capital_depreciation_all=jnp.zeros(self.num_players, dtype=jnp.float32), + mitigation_cost_all=jnp.zeros(self.num_players, dtype=jnp.float32), + consumption_all=jnp.zeros(self.num_players, dtype=jnp.float32), + damages_all=jnp.zeros(self.num_players, dtype=jnp.float32), + abatement_cost_all=jnp.zeros(self.num_players, dtype=jnp.float32), + ) + + def _generate_observation(self, index: int, actions: chex.ArrayDevice, state: EnvState): + return jnp.concatenate([ + # Public features + jnp.asarray([index]), + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.capital_all, + state.labor_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + # Private features + jnp.asarray([state.production_factor_all[index]]), + jnp.asarray([state.intensity_all[index]]), + jnp.asarray([state.mitigation_cost_all[index]]), + jnp.asarray([state.damages_all[index]]), + jnp.asarray([state.abatement_cost_all[index]]), + jnp.asarray([state.production_all[index]]), + jnp.asarray([state.utility_all[index]]), + jnp.asarray([state.social_welfare_all[index]]), + # All agent actions + actions.ravel() + ]) + + @property + def name(self) -> str: + return self.env_id + + @property + def num_actions(self) -> int: + return self.actions_n + + def action_space( + self, params: Optional[EnvParams] = None + ) -> spaces.Box: + return spaces.Box(low=0, high=1, shape=(self.actions_n,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + init_state = self._get_initial_state() + obs = self._generate_observation(0, init_state) + return spaces.Box(low=0, high=float('inf'), shape=obs.shape, dtype=jnp.float32) + + +def load_rice_params(config_dir=None): + """Helper function to read yaml data and set environment configs.""" + assert config_dir is not None + base_params = load_yaml_data(os.path.join(config_dir, "default.yml")) + file_list = sorted(os.listdir(config_dir)) # + yaml_files = [] + for file in file_list: + if file[-4:] == ".yml" and file != "default.yml": + yaml_files.append(file) + + region_params = [] + for file in yaml_files: + region_params.append(load_yaml_data(os.path.join(config_dir, file))) + + # Overwrite rice params) + base_params["_REGIONS"] = [] + for idx, param in enumerate(region_params): + region_to_append = param["_RICE_CONSTANT"] + for k in base_params["_RICE_CONSTANT_DEFAULT"].keys(): + if k not in region_to_append.keys(): + region_to_append[k] = base_params["_RICE_CONSTANT_DEFAULT"][k] + base_params["_REGIONS"].append(region_to_append) + + return base_params, len(region_params) + + +def get_exogenous_emissions(f_0, f_1, t_f, timestep): + return f_0 + jnp.min(jnp.array([f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1)])) + + +def get_land_emissions(e_l0, delta_el, timestep, num_regions): + return e_l0 * pow(1 - delta_el, timestep - 1) / num_regions + + +def get_mitigation_cost(p_b, theta_2, delta_pb, intensity, timestep): + return p_b / (1000 * theta_2) * pow(1 - delta_pb, timestep - 1) * intensity + + +def get_damages(t_at, a_1, a_2, a_3): + return 1 / (1 + a_1 * t_at + a_2 * pow(t_at, a_3)) + + +def get_abatement_cost(mitigation_rate, mitigation_cost, theta_2): + return mitigation_cost * pow(mitigation_rate, theta_2) + + +def get_production(production_factor, capital, labor, gamma): + """Obtain the amount of goods produced.""" + return production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma) + + +def get_gross_output(damages, abatement_cost, production): + """Compute the gross production output, taking into account + damages and abatement cost.""" + return damages * (1 - abatement_cost) * production + + +def get_investment(savings, gross_output): + return savings * gross_output + + +def get_consumption(gross_output, investment, exports): + total_exports = jnp.sum(exports) + return jnp.max(jnp.asarray([0.0, gross_output - investment - total_exports])) + + +def get_max_potential_exports(x_max, gross_output, investment): + return jnp.min(jnp.array([x_max * gross_output, gross_output - investment])) + + +def get_capital_depreciation(x_delta_k, x_delta): + """Compute the global capital depreciation.""" + return pow(1 - x_delta_k, x_delta) + + +# Returns shape 2 +def get_global_temperature( + phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions +): + return jnp.dot(phi_t, temperature) + jnp.dot( + b_t, f_2x * jnp.log(m_at / m_at_1750) / jnp.log(2) + exogenous_emissions + ) + + +def get_aux_m(intensity, mitigation_rate, production, land_emissions): + """Auxiliary variable to denote carbon mass levels.""" + return intensity * (1 - mitigation_rate) * production + land_emissions + + +def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m): + """Get the carbon mass level.""" + return jnp.dot(phi_m, carbon_mass) + jnp.dot(b_m, aux_m) + + +def get_capital(capital_depreciation, capital, delta, investment): + """Evaluate capital.""" + return capital_depreciation * capital + delta * investment + + +def get_labor(labor, l_a, l_g): + """Compute total labor.""" + return labor * pow((1 + l_a) / (1 + labor), l_g) + + +def get_production_factor(production_factor, g_a, delta_a, delta, timestep): + """Compute the production factor.""" + return production_factor * ( + jnp.exp(0.0033) + g_a * jnp.exp(-delta_a * delta * (timestep - 1)) + ) + + +def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep): + """Determine the carbon emission intensity.""" + return intensity * jnp.exp( + -g_sigma * pow(1 - delta_sigma, delta * (timestep - 1)) * delta + ) + + +_SMALL_NUM = 1e-0 + + +def get_utility(labor, consumption, alpha): + return ( + (labor / 1000.0) + * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1) + / (1 - alpha) + ) + + +def get_social_welfare(utility, rho, delta, timestep): + return utility / pow(1 + rho, delta * timestep) + + +def get_armington_agg( + c_dom, + c_for, # np.array + sub_rate=0.5, # in (0,1) + dom_pref=0.5, # in [0,1] + for_pref=None, # np.array +): + """ + Armington aggregate from Lessmann, 2009. + Consumption goods from different regions act as imperfect substitutes. + As such, consumption of domestic and foreign goods are scaled according to + relative preferences, as well as a substitution rate, which are modeled + by a CES functional form. + Inputs : + `C_dom` : A scalar representing domestic consumption. The value of + C_dom is what is left over from initial production after + investment and exports are deducted. + `C_for` : An array reprensenting foreign consumption. Each element + is the consumption imported from a given country. + `sub_rate` : A substitution parameter in (0,1). The elasticity of + substitution is 1 / (1 - sub_rate). + `dom_pref` : A scalar in [0,1] representing the relative preference for + domestic consumption over foreign consumption. + `for_pref` : An array of the same size as `C_for`. Each element is the + relative preference for foreign goods from that country. + """ + + c_dom_pref = dom_pref * (c_dom ** sub_rate) + c_for_pref = jnp.sum(for_pref * pow(c_for, sub_rate)) + + c_agg = (c_dom_pref + c_for_pref) ** (1 / sub_rate) # CES function + return c_agg + + +def load_yaml_data(yaml_file: str): + """Helper function to read yaml configuration data.""" + with open(yaml_file, "r", encoding="utf-8") as file_ptr: + file_data = file_ptr.read() + data = yaml.load(file_data, Loader=yaml.FullLoader) + return rec_array_conversion(data) + + +def rec_array_conversion(data): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, list): + data[key] = jnp.asarray(value) + elif isinstance(value, dict): + data[key] = rec_array_conversion(value) + elif isinstance(data, list): + data = jnp.asarray(data) + return data diff --git a/pax/envs/rice_n.py b/pax/envs/rice_n.py deleted file mode 100644 index 6dd451d6..00000000 --- a/pax/envs/rice_n.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Optional, Tuple - -import chex -import jax -import jax.debug -import jax.numpy as jnp -from gymnax.environments import environment, spaces - - -@chex.dataclass -class EnvState: - inner_t: int - outer_t: int - - activity_step: int - - # Ecological - global_temp: chex.ArrayDevice - global_carbon_mass: chex.ArrayDevice - global_exogenous_emissions: float - global_land_emissions: float - - # Economic - labor_all_regions: chex.ArrayDevice - production_factor_all_regions: chex.ArrayDevice - intensity_all_regions: chex.ArrayDevice - - - - -@chex.dataclass -class EnvParams: - g: float - e: float - P: float - w: float - s_0: float - s_max: float - - -def to_obs_array(params: EnvParams) -> jnp.ndarray: - return jnp.array([params.g, params.e, params.P, params.w]) - - -""" -Based off the MARL environment from https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4189735 -which in turn is adapted from the RICE IAM -""" - - -class RiceN(environment.Environment): - def __init__(self, num_inner_steps: int): - super().__init__() - - def _step( - key: chex.PRNGKey, - state: EnvState, - actions: Tuple[float, float], - params: EnvParams, - ): - t = state.inner_t - key, _ = jax.random.split(key, 2) - - done = t >= num_inner_steps - - next_state = EnvState( - inner_t=state.inner_t + 1, outer_t=state.outer_t, - ) - reset_obs, reset_state = _reset(key, params) - reset_state = reset_state.replace(outer_t=state.outer_t + 1) - - obs1 = jnp.where(done, reset_obs[0], obs1) - obs2 = jnp.where(done, reset_obs[1], obs2) - - state = jax.tree_map( - lambda x, y: jax.lax.select(done, x, y), - reset_state, - next_state, - ) - r1 = jax.lax.select(done, 0.0, r1) - r2 = jax.lax.select(done, 0.0, r2) - - return ( - (obs1, obs2), - state, - (r1, r2), - done, - { - "H": H, - "E": E, - }, - ) - - def _reset( - key: chex.PRNGKey, params: EnvParams - ) -> Tuple[Tuple, EnvState]: - state = EnvState( - inner_t=jnp.zeros((), dtype=jnp.int16), - outer_t=jnp.zeros((), dtype=jnp.int16), - s=params.s_0 - ) - obs = jax.random.uniform(key, (2,)) - obs = jnp.concatenate([jnp.array([state.s]), obs, to_obs_array(params)]) - return (obs, obs), state - - self.step = jax.jit(_step) - self.reset = jax.jit(_reset) - - @property - def name(self) -> str: - """Environment name.""" - return "Fishery-v1" - - @property - def num_actions(self) -> int: - """Number of actions possible in environment.""" - return 1 - - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: - """Action space of the environment.""" - return spaces.Box(low=0, high=params.s_max, shape=(1,)) - - def observation_space(self, params: EnvParams) -> spaces.Box: - """Observation space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=7, dtype=jnp.float32) - - @staticmethod - def equilibrium(params: EnvParams) -> float: - return params.s_max * (1 - params.g / params.e / params.P) diff --git a/pax/envs/sarl_rice.py b/pax/envs/sarl_rice.py new file mode 100644 index 00000000..8d1eaa4e --- /dev/null +++ b/pax/envs/sarl_rice.py @@ -0,0 +1,65 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +from gymnax.environments import environment, spaces + +from pax.envs.rice import Rice, EnvState, EnvParams + +""" +Wrapper to turn Rice into a single-agent environment. +""" + + +class SarlRice(environment.Environment): + env_id: str = "SarlRice-v1" + + def __init__(self, num_inner_steps: int, config_folder: str): + super().__init__() + self.rice = Rice(num_inner_steps, config_folder) + + def _step( + key: chex.PRNGKey, + state: EnvState, + action: chex.Array, + params: EnvParams, + ): + actions = jnp.split(action, self.rice.num_players) + obs, state, rewards, done, info = self.rice.step(key, state, tuple(actions), params) + + return ( + jnp.concatenate(obs), + state, + jnp.asarray(rewards).sum(), + done, + info, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + obs, state = self.rice.reset(key, params) + return jnp.asarray(obs), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + @property + def name(self) -> str: + return self.env_id + + @property + def num_actions(self) -> int: + return self.rice.num_actions * self.rice.num_players + + def action_space( + self, params: Optional[EnvParams] = None + ) -> spaces.Box: + return spaces.Box(low=0, high=1, shape=(self.num_actions,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + obs_space = self.rice.observation_space(params) + obs_space.shape = (obs_space.shape[0] * self.rice.num_players,) + return obs_space diff --git a/pax/experiment.py b/pax/experiment.py index 774dd7df..61df6a7f 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -54,10 +54,13 @@ from pax.envs.iterated_tensor_game_n_player import ( EnvParams as IteratedTensorGameNPlayerParams, ) +from pax.envs.rice import Rice, EnvParams as RiceParams +from pax.envs.sarl_rice import SarlRice +from pax.runners.runner_ctde import CTDERunner from pax.runners.runner_eval import EvalRunner from pax.runners.runner_eval_nplayer import NPlayerEvalRunner from pax.runners.runner_evo import EvoRunner -from pax.runners.runner_evo_nplayer import TensorEvoRunner +from pax.runners.runner_evo_nplayer import NPlayerEvoRunner from pax.runners.runner_marl import RLRunner from pax.runners.runner_marl_nplayer import NplayerRLRunner from pax.runners.runner_sarl import SARLRunner @@ -208,12 +211,32 @@ def env_setup(args, logger=None): s_max=args.s_max, ) env = Fishery( - num_inner_steps=args.num_inner_steps, + num_players=args.num_players, num_inner_steps=args.num_inner_steps, ) if logger: logger.info( f"Env Type: Fishery | Inner Episode Length: {args.num_inner_steps}" ) + elif args.env_id == Rice.env_id: + env_params = RiceParams() + env = Rice( + num_inner_steps=args.num_inner_steps, + config_folder=args.config_folder, + ) + if logger: + logger.info( + f"Env Type: Rice | Inner Episode Length: {args.num_inner_steps}" + ) + elif args.env_id == SarlRice.env_id: + env_params = RiceParams() + env = SarlRice( + num_inner_steps=args.num_inner_steps, + config_folder=args.config_folder, + ) + if logger: + logger.info( + f"Env Type: SarlRice | Inner Episode Length: {args.num_inner_steps}" + ) elif args.runner == "sarl": env, env_params = gymnax.make(args.env_id) else: @@ -332,7 +355,7 @@ def get_pgpe_strategy(agent): args, ) elif args.runner == "tensor_evo": - return TensorEvoRunner( + return NPlayerEvoRunner( agents, env, strategy, @@ -350,6 +373,9 @@ def get_pgpe_strategy(agent): elif args.runner == "sarl": logger.info("Training with SARL Runner") return SARLRunner(agents, env, save_dir, args) + elif args.runner == "ctde": + logger.info("Training with CTDE Runner") + return CTDERunner(agents, env, save_dir, args) else: raise ValueError(f"Unknown runner type {args.runner}") @@ -508,23 +534,18 @@ def get_stay_agent(seed, player_id): "HyperTFT": partial(HyperTFT, args.num_envs), } - if args.runner == "sarl": + if args.runner in ["sarl", "ctde"]: assert args.agent1 in strategies - num_agents = 1 seeds = [args.seed] # Create Player IDs by normalizing seeds to 1, 2 respectively pids = [0] agent_1 = strategies[args.agent1](seeds[0], pids[0]) # player 1 - - if args.agent1 in ["PPO", "PPO_memory"] and args.ppo.with_cnn: - logger.info(f"PPO with CNN: {args.ppo.with_cnn}") logger.info(f"Agent Pair: {args.agent1}") logger.info(f"Agent seeds: {seeds[0]}") if args.runner in ["eval", "sarl"]: logger.info("Using Independent Learners") - return agent_1 - + return agent_1 else: default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None) agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in @@ -660,7 +681,7 @@ def naive_pg_log(agent): "MFOS_pretrained": dumb_log, } - if args.runner == "sarl": + if args.runner in ["sarl", "ctde"]: assert args.agent1 in strategies agent_1_log = naive_pg_log # strategies[args.agent1] # @@ -728,6 +749,9 @@ def main(args): elif args.runner == "sarl": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) + elif args.runner == "ctde": + print(f"Number of Episodes: {args.num_iters}") + runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) elif args.runner == "eval": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) diff --git a/pax/runners/runner_ctde.py b/pax/runners/runner_ctde.py new file mode 100644 index 00000000..c297809a --- /dev/null +++ b/pax/runners/runner_ctde.py @@ -0,0 +1,251 @@ +import os +import time +from typing import Any, NamedTuple, List, Tuple + +import jax +import jax.numpy as jnp +import wandb + +from pax.utils import MemoryState, TrainingState, save + +# from jax.config import config +# config.update('jax_disable_jit', True) + +MAX_WANDB_CALLS = 1000000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class CTDERunner: + """Holds the runner's state.""" + + def __init__(self, agent, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.random.split, (0, None)) + # set up agent + if args.agent1 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent.batch_init = jax.jit(jax.vmap(agent.make_initial_state)) + else: + # batch MemoryState not TrainingState + agent.batch_init = jax.jit(agent.make_initial_state) + + agent.batch_reset = jax.jit(agent.reset_memory, static_argnums=1) + + agent.batch_policy = jax.jit(agent._policy) + + if args.agent1 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent._mem.hidden, (1)) + agent._state, agent._mem = agent.batch_init( + agent._state.random_key, init_hidden + ) + + def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]: + """Runner for inner episode""" + ( + rngs, + obs, + a1_state, + memories, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 2) + env_rng = rngs[:, 0, :] + rngs = rngs[:, 1, :] + + actions = [] + new_memories = [] + for i in range(args.num_players): + a1, a1_state, new_a1_mem = agent.batch_policy( + a1_state, + obs[i], + memories[i], + ) + actions.append(a1) + new_memories.append(new_a1_mem) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + tuple(actions), + env_params, + ) + + trajectories = [ + Sample(observation, + action, + reward * jnp.logical_not(done), + memory.extras["log_probs"], + memory.extras["values"], + done, + memory.hidden) for observation, action, reward, memory in + zip(obs, actions, rewards, memories)] + + return ( + rngs, + next_obs, + a1_state, + new_memories, + env_state, + env_params, + ), trajectories + + def _rollout( + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _memories: List[MemoryState], + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + ).reshape((args.num_envs, -1)) + + obs, env_state = env.reset(rngs, _env_params) + _memories = [agent.batch_reset(mem, False) for mem in _memories] + + # run trials + vals, trajectories = jax.lax.scan( + _inner_rollout, + ( + rngs, + obs, + _a1_state, + _memories, + env_state, + _env_params, + ), + None, + length=args.num_steps, + ) + + ( + rngs, + obs, + _a1_state, + _memories, + env_state, + env_params, + ) = vals + + for sample in trajectories: + _a1_state, _, _a1_metrics = agent.update( + sample, + sample.observations, + _a1_state, + sample, + ) + + # reset memory + _memories = agent.batch_reset(_memories, False) + + # Stats + rewards = jnp.sum(traj.rewards) / (jnp.sum(traj.dones) + 1e-8) + env_stats = {} + + return ( + env_stats, + rewards, + _a1_state, + _memories, + _a1_metrics, + ) + + self.rollout = _rollout + # self.rollout = jax.jit(_rollout) + + def run_loop(self, env, env_params, agent, num_iters, watcher): + """Run training of agent in environment""" + print("Training") + print("-----------------------") + agent = agent + rng, _ = jax.random.split(self.random_key) + + a1_state, a1_mem = agent._state, agent._mem + + num_iters = max(int(num_iters / (self.args.num_envs)), 1) + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + + print(f"Log Interval {log_interval}") + print(f"Running for total iterations: {num_iters}") + # run actual loop + for i in range(num_iters): + rng, rng_run = jax.random.split(rng, 2) + memories = tuple([a1_mem for _ in range(self.args.num_players)]) + # RL Rollout + ( + env_stats, + rewards_1, + a1_state, + a1_mem, + a1_metrics, + ) = self.rollout(rng_run, a1_state, memories, env_params) + + if i % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"iteration_{i}") + save(a1_state.params, log_savepath) + if watcher: + print(f"Saving iteration {i} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {i} locally") + + # logging + self.train_episodes += 1 + if num_iters % log_interval == 0: + print(f"Episode {i}") + + print(f"Env Stats: {env_stats}") + print(f"Total Episode Reward: {float(rewards_1.mean())}") + print() + + if watcher: + # metrics [outer_timesteps] + flattened_metrics_1 = jax.tree_util.tree_map( + lambda x: jnp.mean(x), a1_metrics + ) + agent._logger.metrics = ( + agent._logger.metrics | flattened_metrics_1 + ) + + watcher(agent) + wandb.log( + { + "episodes": self.train_episodes, + "train/episode_reward/player_1": float( + rewards_1.mean() + ), + } + | env_stats, + ) + + agent._state = a1_state + return agent diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index 5a811d00..ffdb1989 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -415,9 +415,9 @@ def _rollout( env_stats = jax.tree_util.tree_map( lambda x: x.mean(), self.cournot_stats( - traj_1, - traj_2, + traj_1.observations, _env_params, + 2 ), ) elif args.env_id == "Fishery": diff --git a/pax/runners/runner_evo_nplayer.py b/pax/runners/runner_evo_nplayer.py index 304b500a..718ddbb0 100644 --- a/pax/runners/runner_evo_nplayer.py +++ b/pax/runners/runner_evo_nplayer.py @@ -2,17 +2,19 @@ import time from datetime import datetime from typing import Any, Callable, NamedTuple - +from functools import partial import jax import jax.numpy as jnp from evosax import FitnessShaper - +from omegaconf import OmegaConf import wandb from pax.utils import MemoryState, TrainingState, save # TODO: import when evosax library is updated # from evosax.utils import ESLog -from pax.watchers import ESLog, cg_visitation, n_player_ipd_visitation +from pax.watchers import ESLog, n_player_ipd_visitation +from pax.watchers.cournot import cournot_stats +from pax.watchers.fishery import fishery_stats MAX_WANDB_CALLS = 1000 @@ -29,9 +31,9 @@ class Sample(NamedTuple): hiddens: jnp.ndarray -class TensorEvoRunner: +class NPlayerEvoRunner: """ - Evoluationary Strategy runner provides a convenient example for quickly writing + Evolutionary Strategy runner provides a convenient example for quickly writing a MARL runner for PAX. The EvoRunner class can be used to run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. It composes together agents, watchers, and the environment. @@ -55,7 +57,7 @@ class TensorEvoRunner: """ def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args + self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): self.args = args self.algo = args.es.algo @@ -72,8 +74,10 @@ def __init__( self.top_k = args.top_k self.train_steps = 0 self.train_episodes = 0 - self.ipd_stats = jax.jit(n_player_ipd_visitation) - self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + # TODO JIT this + self.ipd_stats = n_player_ipd_visitation + self.cournot_stats = jax.jit(cournot_stats) + self.fishery_stats = jax.jit(fishery_stats) # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) # Evo Runner also has an additional pmap dim (num_devices, ...) @@ -103,8 +107,7 @@ def __init__( ) self.num_outer_steps = args.num_outer_steps - - agent1, agent2, agent3 = agents + agent1, *other_agents = agents # vmap agents accordingly # agent 1 is batched over popsize and num_opps @@ -128,90 +131,68 @@ def __init__( jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), ) ) + # go through opponents, we start with agent2 + for agent_idx, non_first_agent in enumerate(other_agents): + agent_arg = f"agent{agent_idx+2}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + # special case where NaiveEx has a different call signature + non_first_agent.batch_init = jax.jit( + jax.vmap(jax.vmap(non_first_agent.make_initial_state)) + ) + else: + non_first_agent.batch_init = jax.jit( + jax.vmap( + jax.vmap( + non_first_agent.make_initial_state, (0, None), 0 + ), + (0, None), + 0, + ) + ) - if args.agent2 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent2.batch_init = jax.jit( - jax.vmap(jax.vmap(agent2.make_initial_state)) + non_first_agent.batch_policy = jax.jit( + jax.vmap(jax.vmap(non_first_agent._policy, 0, 0)) ) - else: - agent2.batch_init = jax.jit( + non_first_agent.batch_reset = jax.jit( jax.vmap( - jax.vmap(agent2.make_initial_state, (0, None), 0), + jax.vmap(non_first_agent.reset_memory, (0, None), 0), (0, None), 0, - ) - ) - - agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) - agent2.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent2.batch_update = jax.jit( - jax.vmap( - jax.vmap(agent2.update, (1, 0, 0, 0)), - (1, 0, 0, 0), - ) - ) - if args.agent2 != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - - a2_rng = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - - agent2._state, agent2._mem = agent2.batch_init( - a2_rng, - init_hidden, + ), + static_argnums=1, ) - if args.agent3 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent3.batch_init = jax.jit( - jax.vmap(jax.vmap(agent3.make_initial_state)) - ) - else: - agent3.batch_init = jax.jit( + non_first_agent.batch_update = jax.jit( jax.vmap( - jax.vmap(agent3.make_initial_state, (0, None), 0), - (0, None), - 0, + jax.vmap(non_first_agent.update, (1, 0, 0, 0)), + (1, 0, 0, 0), ) ) + if OmegaConf.select(args, agent_arg) != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile( + non_first_agent._mem.hidden, (args.num_opps, 1, 1) + ) - agent3.batch_policy = jax.jit(jax.vmap(jax.vmap(agent3._policy, 0, 0))) - agent3.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent3.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent3.batch_update = jax.jit( - jax.vmap( - jax.vmap(agent3.update, (1, 0, 0, 0)), - (1, 0, 0, 0), - ) - ) - if args.agent3 != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent3._mem.hidden, (args.num_opps, 1, 1)) + agent_rng = jnp.concatenate( + [ + jax.random.split( + non_first_agent._state.random_key, args.num_opps + ) + ] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) - a3_rng = jnp.concatenate( - [jax.random.split(agent3._state.random_key, args.num_opps)] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) + ( + non_first_agent._state, + non_first_agent._mem, + ) = non_first_agent.batch_init( + agent_rng, + init_hidden, + ) - agent3._state, agent3._mem = agent3.batch_init( - a3_rng, - init_hidden, - ) + # jit evo strategy.ask = jax.jit(strategy.ask) strategy.tell = jax.jit(strategy.tell) param_reshaper.reshape = jax.jit(param_reshaper.reshape) @@ -220,105 +201,97 @@ def _inner_rollout(carry, unused): """Runner for inner episode""" ( rngs, - obs1, - obs2, - obs3, - r1, - r2, - r3, - a1_state, - a1_mem, - a2_state, - a2_mem, - a3_state, - a3_mem, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, env_state, env_params, ) = carry - + new_other_agent_mem = [None] * len(other_agents) # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] # a2_rng = rngs[:, :, :, 2, :] rngs = rngs[:, :, :, 3, :] - - a1, a1_state, new_a1_mem = agent1.batch_policy( - a1_state, - obs1, - a1_mem, - ) - a2, a2_state, new_a2_mem = agent2.batch_policy( - a2_state, - obs2, - a2_mem, - ) - a3, a3_state, new_a3_mem = agent3.batch_policy( - a3_state, - obs3, - a3_mem, + actions = [] + ( + first_action, + first_agent_state, + new_first_agent_mem, + ) = agent1.batch_policy( + first_agent_state, + first_agent_obs, + first_agent_mem, ) + actions.append(first_action) + for agent_idx, non_first_agent in enumerate(other_agents): + ( + non_first_action, + other_agent_state[agent_idx], + new_other_agent_mem[agent_idx], + ) = non_first_agent.batch_policy( + other_agent_state[agent_idx], + other_agent_obs[agent_idx], + other_agent_mem[agent_idx], + ) + actions.append(non_first_action) ( - (next_obs1, next_obs2, next_obs3), + all_agent_next_obs, env_state, - rewards, + all_agent_rewards, done, info, ) = env.step( env_rng, env_state, - (a1, a2, a3), + actions, env_params, ) + first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs + first_agent_reward, *other_agent_rewards = all_agent_rewards + traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], + first_agent_next_obs, + first_action, + first_agent_reward, + new_first_agent_mem.extras["log_probs"], + new_first_agent_mem.extras["values"], done, - a1_mem.hidden, - ) - traj2 = Sample( - obs2, - a2, - rewards[1], - new_a2_mem.extras["log_probs"], - new_a2_mem.extras["values"], - done, - a2_mem.hidden, - ) - traj3 = Sample( - obs3, - a3, - rewards[2], - new_a3_mem.extras["log_probs"], - new_a3_mem.extras["values"], - done, - a3_mem.hidden, + first_agent_mem.hidden, ) + other_traj = [ + Sample( + other_agent_next_obs[agent_idx], + actions[agent_idx + 1], + other_agent_rewards[agent_idx], + new_other_agent_mem[agent_idx].extras["log_probs"], + new_other_agent_mem[agent_idx].extras["values"], + done, + other_agent_mem[agent_idx].hidden, + ) + for agent_idx in range(len(other_agents)) + ] return ( rngs, - next_obs1, - next_obs2, - next_obs3, - rewards[0], - rewards[1], - rewards[2], - a1_state, - new_a1_mem, - a2_state, - new_a2_mem, - a3_state, - new_a3_mem, + first_agent_next_obs, + tuple(other_agent_next_obs), + first_agent_reward, + tuple(other_agent_rewards), + first_agent_state, + other_agent_state, + new_first_agent_mem, + new_other_agent_mem, env_state, env_params, - ), ( - traj1, - traj2, - traj3, - ) + ), (traj1, *other_traj) def _outer_rollout(carry, unused): """Runner for trial""" @@ -329,182 +302,190 @@ def _outer_rollout(carry, unused): None, length=args.num_inner_steps, ) + other_agent_metrics = [None] * len(other_agents) ( rngs, - obs1, - obs2, - obs3, - r1, - r2, - r3, - a1_state, - a1_mem, - a2_state, - a2_mem, - a3_state, - a3_mem, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, env_state, env_params, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": - a1_mem = agent1.meta_policy(a1_mem) - - # update second and third agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2, - a2_state, - a2_mem, - ) - a3_state, a3_mem, a3_metrics = agent3.batch_update( - trajectories[2], - obs3, - a3_state, - a3_mem, - ) + first_agent_mem = agent1.meta_policy(first_agent_mem) + # update opponents, we start with agent2 + for agent_idx, non_first_agent in enumerate(other_agents): + ( + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + other_agent_metrics[agent_idx], + ) = non_first_agent.batch_update( + trajectories[agent_idx + 1], + other_agent_obs[agent_idx], + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + ) return ( rngs, - obs1, - obs2, - obs3, - r1, - r2, - r3, - a1_state, - a1_mem, - a2_state, - a2_mem, - a3_state, - a3_mem, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, env_state, env_params, - ), (*trajectories, a2_metrics, a3_metrics) + ), (trajectories, other_agent_metrics) def _rollout( - _params: jnp.ndarray, - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _env_params: Any, + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _env_params: Any, ): # env reset - _rng_run, env_reset_rng = jax.random.split(_rng_run) env_rngs = jnp.concatenate( - [jax.random.split(env_reset_rng, args.num_envs)] + [jax.random.split(_rng_run, args.num_envs)] * args.num_opps * args.popsize ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) obs, env_state = env.reset(env_rngs, _env_params) rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] * args.num_players # Player 1 _a1_state = _a1_state._replace(params=_params) _a1_mem = agent1.batch_reset(_a1_mem, False) + # Other players + other_agent_mem = [None] * len(other_agents) + other_agent_state = [None] * len(other_agents) - if self.args.env_type not in ["meta"]: - raise RuntimeError( - "Only meta-experiments are supported with evo runner" - ) - _rng_run, agent2_rng, agent3_rng = jax.random.split(_rng_run, 3) - a2_rng = jnp.concatenate( - [jax.random.split(agent2_rng, args.num_opps)] * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - a3_rng = jnp.concatenate( - [jax.random.split(agent3_rng, args.num_opps)] * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - # meta-experiments - init 2nd agent per trial - a2_state, a2_mem = agent2.batch_init( - a2_rng, - agent2._mem.hidden, - ) - # meta-experiments - init 3nd agent per trial - a3_state, a3_mem = agent3.batch_init( - a3_rng, - agent3._mem.hidden, + _rng_run, *other_agent_rngs = jax.random.split( + _rng_run, args.num_players ) + for agent_idx, non_first_agent in enumerate(other_agents): + # indexing starts at 2 for args + agent_arg = f"agent{agent_idx+2}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + ( + other_agent_mem[agent_idx], + other_agent_state[agent_idx], + ) = non_first_agent.batch_init(obs[agent_idx + 1]) + else: + # meta-experiments - init 2nd agent per trial + non_first_agent_rng = jnp.concatenate( + [ + jax.random.split( + other_agent_rngs[agent_idx], args.num_opps + ) + ] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + ( + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + ) = non_first_agent.batch_init( + non_first_agent_rng, + non_first_agent._mem.hidden, + ) # run trials vals, stack = jax.lax.scan( _outer_rollout, ( env_rngs, - *obs, - *rewards, + obs[0], + tuple(obs[1:]), + rewards[0], + tuple(rewards[1:]), _a1_state, + other_agent_state, _a1_mem, - a2_state, - a2_mem, - a3_state, - a3_mem, + other_agent_mem, env_state, _env_params, ), None, length=self.num_outer_steps, ) - ( env_rngs, - obs1, - obs2, - obs3, - r1, - r2, - r3, - _a1_state, - _a1_mem, - a2_state, - a2_mem, - a3_state, - a3_mem, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, env_state, _env_params, ) = vals - traj_1, traj_2, traj_3, a2_metrics, a3_metrics = stack + trajectories, other_agent_metrics = stack # Fitness - fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) - other2_fitness = traj_3.rewards.mean(axis=(0, 1, 3, 4)) + fitness = trajectories[0].rewards.mean(axis=(0, 1, 3, 4)) + other_fitness = [ + traj.rewards.mean(axis=(0, 1, 3, 4)) + for traj in trajectories[1:] + ] # Stats - if args.env_id == "coin_game": - env_stats = jax.tree_util.tree_map( - lambda x: x, - self.cg_stats(env_state), - ) - - rewards_1 = traj_1.rewards.sum(axis=1).mean() - rewards_2 = traj_2.rewards.sum(axis=1).mean() - rewards_3 = traj_3.rewards.sum(axis=1).mean() - - elif args.env_id in [ - "iterated_tensor_game", + first_agent_reward = trajectories[0].rewards.mean() + other_agent_rewards = [ + traj.rewards.mean() for traj in trajectories[1:] + ] + if args.env_id in [ + "iterated_nplayer_tensor_game", ]: env_stats = jax.tree_util.tree_map( lambda x: x.mean(), self.ipd_stats( - obs1, + trajectories[0].observations, args.num_players + ), + ) + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cournot_stats( + trajectories[0].observations, _env_params, args.num_players + ), + ) + elif args.env_id == "Fishery": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.fishery_stats( + trajectories[0].observations, args.num_players + ), + ) + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cournot_stats( + trajectories[0].observations, _env_params, args.num_players ), ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - rewards_3 = traj_3.rewards.mean() + else: + env_stats = {} + return ( fitness, other_fitness, - other2_fitness, env_stats, - rewards_1, - rewards_2, - rewards_3, - a2_metrics, - a3_metrics, + first_agent_reward, + other_agent_rewards, + other_agent_metrics, ) self.rollout = jax.pmap( @@ -512,12 +493,16 @@ def _rollout( in_axes=(0, None, None, None, None), ) + print( + f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" + ) + def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, + self, + env_params, + agents, + num_iters: int, + watchers: Callable, ): """Run training of agents in environment""" print("Training") @@ -531,7 +516,6 @@ def run_loop( print(f"Log Interval: {log_interval}") print("------------------------------") # Initialize agents and RNG - agent1, agent2, agent3 = agents rng, _ = jax.random.split(self.random_key) # Initialize evolution @@ -558,16 +542,16 @@ def run_loop( # Reshape a single agent's params before vmapping init_hidden = jnp.tile( - agent1._mem.hidden, + agents[0]._mem.hidden, (popsize, num_opps, 1, 1), ) a1_rng = jax.random.split(rng, popsize) - agent1._state, agent1._mem = agent1.batch_init( + agents[0]._state, agents[0]._mem = agents[0].batch_init( a1_rng, init_hidden, ) - a1_state, a1_mem = agent1._state, agent1._mem + a1_state, a1_mem = agents[0]._state, agents[0]._mem for gen in range(num_gens): rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) @@ -583,33 +567,28 @@ def run_loop( ( fitness, other_fitness, - other2_fitness, env_stats, - rewards_1, - rewards_2, - rewards_3, - a2_metrics, - a3_metrics, + first_agent_reward, + other_agent_reward, + other_agent_metrics, ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) # Aggregate over devices fitness = jnp.reshape(fitness, popsize * self.args.num_devices) env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) - # Maximize fitness - fitness_re = fit_shaper.apply(x, fitness) - # Tell fitness_re = fit_shaper.apply(x, fitness) if self.args.es.mean_reduce: fitness_re = fitness_re - fitness_re.mean() evo_state = strategy.tell(x, fitness_re, evo_state, es_params) + # Logging log = es_logging.update(log, x, fitness) # Saving - if gen % self.args.save_interval == 0 or gen == 50: + if gen % self.args.save_interval == 0: log_savepath = os.path.join(self.save_dir, f"generation_{gen}") if self.args.num_devices > 1: top_params = param_reshaper.reshape( @@ -638,10 +617,10 @@ def run_loop( "--------------------------------------------------------------------------" ) print( - f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()} | Third Fitness: {other2_fitness.mean()}" + f"Fitness: {fitness.mean()} | Other Fitness: {[fitness.mean() for fitness in other_fitness]}" ) print( - f"Reward Per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" + f"Reward Per Timestep: {float(first_agent_reward.mean()), *[float(reward.mean()) for reward in other_agent_reward]}" ) print( f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" @@ -664,36 +643,51 @@ def run_loop( print() if watchers: + rewards_strs = [ + "train/reward_per_timestep/player_" + str(i) + for i in range(2, len(other_agent_reward) + 2) + ] + rewards_val = [ + float(reward.mean()) for reward in other_agent_reward + ] + rewards_dict = dict(zip(rewards_strs, rewards_val)) + fitness_str = [ + "train/fitness/player_" + str(i) + for i in range(2, len(other_fitness) + 2) + ] + fitness_val = [ + float(fitness.mean()) for fitness in other_fitness + ] + fitness_dict = dict(zip(fitness_str, fitness_val)) + all_rewards = other_agent_reward + [first_agent_reward] + global_welfare = float( + sum([reward.mean() for reward in all_rewards]) + / self.args.num_players + ) wandb_log = { - "train_iteration": gen, - "train/fitness/player_1": float(fitness.mean()), - "train/fitness/player_2": float(other_fitness.mean()), - "train/fitness/player_3": float(other2_fitness.mean()), - "train/fitness/top_overall_mean": log["log_top_mean"][gen], - "train/fitness/top_overall_std": log["log_top_std"][gen], - "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], - "train/fitness/top_gen_std": log["log_top_gen_std"][gen], - "train/fitness/gen_std": log["log_gen_std"][gen], - "train/time/minutes": float( - (time.time() - self.start_time) / 60 - ), - "train/time/seconds": float( - (time.time() - self.start_time) - ), - "train/reward_per_timestep/player_1": float( - rewards_1.mean() - ), - "train/reward_per_timestep/player_2": float( - rewards_2.mean() - ), - "train/reward_per_timestep/player_3": float( - rewards_3.mean() - ), - } + "train_iteration": gen, + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/fitness/player_1": float(fitness.mean()), + "train/reward_per_timestep/player_1": float( + first_agent_reward.mean() + ), + "train/global_welfare": global_welfare, + } | rewards_dict + wandb_log = wandb_log | fitness_dict wandb_log.update(env_stats) # loop through population for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) + zip(log["top_fitness"], log["top_gen_fitness"]) ): wandb_log[ f"train/fitness/top_overall_agent_{idx+1}" @@ -702,19 +696,17 @@ def run_loop( f"train/fitness/top_gen_agent_{idx+1}" ] = gen_fitness - # player 2 metrics + # other player metrics # metrics [outer_timesteps, num_opps] - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics - ) - agent2._logger.metrics.update(flattened_metrics) - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), a3_metrics - ) - agent3._logger.metrics.update(flattened_metrics) + for agent, metrics in zip(agents[1:], other_agent_metrics): + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), metrics + ) - for watcher, agent in zip(watchers, agents): - watcher(agent) + agent._logger.metrics.update(flattened_metrics) + # TODO fix agent logger + # for watcher, agent in zip(watchers, agents): + # watcher(agent) wandb_log = jax.tree_util.tree_map( lambda x: x.item() if isinstance(x, jax.Array) else x, wandb_log, diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 3a9fa222..5d0d337d 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -407,9 +407,9 @@ def _rollout( env_stats = jax.tree_util.tree_map( lambda x: x.mean(), self.cournot_stats( - traj_1, - traj_2, + traj_1.observations, _env_params, + 2 ), ) elif args.env_id == "Fishery": diff --git a/pax/runners/runner_marl_nplayer.py b/pax/runners/runner_marl_nplayer.py index d693e33b..2de09ce9 100644 --- a/pax/runners/runner_marl_nplayer.py +++ b/pax/runners/runner_marl_nplayer.py @@ -9,6 +9,8 @@ import wandb from pax.utils import MemoryState, TrainingState, save from pax.watchers import n_player_ipd_visitation +from pax.watchers.cournot import cournot_stats +from pax.watchers.fishery import fishery_stats MAX_WANDB_CALLS = 1000 @@ -89,6 +91,8 @@ def _reshape_opp_dim(x): self.reduce_opp_dim = jax.jit(_reshape_opp_dim) self.ipd_stats = n_player_ipd_visitation + self.cournot_stats = jax.jit(cournot_stats) + self.fishery_stats = jax.jit(fishery_stats) # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) env.step = jax.vmap( @@ -134,7 +138,7 @@ def _reshape_opp_dim(x): # go through opponents, we start with agent2 for agent_idx, non_first_agent in enumerate(other_agents): - agent_arg = f"agent{agent_idx+2}" + agent_arg = f"agent{agent_idx + 2}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": # special case where NaiveEx has a different call signature @@ -191,8 +195,6 @@ def _inner_rollout(carry, unused): # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, 0, :] - # a1_rng = rngs[:, :, 1, :] - # a2_rng = rngs[:, :, 2, :] rngs = rngs[:, :, 3, :] new_other_agent_mem = [None] * len(other_agents) @@ -305,7 +307,6 @@ def _outer_rollout(carry, unused): # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": first_agent_mem = agent1.meta_policy(first_agent_mem) - # TODO update first agent regularly? # update second agent for agent_idx, non_first_agent in enumerate(other_agents): @@ -334,12 +335,12 @@ def _outer_rollout(carry, unused): ), (trajectories, other_agent_metrics) def _rollout( - _rng_run: jnp.ndarray, - first_agent_state: TrainingState, - first_agent_mem: MemoryState, - other_agent_state: List[TrainingState], - other_agent_mem: List[MemoryState], - _env_params: Any, + _rng_run: jnp.ndarray, + first_agent_state: TrainingState, + first_agent_mem: MemoryState, + other_agent_state: List[TrainingState], + other_agent_mem: List[MemoryState], + _env_params: Any, ): # env reset rngs = jnp.concatenate( @@ -348,8 +349,8 @@ def _rollout( obs, env_state = env.reset(rngs, _env_params) rewards = [ - jnp.zeros((args.num_opps, args.num_envs)), - ] * args.num_players + jnp.zeros((args.num_opps, args.num_envs)), + ] * args.num_players # Player 1 first_agent_mem = agent1.batch_reset(first_agent_mem, False) @@ -359,7 +360,7 @@ def _rollout( for agent_idx, non_first_agent in enumerate(other_agents): # indexing starts at 2 for args - agent_arg = f"agent{agent_idx+2}" + agent_arg = f"agent{agent_idx + 2}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": ( @@ -429,6 +430,11 @@ def _rollout( other_agent_mem[agent_idx], False ) + first_agent_reward = trajectories[0].rewards.mean() + other_agent_rewards = [ + traj.rewards.mean() for traj in trajectories[1:] + ] + if args.env_id == "iterated_nplayer_tensor_game": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -436,16 +442,22 @@ def _rollout( trajectories[0].observations, args.num_players ), ) - first_agent_reward = trajectories[0].rewards.mean() - other_agent_rewards = [ - traj.rewards.mean() for traj in trajectories[1:] - ] + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cournot_stats( + trajectories[0].observations, _env_params, args.num_players + ), + ) + elif args.env_id == "Fishery": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.fishery_stats( + trajectories[0].observations, _env_params, args.num_players + ), + ) else: env_stats = {} - first_agent_reward = trajectories[0].rewards.mean() - other_agent_rewards = [ - traj.rewards.mean() for traj in trajectories[1:] - ] return ( env_stats, @@ -537,14 +549,14 @@ def run_loop(self, env_params, agents, num_iters, watchers): lambda x: jnp.mean(x), first_agent_metrics ) agent1._logger.metrics = ( - agent1._logger.metrics | flattened_metrics_1 + agent1._logger.metrics | flattened_metrics_1 ) for agent, metric in zip(other_agents, other_agent_metrics): flattened_metrics = jax.tree_util.tree_map( lambda x: jnp.mean(x), first_agent_metrics ) agent._logger.metrics = ( - agent._logger.metrics | flattened_metrics + agent._logger.metrics | flattened_metrics ) for watcher, agent in zip(watchers, agents): @@ -562,14 +574,14 @@ def run_loop(self, env_params, agents, num_iters, watchers): ] rewards_dict = dict(zip(rewards_strs, rewards_val)) wandb_log = ( - { - "train_iteration": i, - "train/reward_per_timestep/player_1": float( - first_agent_reward.mean().item() - ), - } - | rewards_dict - | env_stats + { + "train_iteration": i, + "train/reward_per_timestep/player_1": float( + first_agent_reward.mean().item() + ), + } + | rewards_dict + | env_stats ) wandb.log(wandb_log) diff --git a/pax/runners/runner_sarl.py b/pax/runners/runner_sarl.py index 5e0f1973..53c071d0 100644 --- a/pax/runners/runner_sarl.py +++ b/pax/runners/runner_sarl.py @@ -162,10 +162,8 @@ def _rollout( _a1_mem, ) - # reset memory _a1_mem = agent.batch_reset(_a1_mem, False) - # Stats rewards = jnp.sum(traj.rewards) / (jnp.sum(traj.dones) + 1e-8) env_stats = {} @@ -177,8 +175,7 @@ def _rollout( _a1_metrics, ) - self.rollout = _rollout - # self.rollout = jax.jit(_rollout) + self.rollout = jax.jit(_rollout) def run_loop(self, env, env_params, agent, num_iters, watcher): """Run training of agent in environment""" diff --git a/pax/watchers/cournot.py b/pax/watchers/cournot.py index a81f418f..8c20dfa9 100644 --- a/pax/watchers/cournot.py +++ b/pax/watchers/cournot.py @@ -3,15 +3,18 @@ from pax.envs.cournot import EnvParams as CournotEnvParams, CournotGame -def cournot_stats(observations: jnp.ndarray, params: CournotEnvParams, n_player: int) -> dict: +def cournot_stats(observations: jnp.ndarray, params: CournotEnvParams, num_players: int) -> dict: opt_quantity = CournotGame.nash_policy(params) - #average_quantity = (traj1.actions + traj2.actions) / 2 - return { - # "quantity/1": jnp.mean(traj1.actions), - # "quantity/2": jnp.mean(traj2.actions), - "quantity/average": jnp.mean(average_quantity), - # How strongly do the joint actions deviate from the socially optimum? - # Since the reward is a linear function of the quantity there is no need to consider it separately. - "quantity/loss": jnp.mean((opt_quantity / 2 - average_quantity) ** 2), + actions = observations[..., :num_players] + average_quantity = actions.mean() + + stats = { + "cournot/average_quantity": average_quantity, + "cournot/quantity_loss": jnp.mean((opt_quantity - average_quantity) ** 2), } + + for i in range(num_players): + stats["cournot/quantity_" + str(i)] = jnp.mean(observations[..., i]) + + return stats diff --git a/pax/watchers/fishery.py b/pax/watchers/fishery.py index 23244bf4..d0c104c0 100644 --- a/pax/watchers/fishery.py +++ b/pax/watchers/fishery.py @@ -6,22 +6,23 @@ from jax import numpy as jnp -def fishery_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict: +def fishery_stats(observations: jnp.ndarray, num_players: int) -> dict: # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim - stock_obs = traj1.observations[..., 0] + stock_obs = observations[..., -1] + actions = observations[..., :num_players] # TODO this blows up the memory usage # flattened_stock_obs = jnp.ravel(stock_obs) # split_stock_obs = jnp.array(jnp.split(flattened_stock_obs, flattened_stock_obs.shape[0] // num_inner_steps)) - return { - # "stock/ep_mean": split_stock_obs.mean(axis=1), - # "stock/ep_std": split_stock_obs.std(axis=1), - # "stock/ep_min": split_stock_obs.min(axis=1), - # "stock/ep_max": split_stock_obs.max(axis=1), - "fishery/stock_avg": jnp.mean(stock_obs), - "fishery/effort_1": jnp.mean(jax.nn.sigmoid(traj1.actions)), - "fishery/effort_2": jnp.mean(jax.nn.sigmoid(traj2.actions)), + stats = { + "fishery/stock": jnp.mean(stock_obs), + "fishery/mean_effort": actions.mean() } + for i in range(num_players): + stats["fishery/effort_" + str(i)] = jnp.mean(observations[..., i]) + + return stats + def fishery_eval_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict: # Calculate effort for both agents diff --git a/pax/watchers/rice.py b/pax/watchers/rice.py new file mode 100644 index 00000000..f08fb1c4 --- /dev/null +++ b/pax/watchers/rice.py @@ -0,0 +1,19 @@ +from typing import NamedTuple + +import jax +import numpy as np +import wandb +from jax import numpy as jnp + + +def rice_stats(traj: NamedTuple, num_players: int) -> dict: + # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim + stock_obs = observations[..., -1] + actions = observations[..., :num_players] + + + for i in range(num_players): + stats["fishery/effort_" + str(i)] = jnp.mean(observations[..., i]) + + return stats + diff --git a/test/envs/test_fishery.py b/test/envs/test_fishery.py index ef61e410..776abf73 100644 --- a/test/envs/test_fishery.py +++ b/test/envs/test_fishery.py @@ -8,7 +8,7 @@ def test_fishery_convergence(): rng = jax.random.PRNGKey(0) ep_length = 300 - env = Fishery(num_inner_steps=ep_length) + env = Fishery(num_players=2, num_inner_steps=ep_length) env_params = EnvParams( g=0.15, e=0.009, diff --git a/test/envs/test_rice.py b/test/envs/test_rice.py new file mode 100644 index 00000000..828dd5d4 --- /dev/null +++ b/test/envs/test_rice.py @@ -0,0 +1,65 @@ +import os +import time + +import jax +from pytest import mark + +from pax.envs.rice import Rice, EnvParams + +file_dir = os.path.join(os.path.dirname(__file__)) +env_dir = os.path.join(file_dir, "../../pax/envs") + +def test_rice(): + rng = jax.random.PRNGKey(0) + ep_length = 20 + num_players = 27 + + env = Rice(num_inner_steps=ep_length, config_folder=os.path.join(env_dir, "region_yamls")) + env_params = EnvParams() + obs, env_state = env.reset(rng, EnvParams()) + + for i in range(3 * ep_length + 1): + # Do random actions + key, _ = jax.random.split(rng, 2) + action = jax.random.uniform(rng, (env.num_actions,)) + actions = tuple([action for _ in range(num_players)]) + obs, env_state, rewards, done, info = env.step( + rng, env_state, actions, env_params + ) + for i in range(num_players): + assert env_state.consumption_all[i].item() >= 0, "consumption cannot be negative!" + + # assert all obs positive + # assert done firing correctly + + +# TODO assert gross_output - investment - total_exports > -1e-5, "consumption cannot be negative!" + +@mark.skip(reason="Benchmark") +def test_benchmark_rice_n(): + rng = jax.random.PRNGKey(0) + ep_length = 100 + iterations = 100 + num_players = 27 + + env = Rice(num_inner_steps=ep_length, config_folder=os.path.join(env_dir, "region_yamls")) + env_params = EnvParams() + obs, env_state = env.reset(rng, EnvParams()) + + start_time = time.time() + + for i in range(iterations * ep_length): + # Do random actions + key, _ = jax.random.split(rng, 2) + action = jax.random.uniform(rng, (env.num_actions,)) + actions = tuple([action for _ in range(num_players)]) + obs, env_state, rewards, done, info = env.step( + rng, env_state, actions, env_params + ) + + end_time = time.time() # End timing + total_time = end_time - start_time + + # Print or log the total time taken for all iterations + print(f"Total time taken:\t{total_time:.4f} seconds") + print(f"Average step duration:\t{total_time / (iterations * ep_length):.4f} seconds")