Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lola #11

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d5ed96
begin adding centralized learning
newtonkwan Jun 27, 2022
ec3ce6a
first commit. begin adding centralized training for LOLA
newtonkwan Jul 5, 2022
2664eea
add base lola
newtonkwan Jul 5, 2022
461f772
Merge branch 'main' into add_lola
newtonkwan Jul 5, 2022
bcb4833
add centralized learner
newtonkwan Jul 5, 2022
39de443
resolve merge conflict. add centralized learning
newtonkwan Jul 5, 2022
6c17d02
add lola machinery to experiments.py
newtonkwan Jul 5, 2022
8fab167
fix entropy annealing
newtonkwan Jul 6, 2022
21dc4fd
fix done conditiion in additional rollout step in PPO
newtonkwan Jul 7, 2022
6fe1d02
minor changes to lola
newtonkwan Jul 8, 2022
ac3a7f1
merge main with add_lola
newtonkwan Jul 12, 2022
b409692
minor bug fix
newtonkwan Jul 12, 2022
8f42170
add changes to buffer
newtonkwan Jul 13, 2022
acc514a
merge recent main updates Merge branch 'main' into add_lola
newtonkwan Jul 13, 2022
fce83a4
update confs
newtonkwan Jul 15, 2022
d1be0c5
add naive learner
newtonkwan Jul 19, 2022
dbe82c3
pull changes from main
newtonkwan Jul 19, 2022
a855c4b
lazy commit. commiting to add naive learner PR
newtonkwan Jul 22, 2022
f11dce8
merge main
newtonkwan Jul 22, 2022
bb7b03b
add logic for lola (still debugging)
newtonkwan Jul 27, 2022
c169c75
add lola (doesn't quite work yet)
newtonkwan Jul 27, 2022
1a00280
compiling lola...
newtonkwan Jul 28, 2022
be33bc8
working lola
newtonkwan Jul 28, 2022
3cedf4e
update configs
newtonkwan Jul 28, 2022
b37aa91
tidy up
newtonkwan Jul 28, 2022
a2eb9e2
pull in main
newtonkwan Jul 29, 2022
99c7906
add working lola with new runner using lax.scan
newtonkwan Jul 29, 2022
dbb3937
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
11b98c6
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
aeb6426
lola compiles, move TrainingState to utils
newtonkwan Aug 1, 2022
c9cd40c
lastest lola
newtonkwan Aug 1, 2022
dd29be0
fix axis
newtonkwan Aug 1, 2022
78c8196
fix axis
newtonkwan Aug 1, 2022
c4b2e72
similar lola
newtonkwan Aug 1, 2022
2f30284
half working lola
newtonkwan Aug 1, 2022
4ad1b76
temporary lola
newtonkwan Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions pax/centralized_learners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Callable, List

from dm_env import TimeStep
import jax.numpy as jnp


class CentralizedLearners:
"""Interface for a set of batched agents to work with environment
Performs centralized training"""

def __init__(self, agents: list):
self.num_agents: int = len(agents)
self.agents: list = agents

def select_action(self, timesteps: List[TimeStep]) -> List[jnp.ndarray]:
assert len(timesteps) == self.num_agents
return [
agent.select_action(t) for agent, t in zip(self.agents, timesteps)
]

def in_lookahead(self, env):
"""Simulates a rollout and gradient update"""
counter = 0
for agent in self.agents:
# All other agents in a list
# i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...]
other_agents = self.agents[:counter] + self.agents[counter + 1 :]
agent.in_lookahead(env, other_agents)
counter += 1

def out_lookahead(self, env):
"""Performs a real rollout and update"""
counter = 0
for agent in self.agents:
# All other agents in a list
# i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...]
other_agents = self.agents[:counter] + self.agents[counter + 1 :]
agent.out_lookahead(env, other_agents)
counter += 1

# TODO: Obselete at the moment. This can be put into the LOLA.
def update(
self,
old_timesteps: List[TimeStep],
actions: List[jnp.ndarray],
timesteps: List[TimeStep],
) -> None:
counter = 0
for agent, t, action, t_1 in zip(
self.agents, old_timesteps, actions, timesteps
):
# All other agents in a list
# i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...]
other_agents = self.agents[:counter] + self.agents[counter + 1 :]
agent.update(t, action, t_1, other_agents)
counter += 1

def log(self, metrics: List[Callable]) -> None:
for metric, agent in zip(metrics, self.agents):
metric(agent)

def eval(self, set_flag: bool) -> None:
for agent in self.agents:
agent.eval = set_flag
79 changes: 47 additions & 32 deletions pax/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,77 +9,92 @@ hydra:
level: INFO

# Global variables
seed: 0
seed: 25
save_dir: "./exp/${wandb.group}/${wandb.name}"
debug: False

# Agents
agent1: 'Hyper'
agent2: 'NaiveLearnerEx'
agent1: 'LOLA'
agent2: 'LOLA'

# Environment
env_id: ipd
game: ipd
env_type: infinite
env_type: finite
env_discount: 0.96
payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]]
payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]]
centralized: True

