Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Sep 27, 2023
1 parent 57899fd commit 1a4d194
Show file tree
Hide file tree
Showing 25 changed files with 100 additions and 68 deletions.
2 changes: 1 addition & 1 deletion pax/agents/naive/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int):
if args.env_id == "coin_game":
print(f"Making network for {args.env_id} with CNN")
network = make_coingame_network(action_spec, args)
elif args.env_id == "Rice-v1":
elif args.env_id == "Rice-N":
network = make_rice_network(action_spec)
else:
network = make_network(action_spec)
Expand Down
7 changes: 4 additions & 3 deletions pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
make_sarl_network, make_cournot_network,
make_fishery_network, make_rice_sarl_network,
)
from pax.envs.rice.c_rice import ClubRice
from pax.envs.rice.rice import Rice
from pax.envs.rice.sarl_rice import SarlRice
from pax.utils import Logger, MemoryState, TrainingState, get_advantages, float_precision
Expand Down Expand Up @@ -503,12 +504,12 @@ def make_agent(
network = make_rice_sarl_network(action_spec, agent_args.hidden_size)
elif args.env_id == Rice.env_id:
network = make_rice_sarl_network(action_spec, agent_args.hidden_size)
elif args.env_id == ClubRice.env_id:
network = make_rice_sarl_network(action_spec, agent_args.hidden_size)
elif args.runner == "sarl":
network = make_sarl_network(action_spec)
else:
network = make_ipd_network(
action_spec, tabular, agent_args.hidden_size
)
raise NotImplementedError(f"No ppo network implemented for env {args.env_id}")

# Optimizer
transition_steps = (
Expand Down
5 changes: 5 additions & 0 deletions pax/agents/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
make_GRU_ipditm_network, make_GRU_fishery_network, make_GRU_rice_network,
)
from pax.envs.rice.rice import Rice
from pax.envs.rice.c_rice import ClubRice
from pax.utils import MemoryState, TrainingState, get_advantages

# from dm_env import TimeStep
Expand Down Expand Up @@ -549,6 +550,10 @@ def make_gru_agent(
network, initial_hidden_state = make_GRU_rice_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == ClubRice.env_id:
network, initial_hidden_state = make_GRU_rice_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_GRU_ipditm_network(
action_spec,
Expand Down
10 changes: 5 additions & 5 deletions pax/conf/experiment/c_rice/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ agent1: 'PPO'
agent_default: 'PPO'

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: meta
num_players: 6
has_mediator: True
config_folder: pax/envs/Rice/5_regions
runner: tensor_evo
config_folder: pax/envs/rice/5_regions
runner: evo

# Training
top_k: 5
Expand Down Expand Up @@ -72,10 +72,10 @@ es:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'mediator'
mode: 'offline'
name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}'
name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}'
log: False


10 changes: 5 additions & 5 deletions pax/conf/experiment/c_rice/marl_baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
agent_default: 'PPO'

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: meta
num_players: 6
has_mediator: True
config_folder: pax/envs/Rice/5_regions
runner: tensor_evo
config_folder: pax/envs/rice/5_regions
runner: evo

# Training
top_k: 5
Expand Down Expand Up @@ -48,7 +48,7 @@ ppo_default:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'mediator'
name: 'c-Rice-MARL-${agent_default}-seed-${seed}'
name: 'c-rice-MARL-${agent_default}-seed-${seed}'
log: True
17 changes: 9 additions & 8 deletions pax/conf/experiment/c_rice/mediator_gs_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@
# Agents
agent1: 'PPO'
agent_default: 'Naive'
agent2_roles: 5

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: meta
num_players: 6
has_mediator: True
config_folder: pax/envs/Rice/5_regions
runner: tensor_evo
config_folder: pax/envs/rice/5_regions
runner: evo

# Training
top_k: 5
popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 1
num_outer_steps: 2000
num_inner_steps: 20
num_iters: 3500
num_iters: 1500
num_devices: 1
num_steps: 4
num_steps: 200


# PPO agent parameters
Expand Down Expand Up @@ -82,9 +83,9 @@ es:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'mediator'
name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}'
name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}'
log: True


18 changes: 10 additions & 8 deletions pax/conf/experiment/c_rice/mediator_gs_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@

# Agents
agent1: 'PPO'
agent_default: 'PPO'
agent2: 'PPO_memory'
agent_default: 'PPO_memory'
agent2_roles: 5

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: meta
num_players: 6
has_mediator: True
config_folder: pax/envs/Rice/5_regions
runner: tensor_evo
config_folder: pax/envs/rice/5_regions
runner: evo

# Training
top_k: 5
popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 200
num_outer_steps: 2000
num_inner_steps: 20
num_iters: 3500
num_iters: 1500
num_devices: 1
num_steps: 200

Expand Down Expand Up @@ -72,9 +74,9 @@ es:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'mediator'
name: 'c-Rice-mediator-GS-${agent_default}-seed-${seed}'
name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}'
log: True


8 changes: 4 additions & 4 deletions pax/conf/experiment/c_rice/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ agent_default: 'PPO_memory'
agent2_roles: 4

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: meta
num_players: 5
has_mediator: False
shuffle_players: False
config_folder: pax/envs/Rice/5_regions
config_folder: pax/envs/rice/5_regions
runner: evo

# Training
Expand Down Expand Up @@ -76,9 +76,9 @@ es:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'shaper'
name: 'c-Rice-SHAPER-${agent_default}-seed-${seed}'
name: 'c-rice-SHAPER-${agent_default}-seed-${seed}'
log: True


8 changes: 4 additions & 4 deletions pax/conf/experiment/c_rice/weight_sharing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ agent2: 'PPO_memory'
agent_default: 'PPO_memory'

# Environment
env_id: C-Rice-v1
env_id: C-Rice-N
env_type: sequential
num_players: 5
has_mediator: False
config_folder: pax/envs/Rice/5_regions
config_folder: pax/envs/rice/5_regions
runner: weight_sharing
# Training hyperparameters

Expand Down Expand Up @@ -48,6 +48,6 @@ ppo_default:

# Logging setup
wandb:
project: c-Rice
project: c-rice
group: 'weight_sharing'
name: 'c-Rice-weight_sharing-${agent1}-seed-${seed}'
name: 'c-rice-weight_sharing-${agent1}-seed-${seed}'
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ agent1: 'PPO'
agent_default: 'PPO'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 6
has_mediator: True
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/marl_baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
agent_default: 'PPO'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 6
has_mediator: True
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/mediator_gs_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ agent1: 'PPO'
agent_default: 'Naive'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 6
has_mediator: True
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/mediator_gs_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ agent1: 'PPO'
agent_default: 'PPO'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 6
has_mediator: True
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/mediator_shaper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
agent_default: 'PPO_memory'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 6
has_mediator: True
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/mfos_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ agent_default: 'PPO_memory'
agent2_roles: 4

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 5
has_mediator: False
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/sarl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
agent1: 'PPO'

# Environment
env_id: SarlRice-v1
env_id: SarlRice-N
env_type: sequential
num_players: 5
has_mediator: False
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ agent_default: 'PPO_memory'
agent2_roles: 4

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: meta
num_players: 5
has_mediator: False
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/rice/weight_sharing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ agent2: 'PPO_memory'
agent_default: 'PPO_memory'

# Environment
env_id: Rice-v1
env_id: Rice-N
env_type: sequential
num_players: 5
has_mediator: False
Expand Down
Loading

0 comments on commit 1a4d194

Please sign in to comment.