diff --git a/pax/agents/ppo/ppo_gru.py b/pax/agents/ppo/ppo_gru.py index 6a6428fd..d56d7c2c 100644 --- a/pax/agents/ppo/ppo_gru.py +++ b/pax/agents/ppo/ppo_gru.py @@ -546,7 +546,7 @@ def make_gru_agent( network, initial_hidden_state = make_GRU_fishery_network( action_spec, agent_args.hidden_size ) - elif args.env_id == Rice.env_id: + elif args.env_id in [Rice.env_id, "Rice-v1"]: network, initial_hidden_state = make_GRU_rice_network( action_spec, agent_args.hidden_size ) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 5c31854c..79d87c73 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -25,6 +25,7 @@ agent1: 'PPO' agent2: 'PPO' shuffle_players: False agent2_roles: 1 +agent2_reset_interval: 1 # Reset every rollout # Logging setup wandb: diff --git a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml index d623a31e..d6a9bcde 100644 --- a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml +++ b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml @@ -19,7 +19,7 @@ top_k: 5 popsize: 1000 num_envs: 1 num_opps: 1 -num_outer_steps: 2000 +num_outer_steps: 180 num_inner_steps: 20 num_iters: 1500 num_devices: 1 diff --git a/pax/conf/experiment/rice/gs_v_ppo.yaml b/pax/conf/experiment/rice/gs_v_ppo.yaml new file mode 100644 index 00000000..daf31c69 --- /dev/null +++ b/pax/conf/experiment/rice/gs_v_ppo.yaml @@ -0,0 +1,100 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO_memory' +agent2_roles: 4 + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: False +config_folder: pax/envs/rice/5_regions +runner: evo + + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 500 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + name: 'rice-GS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/rice/shaper_v_ppo.yaml b/pax/conf/experiment/rice/shaper_v_ppo.yaml index 7ed03d8e..61e619bb 100644 --- a/pax/conf/experiment/rice/shaper_v_ppo.yaml +++ b/pax/conf/experiment/rice/shaper_v_ppo.yaml @@ -12,7 +12,7 @@ env_id: Rice-N env_type: meta num_players: 5 has_mediator: False -shuffle_players: False +shuffle_players: True config_folder: pax/envs/rice/5_regions runner: evo @@ -76,7 +76,7 @@ es: # Logging setup wandb: - project: c-rice + project: rice group: 'shaper' name: 'rice-SHAPER-${agent_default}-seed-${seed}' log: True diff --git a/pax/envs/rice/rice.py b/pax/envs/rice/rice.py index ba85c759..31a862c1 100644 --- a/pax/envs/rice/rice.py +++ b/pax/envs/rice/rice.py @@ -108,7 +108,7 @@ def __init__(self, num_inner_steps: int, config_folder: str, has_mediator=False) self.mitigation_rate_action_index = self.savings_action_index + self.savings_action_n self.export_action_index = self.mitigation_rate_action_index + self.mitigation_rate_action_n self.tariffs_action_index = self.export_action_index + self.export_action_n - self.desired_imports_action_index = self.tariffs_action_index + self.tariff_actions_n + self.desired_imports_action_index = self.tariffs_action_index + self.import_actions_n # Parameters for armington aggregation utility self.sub_rate = jnp.asarray(0.5, dtype=float_precision) @@ -188,7 +188,7 @@ def _step( tariffed_imports = scaled_imports * (1 - prev_tariffs) # calculate tariffed imports, tariff revenue and budget balance # In the paper this goes to a "special reserve fund", i.e. it's not used - tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=0) + tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=1) total_exports = scaled_imports.sum(axis=0) balance_all = balance_all + self.dice_constant["xDelta"] * ( @@ -229,9 +229,6 @@ def _step( self.dice_constant["xB_M"], jnp.sum(aux_m_all), ) - # def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m): - # return jnp.dot(phi_m, carbon_mass) + jnp.dot(b_m, aux_m) - capital_depreciation = get_capital_depreciation( self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] diff --git a/pax/experiment.py b/pax/experiment.py index c8942880..df65b4c5 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -723,7 +723,7 @@ def naive_pg_log(agent): return agent_log -@hydra.main(config_path="conf", config_name="config") +@hydra.main(config_path="conf", config_name="config", version_base="1.1") def main(args): print(f"Jax backend: {xla_bridge.get_backend().platform}") diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index 252a6676..3b654516 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -326,6 +326,7 @@ def _rollout( _rng_run: jnp.ndarray, _a1_state: TrainingState, _a1_mem: MemoryState, + _a2_state: TrainingState, _env_params: Any, ): # env reset @@ -343,7 +344,9 @@ def _rollout( _a1_state = _a1_state._replace(params=_params) _a1_mem = agent1.batch_reset(_a1_mem, False) # Player 2 - if args.agent2 == "NaiveEx": + if _a2_state is not None: + a2_state = _a2_state + elif args.agent2 == "NaiveEx": a2_state, a2_mem = agent2.batch_init(obs[1]) else: # meta-experiments - init 2nd agent per trial @@ -469,7 +472,7 @@ def _rollout( self.rollout = jax.pmap( _rollout, - in_axes=(0, None, None, None, None), + in_axes=(0, None, None, None, None, None), ) print( @@ -532,6 +535,7 @@ def run_loop( ) a1_state, a1_mem = agent1._state, agent1._mem + a2_state = None for gen in range(num_gens): rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) @@ -543,6 +547,10 @@ def run_loop( params = jax.tree_util.tree_map( lambda x: jax.lax.expand_dims(x, (0,)), params ) + + if gen % self.args.agent2_reset_interval == 0: + a2_state = None + # Evo Rollout ( fitness, @@ -552,7 +560,7 @@ def run_loop( rewards_2, a2_metrics, a2_state - ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) + ) = self.rollout(params, rng_run, a1_state, a1_mem, a2_state, env_params) # Aggregate over devices fitness = jnp.reshape(fitness, popsize * self.args.num_devices).astype(dtype=jnp.float32)