Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lola #11

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d5ed96
begin adding centralized learning
newtonkwan Jun 27, 2022
ec3ce6a
first commit. begin adding centralized training for LOLA
newtonkwan Jul 5, 2022
2664eea
add base lola
newtonkwan Jul 5, 2022
461f772
Merge branch 'main' into add_lola
newtonkwan Jul 5, 2022
bcb4833
add centralized learner
newtonkwan Jul 5, 2022
39de443
resolve merge conflict. add centralized learning
newtonkwan Jul 5, 2022
6c17d02
add lola machinery to experiments.py
newtonkwan Jul 5, 2022
8fab167
fix entropy annealing
newtonkwan Jul 6, 2022
21dc4fd
fix done conditiion in additional rollout step in PPO
newtonkwan Jul 7, 2022
6fe1d02
minor changes to lola
newtonkwan Jul 8, 2022
ac3a7f1
merge main with add_lola
newtonkwan Jul 12, 2022
b409692
minor bug fix
newtonkwan Jul 12, 2022
8f42170
add changes to buffer
newtonkwan Jul 13, 2022
acc514a
merge recent main updates Merge branch 'main' into add_lola
newtonkwan Jul 13, 2022
fce83a4
update confs
newtonkwan Jul 15, 2022
d1be0c5
add naive learner
newtonkwan Jul 19, 2022
dbe82c3
pull changes from main
newtonkwan Jul 19, 2022
a855c4b
lazy commit. commiting to add naive learner PR
newtonkwan Jul 22, 2022
f11dce8
merge main
newtonkwan Jul 22, 2022
bb7b03b
add logic for lola (still debugging)
newtonkwan Jul 27, 2022
c169c75
add lola (doesn't quite work yet)
newtonkwan Jul 27, 2022
1a00280
compiling lola...
newtonkwan Jul 28, 2022
be33bc8
working lola
newtonkwan Jul 28, 2022
3cedf4e
update configs
newtonkwan Jul 28, 2022
b37aa91
tidy up
newtonkwan Jul 28, 2022
a2eb9e2
pull in main
newtonkwan Jul 29, 2022
99c7906
add working lola with new runner using lax.scan
newtonkwan Jul 29, 2022
dbb3937
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
11b98c6
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
aeb6426
lola compiles, move TrainingState to utils
newtonkwan Aug 1, 2022
c9cd40c
lastest lola
newtonkwan Aug 1, 2022
dd29be0
fix axis
newtonkwan Aug 1, 2022
78c8196
fix axis
newtonkwan Aug 1, 2022
c4b2e72
similar lola
newtonkwan Aug 1, 2022
2f30284
half working lola
newtonkwan Aug 1, 2022
4ad1b76
temporary lola
newtonkwan Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pax/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ seed: 0
save_dir: "./exp/${wandb.group}/${wandb.name}"

# Agents
agent1: 'PPO'
agent2: 'TitForTat'
agent1: 'LOLA'
agent2: 'PPO'

# Environment
env_id: ipd
game: ipd
payoff:
centralized: True

# Training hyperparameters
num_envs: 100
Expand Down Expand Up @@ -54,8 +55,8 @@ ppo:
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.1
entropy_coeff_horizon: 200_000_000
entropy_coeff_start: 0.2
entropy_coeff_horizon: 500_000
entropy_coeff_end: 0.01
lr_scheduling: True
learning_rate: 2.5e-2
Expand All @@ -70,6 +71,6 @@ ppo:
wandb:
entity: "ucl-dark"
project: ipd
group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-final'
group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-v3'
name: run-seed-${seed}
log: True
1 change: 0 additions & 1 deletion pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Random,
Human,
GrimTrigger,
# ZDExtortion,
)
from pax.utils import Section
from pax.watchers import (
Expand Down
13 changes: 6 additions & 7 deletions pax/lola/lola.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,21 @@ def update(
t: TimeStep,
actions: jnp.ndarray,
t_prime: TimeStep,
other_agents=None,
other_agents: list = None,
):
"""Update agent"""
# an sgd step requires the parameters of the other agent.
# currently, the runner file doesn't have access to the other agent's gradients
# we could put the parameters of the agent inside the timestep
# for agent in other_agents:
# other_agent_obs = agent._trajectory_buffer.observations
pass


def make_lola(seed: int) -> LOLA:
""" "Instantiate LOLA"""
"""Instantiate LOLA"""
random_key = jax.random.PRNGKey(seed)

def forward(inputs):
"""Forward pass for LOLA exact"""
values = hk.Linear(1, with_bias=False)
"""Forward pass for LOLA"""
values = hk.Linear(2, with_bias=False)
return values(inputs)

network = hk.without_apply_rng(hk.transform(forward))
Expand Down
2 changes: 2 additions & 0 deletions pax/ppo/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def reset(self):
(self._num_envs, self._num_steps, self.gru_dim)
)

