Skip to content

Commit

Permalink
add rice environment
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Aug 16, 2023
1 parent 4d9b179 commit 8ba77da
Show file tree
Hide file tree
Showing 56 changed files with 2,091 additions and 627 deletions.
26 changes: 26 additions & 0 deletions pax/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,32 @@ def forward_fn(
return network, hidden_state


def make_GRU_rice_network(
num_actions: int,
hidden_size: int,
):
hidden_state = jnp.zeros((1, hidden_size))

def forward_fn(
inputs: jnp.ndarray, state: jnp.ndarray
) -> tuple[tuple[MultivariateNormalDiag, Array], Any]:
"""forward function"""
gru = hk.GRU(
hidden_size,
w_i_init=hk.initializers.Orthogonal(jnp.sqrt(1)),
w_h_init=hk.initializers.Orthogonal(jnp.sqrt(1)),
b_init=hk.initializers.Constant(0),
)

cvh = ContinuousValueHead(num_values=num_actions)
embedding, state = gru(inputs, state)
logits, values = cvh(embedding)
return (logits, values), state

network = hk.without_apply_rng(hk.transform(forward_fn))
return network, hidden_state


def test_GRU():
key = jax.random.PRNGKey(seed=0)
num_actions = 2
Expand Down
6 changes: 5 additions & 1 deletion pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
make_ipd_network, make_cournot_network,
make_fishery_network,
)
from pax.envs.rice import Rice
from pax.envs.sarl_rice import SarlRice
from pax.utils import Logger, MemoryState, TrainingState, get_advantages


Expand Down Expand Up @@ -339,7 +341,7 @@ def model_update_epoch(

return new_state, new_memory, metrics

def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
def make_initial_state(key: Any, hidden: jnp.ndarray) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)