# Training hyperparameters
num_envs: 4000
num_steps: 100 # number of steps per episode
total_timesteps: 1.6e9
eval_every: 0.4e9 # timesteps for update
num_envs: 128
num_steps: 150 # number of steps per episode
total_timesteps: 1_000_000
eval_every: 100_000 # eval every n episodes, not timesteps


# Useful information
# num_episodes = total_timesteps / (num_steps * num_envs)
# num_updates = num_episodes / eval_every
# batch_size = num_envs * num_steps

# DQN agent parameters
dqn:
batch_size: 256
discount: 0.99
learning_rate: 1e-2
epsilon: 0.5
replay_capacity: 100000
min_replay_size: 1000
sgd_period: 1
target_update_period: 4

# PPO agent parameters
ppo:
num_minibatches: 1
num_minibatches: 10
num_epochs: 4
gamma: 0.96
gae_lambda: 0.99
gae_lambda: 0.95
ppo_clipping_epsilon: 0.2
value_coeff: 0.5
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: False
entropy_coeff_start: 0.01
entropy_coeff_horizon: 0.8e9
entropy_coeff_end: 0.001
lr_scheduling: False
learning_rate: 4e-3
anneal_entropy: True
entropy_coeff_start: 0.2
entropy_coeff_horizon: 5_000_000
# for halfway, the horizon should (1/2) * (total_timesteps / num_envs)
entropy_coeff_end: 0.01
lr_scheduling: True
learning_rate: 2.5e-3
adam_epsilon: 1e-5
adam_eps_root: 0.
with_memory: False

# Naive Learner parameters
naive:
lr: 1.0
num_minibatches: 1
num_epochs: 1
gamma: 0.96
gae_lambda: 0.95
max_gradient_norm: 1
lr_scheduling: False
learning_rate: 1
adam_epsilon: 1e-5

# LOLA agent parameters
# lola:
# ...
lola:
use_baseline: True
adam_epsilon: 1e-5
lr_in: 0.3
lr_out: 0.2
gamma: 0.96
num_lookaheads: 1

# Logging setup
wandb:
entity: "ucl-dark"
project: ipd
group: 'MFOS-${agent1}-vs-${agent2}-${game}'
group: 'LOLA-vs-${agent2}-${game}'
name: run-seed-${seed}
log: True
log: False


# DQN agent parameters
dqn:
batch_size: 256
discount: 0.99
learning_rate: 1e-2
epsilon: 0.5
replay_capacity: 100000
min_replay_size: 1000
sgd_period: 1
target_update_period: 4
2 changes: 1 addition & 1 deletion pax/conf/experiment/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ debug: true

wandb:
group: debug
log: true
log: False

43 changes: 43 additions & 0 deletions pax/conf/experiment/lola.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# @package _global_

# Agents
agent1: 'LOLA'
agent2: 'LOLA'
centralized: True

# Environment
env_id: ipd
game: ipd
env_type: finite
env_discount: 0.96
payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]]


# Training hyperparameters
num_envs: 128
num_steps: 150 # number of steps per episode
total_timesteps: 4_000_000
eval_every: 4_000_000 # timesteps

# Useful information
# num_episodes = total_timesteps / num_steps
# num_updates = num_episodes / eval_every
# batch_size = num_envs * num_steps

# LOLA agent parameters
lola:
use_baseline: False
adam_epsilon: 1e-5
lr_in: 0.3
lr_out: 0.2
lr_value: 0.1
gamma: 0.96
num_lookaheads: 0

# Logging setup
wandb:
entity: "ucl-dark"
project: ipd
group: 'LOLA-vs-${agent2}-${game}'
name: run-seed-${seed}-${lola.num_lookaheads}-lookaheads
log: True
2 changes: 2 additions & 0 deletions pax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def step(
return self.reset()
action_1, action_2 = actions
self._num_steps += 1
# print("action_1.shape", action_1.shape)
# print("action_1", action_1)
assert action_1.shape == action_2.shape
assert action_1.shape == (self.num_envs,)

Expand Down
Loading