self.parameters = jnp.zeros((self._num_envs, self._num_steps))


if __name__ == "__main__":
pass
6 changes: 4 additions & 2 deletions pax/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
super().__init__(name=name)
self._logit_layer = hk.Linear(
num_values,
w_init=hk.initializers.Orthogonal(0.01), # baseline
# w_init=hk.initializers.Orthogonal(0.01), # baseline
w_init=hk.initializers.Constant(0.5),
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
with_bias=False,
)
self._value_layer = hk.Linear(
1,
w_init=hk.initializers.Orthogonal(1.0), # baseline
# w_init=hk.initializers.Orthogonal(1.0), # baseline
w_init=hk.initializers.Constant(0.5),
with_bias=False,
)

Expand Down
10 changes: 6 additions & 4 deletions pax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def loss(
fraction * entropy_coeff_start
+ (1 - fraction) * entropy_coeff_end
)

# Constant Entropy term
else:
entropy_cost = entropy_coeff_start
# else:
# entropy_cost = entropy_coeff_start
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
entropy_loss = -jnp.mean(entropy)

# Total loss: Minimize policy and value loss; maximize entropy
Expand All @@ -201,6 +202,7 @@ def loss(
"loss_policy": policy_loss,
"loss_value": value_loss,
"loss_entropy": entropy_loss,
"entropy_cost": entropy_cost,
}

@jax.jit
Expand Down Expand Up @@ -371,8 +373,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState:
dummy_obs = utils.add_batch_dim(dummy_obs)
initial_params = network.init(subkey, dummy_obs)
initial_opt_state = optimizer.init(initial_params)
# for dict_key in initial_params.keys():
# print(initial_params[dict_key])
return TrainingState(
params=initial_params,
opt_state=initial_opt_state,
Expand Down Expand Up @@ -401,6 +401,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState:
"loss_policy": 0,
"loss_value": 0,
"loss_entropy": 0,
"entropy_cost": entropy_coeff_start,
}

# Initialize functions
Expand Down Expand Up @@ -480,6 +481,7 @@ def update(
self._logger.metrics["loss_policy"] = results["loss_policy"]
self._logger.metrics["loss_value"] = results["loss_value"]
self._logger.metrics["loss_entropy"] = results["loss_entropy"]
self._logger.metrics["entropy_cost"] = results["entropy_cost"]


# TODO: seed, and player_id not used in CartPole
Expand Down
3 changes: 3 additions & 0 deletions pax/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def loss(
"loss_policy": policy_loss,
"loss_value": value_loss,
"loss_entropy": entropy_loss,
"entropy_cost": entropy_cost,
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
}
# }, new_rnn_unroll_state

Expand Down Expand Up @@ -429,6 +430,7 @@ def make_initial_state(
"loss_policy": 0,
"loss_value": 0,
"loss_entropy": 0,
"entropy_cost": entropy_coeff_start,
}

# Initialize functions
Expand Down Expand Up @@ -503,6 +505,7 @@ def update(
self._logger.metrics["loss_policy"] = results["loss_policy"]
self._logger.metrics["loss_value"] = results["loss_value"]
self._logger.metrics["loss_entropy"] = results["loss_entropy"]
self._logger.metrics["entropy_cost"] = results["entropy_cost"]


# TODO: seed, and player_id not used in CartPole
Expand Down
18 changes: 10 additions & 8 deletions pax/watchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
# five possible states
START = jnp.array([[0, 0, 0, 0, 1]])
CC = jnp.array([[1, 0, 0, 0, 0]])
CD = jnp.array([[0, 1, 0, 0, 0]])
DC = jnp.array([[0, 0, 1, 0, 0]])
DC = jnp.array([[0, 1, 0, 0, 0]])
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
CD = jnp.array([[0, 0, 1, 0, 0]])
DD = jnp.array([[0, 0, 0, 1, 0]])
STATE_NAMES = ["START", "CC", "CD", "DC", "DD"]
ALL_STATES = [START, CC, CD, DC, DD]
STATE_NAMES = ["START", "CC", "DC", "CD", "DD"]
ALL_STATES = [START, CC, DC, CD, DD]


def policy_logger(agent) -> None:
Expand Down Expand Up @@ -119,11 +119,13 @@ def ppo_losses(agent) -> None:
loss_policy = agent._logger.metrics["loss_policy"]
loss_value = agent._logger.metrics["loss_value"]
loss_entropy = agent._logger.metrics["loss_entropy"]
entropy_coefficient = agent._logger.metrics["entropy_cost"]
losses = {
"sgd_steps": sgd_steps,
"losses/total": loss_total,
"losses/policy": loss_policy,
"losses/value": loss_value,
"losses/entropy": loss_entropy,
"train/total": loss_total,
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
"train/policy": loss_policy,
"train/value": loss_value,
"train/entropy": loss_entropy,
"train/entropy_coefficient": entropy_coefficient,
}
return losses