Expand Down Expand Up @@ -495,6 +497,8 @@ def make_agent(
network = make_cournot_network(action_spec, agent_args.hidden_size)
elif args.env_id == "Fishery":
network = make_fishery_network(action_spec, agent_args.hidden_size)
elif args.env_id == Rice.env_id:
network = make_fishery_network(action_spec, agent_args.hidden_size)
else:
network = make_ipd_network(
action_spec, tabular, agent_args.hidden_size
Expand Down
12 changes: 10 additions & 2 deletions pax/agents/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
make_GRU_cartpole_network,
make_GRU_coingame_network,
make_GRU_ipd_network,
make_GRU_ipditm_network, make_GRU_fishery_network,
make_GRU_ipditm_network, make_GRU_fishery_network, make_GRU_rice_network,
)
from pax.envs.rice import Rice
from pax.utils import MemoryState, TrainingState, get_advantages

# from dm_env import TimeStep
Expand Down Expand Up @@ -517,7 +518,14 @@ def make_gru_agent(
network, initial_hidden_state = make_GRU_ipd_network(
action_spec, agent_args.hidden_size
)

elif args.env_id == "Fishery":
network, initial_hidden_state = make_GRU_fishery_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == Rice.env_id:
network, initial_hidden_state = make_GRU_rice_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_GRU_ipditm_network(
action_spec,
Expand Down
38 changes: 8 additions & 30 deletions pax/conf/experiment/cournot/gs_v_ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
# @package _global_

# Agents
agent1: 'PPO'
agent2: 'PPO'
# Agent default applies to all agents
agent_default: 'PPO'

# Environment
env_id: Cournot
env_type: meta
a: 100
b: 1
marginal_cost: 10
# This means the optimum quantity is 2(a-marginal_cost)/3b = 60
# This means the nash quantity is 2(a-marginal_cost)/3b = 60

# Runner
runner: evo
runner: tensor_evo

# Training
top_k: 5
popsize: 1000
num_envs: 2
num_envs: 4
num_opps: 1
num_outer_steps: 100
num_inner_steps: 50
num_outer_steps: 300
num_inner_steps: 1 # One-shot game
num_iters: 3000
num_devices: 1

# PPO agent parameters
ppo1:
ppo_default:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
Expand All @@ -46,28 +46,6 @@ ppo1:
with_cnn: False
hidden_size: 16

# PPO agent parameters
ppo2:
num_minibatches: 4
num_epochs: 2
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.02
entropy_coeff_horizon: 2000000
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5
with_memory: False
with_cnn: False
hidden_size: 16


# ES parameters
es:
algo: OpenES # [OpenES, CMA_ES]
Expand Down
33 changes: 6 additions & 27 deletions pax/conf/experiment/cournot/marl_baseline.yaml
Original file line number Diff line number Diff line change
@@ -1,51 +1,30 @@
# @package _global_

# Agents
agent1: 'PPO'
agent2: 'PPO'
agent_default: 'PPO'

# Environment
env_id: Cournot
env_type: sequential
a: 100
b: 1
marginal_cost: 10
# This means the optimum quantity is 2(a-marginal_cost)/3b = 60
runner: rl
# This means the nash quantity is 2(a-marginal_cost)/3b = 60
runner: tensor_rl_nplayer

# env_batch_size = num_envs * num_opponents
num_envs: 20
num_opps: 1
num_outer_steps: 1
num_inner_steps: 100 # number of inner steps (only for MetaFinite Env)
num_outer_steps: 300
num_inner_steps: 1 # One-shot game
num_iters: 1e7

# Useful information
# batch_size = num_envs * num_inner_steps
# batch_size % num_minibatches == 0 must hold

# PPO agent parameters
ppo1:
num_minibatches: 10
num_epochs: 4
gamma: 0.96
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.1
entropy_coeff_horizon: 0.25e9
entropy_coeff_end: 0.05
lr_scheduling: True
learning_rate: 3e-4
adam_epsilon: 1e-5
with_memory: True
hidden_size: 16
with_cnn: False

ppo2:
ppo_default:
num_minibatches: 10
num_epochs: 4
gamma: 0.96
Expand Down
6 changes: 3 additions & 3 deletions pax/conf/experiment/cournot/mfos_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ b: 1
marginal_cost: 10

# Runner
runner: evo
runner: tensor_evo

# Training
top_k: 5
popsize: 1000
num_envs: 4
num_opps: 1
num_outer_steps: 100
num_inner_steps: 50
num_outer_steps: 300
num_inner_steps: 1 # One-shot game
num_iters: 3000
num_devices: 1
num_steps: '${num_inner_steps}'
Expand Down
8 changes: 4 additions & 4 deletions pax/conf/experiment/cournot/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ b: 1
marginal_cost: 10

# Runner
runner: evo
runner: tensor_evo

# Training
top_k: 5
popsize: 1000
num_envs: 2
num_envs: 4
num_opps: 1
num_outer_steps: 1
num_inner_steps: 100
num_outer_steps: 300
num_inner_steps: 1 # One-shot game
num_iters: 5000
num_devices: 1
num_steps: '${num_inner_steps}'
Expand Down
6 changes: 3 additions & 3 deletions pax/conf/experiment/fishery/gs_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Agents
agent1: 'PPO'
agent2: 'PPO_memory'
agent_default: 'PPO_memory'

# Environment
env_id: Fishery
Expand All @@ -16,14 +16,14 @@ s_max: 1.0


# Runner
runner: evo
runner: tensor_evo

# Training
top_k: 5
popsize: 1000
num_envs: 2
num_opps: 1
num_outer_steps: 40
num_outer_steps: 1
num_inner_steps: 300
num_iters: 1000
num_devices: 1
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/fishery/marl_baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ s_0: 0.5
s_max: 1.0

# This means the optimum quantity is 2(a-marginal_cost)/3b = 60
runner: rl
runner: tensor_rl_nplayer

# env_batch_size = num_envs * num_opponents
num_envs: 100
Expand Down
51 changes: 51 additions & 0 deletions pax/conf/experiment/rice/ctde.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

# Agents
agent1: 'PPO'

# Environment
env_id: Rice-v1
env_type: sequential
num_players: 27
config_folder: pax/envs/region_yamls
runner: ctde
# Training hyperparameters

# env_batch_size = num_envs * num_opponents
num_envs: 50
num_inner_steps: 100
num_iters: 1e7
save_interval: 100
num_steps: '${num_inner_steps}'

# Evaluation
#run_path: ucl-dark/cg/3sp0y2cy
#model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900

ppo0:
num_minibatches: 4
num_epochs: 4
gamma: 0.99
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.01
entropy_coeff_horizon: 1000000
entropy_coeff_end: 0.001
lr_scheduling: True
learning_rate: 2.5e-4 #5e-4
adam_epsilon: 1e-5
with_memory: True
with_cnn: False
output_channels: 16
kernel_shape: [3, 3]
separate: True
hidden_size: 16

# Logging setup
wandb:
group: rice
name: 'rice-SARL-${agent1}'
51 changes: 51 additions & 0 deletions pax/conf/experiment/rice/sarl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

# Agents
agent1: 'PPO'

# Environment
env_id: Rice-v1
env_type: sequential
num_players: 27
config_folder: pax/envs/region_yamls
runner: ctde
# Training hyperparameters

# env_batch_size = num_envs * num_opponents
num_envs: 50
num_inner_steps: 100
num_iters: 1e7
save_interval: 100
num_steps: '${num_inner_steps}'

# Evaluation
#run_path: ucl-dark/cg/3sp0y2cy
#model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900

ppo0:
num_minibatches: 4
num_epochs: 4
gamma: 0.99
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.01
entropy_coeff_horizon: 1000000
entropy_coeff_end: 0.001
lr_scheduling: True
learning_rate: 2.5e-4 #5e-4
adam_epsilon: 1e-5
with_memory: True
with_cnn: False
output_channels: 16
kernel_shape: [3, 3]
separate: True
hidden_size: 16

# Logging setup
wandb:
group: rice
name: 'rice-SARL-${agent1}'
Loading

0 comments on commit 8ba77da

Please sign in to comment.