Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Sep 29, 2023
1 parent 066480c commit 3f15a2e
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pax/agents/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def make_gru_agent(
network, initial_hidden_state = make_GRU_fishery_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == Rice.env_id:
elif args.env_id in [Rice.env_id, "Rice-v1"]:
network, initial_hidden_state = make_GRU_rice_network(
action_spec, agent_args.hidden_size
)
Expand Down
1 change: 1 addition & 0 deletions pax/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ agent1: 'PPO'
agent2: 'PPO'
shuffle_players: False
agent2_roles: 1
agent2_reset_interval: 1 # Reset every rollout

# Logging setup
wandb:
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/c_rice/mediator_gs_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ top_k: 5
popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 2000
num_outer_steps: 180
num_inner_steps: 20
num_iters: 1500
num_devices: 1
Expand Down
100 changes: 100 additions & 0 deletions pax/conf/experiment/rice/gs_v_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# @package _global_

# Agents
agent1: 'PPO'
agent_default: 'PPO_memory'
agent2_roles: 4

# Environment
env_id: Rice-N
env_type: meta
num_players: 5
has_mediator: False
shuffle_players: False
config_folder: pax/envs/rice/5_regions
runner: evo


# Training
top_k: 5
popsize: 1000
num_envs: 4
num_opps: 1
num_outer_steps: 500
num_inner_steps: 20
num_iters: 1500
num_devices: 1
num_steps: 200

# PPO agent parameters
ppo1:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.02
entropy_coeff_horizon: 2000000
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5
with_memory: False
with_cnn: False
hidden_size: 16

# PPO agent parameters
ppo2:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.02
entropy_coeff_horizon: 2000000
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5
with_memory: False
with_cnn: False
hidden_size: 16


# ES parameters
es:
algo: OpenES # [OpenES, CMA_ES]
sigma_init: 0.04 # Initial scale of isotropic Gaussian noise
sigma_decay: 0.999 # Multiplicative decay factor
sigma_limit: 0.01 # Smallest possible scale
init_min: 0.0 # Range of parameter mean initialization - Min
init_max: 0.0 # Range of parameter mean initialization - Max
clip_min: -1e10 # Range of parameter proposals - Min
clip_max: 1e10 # Range of parameter proposals - Max
lrate_init: 0.01 # Initial learning rate
lrate_decay: 0.9999 # Multiplicative decay factor
lrate_limit: 0.001 # Smallest possible lrate
beta_1: 0.99 # Adam - beta_1
beta_2: 0.999 # Adam - beta_2
eps: 1e-8 # eps constant,
centered_rank: False # Fitness centered_rank
w_decay: 0 # Decay old elite fitness
maximise: True # Maximise fitness
z_score: False # Normalise fitness
mean_reduce: True # Remove mean

# Logging setup
wandb:
project: rice
name: 'rice-GS-${agent1}-vs-${agent2}'
log: True


4 changes: 2 additions & 2 deletions pax/conf/experiment/rice/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ env_id: Rice-N
env_type: meta
num_players: 5
has_mediator: False
shuffle_players: False
shuffle_players: True
config_folder: pax/envs/rice/5_regions
runner: evo

Expand Down Expand Up @@ -76,7 +76,7 @@ es:

# Logging setup
wandb:
project: c-rice
project: rice
group: 'shaper'
name: 'rice-SHAPER-${agent_default}-seed-${seed}'
log: True
Expand Down
7 changes: 2 additions & 5 deletions pax/envs/rice/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, num_inner_steps: int, config_folder: str, has_mediator=False)
self.mitigation_rate_action_index = self.savings_action_index + self.savings_action_n
self.export_action_index = self.mitigation_rate_action_index + self.mitigation_rate_action_n
self.tariffs_action_index = self.export_action_index + self.export_action_n
self.desired_imports_action_index = self.tariffs_action_index + self.tariff_actions_n
self.desired_imports_action_index = self.tariffs_action_index + self.import_actions_n

# Parameters for armington aggregation utility
self.sub_rate = jnp.asarray(0.5, dtype=float_precision)
Expand Down Expand Up @@ -188,7 +188,7 @@ def _step(
tariffed_imports = scaled_imports * (1 - prev_tariffs)
# calculate tariffed imports, tariff revenue and budget balance
# In the paper this goes to a "special reserve fund", i.e. it's not used
tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=0)
tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=1)

total_exports = scaled_imports.sum(axis=0)
balance_all = balance_all + self.dice_constant["xDelta"] * (
Expand Down Expand Up @@ -229,9 +229,6 @@ def _step(
self.dice_constant["xB_M"],
jnp.sum(aux_m_all),
)
# def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m):
# return jnp.dot(phi_m, carbon_mass) + jnp.dot(b_m, aux_m)


capital_depreciation = get_capital_depreciation(
self.rice_constant["xdelta_K"], self.dice_constant["xDelta"]
Expand Down
2 changes: 1 addition & 1 deletion pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def naive_pg_log(agent):
return agent_log


@hydra.main(config_path="conf", config_name="config")
@hydra.main(config_path="conf", config_name="config", version_base="1.1")
def main(args):
print(f"Jax backend: {xla_bridge.get_backend().platform}")

Expand Down
14 changes: 11 additions & 3 deletions pax/runners/runner_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def _rollout(
_rng_run: jnp.ndarray,
_a1_state: TrainingState,
_a1_mem: MemoryState,
_a2_state: TrainingState,
_env_params: Any,
):
# env reset
Expand All @@ -343,7 +344,9 @@ def _rollout(
_a1_state = _a1_state._replace(params=_params)
_a1_mem = agent1.batch_reset(_a1_mem, False)
# Player 2
if args.agent2 == "NaiveEx":
if _a2_state is not None:
a2_state = _a2_state
elif args.agent2 == "NaiveEx":
a2_state, a2_mem = agent2.batch_init(obs[1])
else:
# meta-experiments - init 2nd agent per trial
Expand Down Expand Up @@ -469,7 +472,7 @@ def _rollout(

self.rollout = jax.pmap(
_rollout,
in_axes=(0, None, None, None, None),
in_axes=(0, None, None, None, None, None),
)

print(
Expand Down Expand Up @@ -532,6 +535,7 @@ def run_loop(
)

a1_state, a1_mem = agent1._state, agent1._mem
a2_state = None

for gen in range(num_gens):
rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4)
Expand All @@ -543,6 +547,10 @@ def run_loop(
params = jax.tree_util.tree_map(
lambda x: jax.lax.expand_dims(x, (0,)), params
)

if gen % self.args.agent2_reset_interval == 0:
a2_state = None

# Evo Rollout
(
fitness,
Expand All @@ -552,7 +560,7 @@ def run_loop(
rewards_2,
a2_metrics,
a2_state
) = self.rollout(params, rng_run, a1_state, a1_mem, env_params)
) = self.rollout(params, rng_run, a1_state, a1_mem, a2_state, env_params)

# Aggregate over devices
fitness = jnp.reshape(fitness, popsize * self.args.num_devices).astype(dtype=jnp.float32)
Expand Down

0 comments on commit 3f15a2e

Please sign in to comment.