Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lola #11

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d5ed96
begin adding centralized learning
newtonkwan Jun 27, 2022
ec3ce6a
first commit. begin adding centralized training for LOLA
newtonkwan Jul 5, 2022
2664eea
add base lola
newtonkwan Jul 5, 2022
461f772
Merge branch 'main' into add_lola
newtonkwan Jul 5, 2022
bcb4833
add centralized learner
newtonkwan Jul 5, 2022
39de443
resolve merge conflict. add centralized learning
newtonkwan Jul 5, 2022
6c17d02
add lola machinery to experiments.py
newtonkwan Jul 5, 2022
8fab167
fix entropy annealing
newtonkwan Jul 6, 2022
21dc4fd
fix done conditiion in additional rollout step in PPO
newtonkwan Jul 7, 2022
6fe1d02
minor changes to lola
newtonkwan Jul 8, 2022
ac3a7f1
merge main with add_lola
newtonkwan Jul 12, 2022
b409692
minor bug fix
newtonkwan Jul 12, 2022
8f42170
add changes to buffer
newtonkwan Jul 13, 2022
acc514a
merge recent main updates Merge branch 'main' into add_lola
newtonkwan Jul 13, 2022
fce83a4
update confs
newtonkwan Jul 15, 2022
d1be0c5
add naive learner
newtonkwan Jul 19, 2022
dbe82c3
pull changes from main
newtonkwan Jul 19, 2022
a855c4b
lazy commit. commiting to add naive learner PR
newtonkwan Jul 22, 2022
f11dce8
merge main
newtonkwan Jul 22, 2022
bb7b03b
add logic for lola (still debugging)
newtonkwan Jul 27, 2022
c169c75
add lola (doesn't quite work yet)
newtonkwan Jul 27, 2022
1a00280
compiling lola...
newtonkwan Jul 28, 2022
be33bc8
working lola
newtonkwan Jul 28, 2022
3cedf4e
update configs
newtonkwan Jul 28, 2022
b37aa91
tidy up
newtonkwan Jul 28, 2022
a2eb9e2
pull in main
newtonkwan Jul 29, 2022
99c7906
add working lola with new runner using lax.scan
newtonkwan Jul 29, 2022
dbb3937
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
11b98c6
tidy up watchers, fix naive learner, LOLA getting exploited hard ....
newtonkwan Jul 30, 2022
aeb6426
lola compiles, move TrainingState to utils
newtonkwan Aug 1, 2022
c9cd40c
lastest lola
newtonkwan Aug 1, 2022
dd29be0
fix axis
newtonkwan Aug 1, 2022
78c8196
fix axis
newtonkwan Aug 1, 2022
c4b2e72
similar lola
newtonkwan Aug 1, 2022
2f30284
half working lola
newtonkwan Aug 1, 2022
4ad1b76
temporary lola
newtonkwan Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pax/centralized_learners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Callable, List

from dm_env import TimeStep
import jax.numpy as jnp


class CentralizedLearners:
"""Interface for a set of batched agents to work with environment
Performs centralized training"""

def __init__(self, agents: list):
self.num_agents: int = len(agents)
self.agents: list = agents

def select_action(self, timesteps: List[TimeStep]) -> List[jnp.ndarray]:
assert len(timesteps) == self.num_agents
return [
agent.select_action(t) for agent, t in zip(self.agents, timesteps)
]

def update(
self,
old_timesteps: List[TimeStep],
actions: List[jnp.ndarray],
timesteps: List[TimeStep],
) -> None:
counter = 0
for agent, t, action, t_1 in zip(
self.agents, old_timesteps, actions, timesteps
):
# All other agents in a list
# i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...]
other_agents = self.agents[:counter] + self.agents[counter + 1 :]
agent.update(t, action, t_1, other_agents)
counter += 1

def log(self, metrics: List[Callable]) -> None:
for metric, agent in zip(metrics, self.agents):
metric(agent)

def eval(self, set_flag: bool) -> None:
for agent in self.agents:
agent.eval = set_flag
11 changes: 6 additions & 5 deletions pax/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ seed: 0
save_dir: "./exp/${wandb.group}/${wandb.name}"

# Agents
agent1: 'PPO'
agent2: 'TitForTat'
agent1: 'LOLA'
agent2: 'PPO'

# Environment
env_id: ipd
game: ipd
payoff:
centralized: True

# Training hyperparameters
num_envs: 100
Expand Down Expand Up @@ -54,8 +55,8 @@ ppo:
clip_value: True
max_gradient_norm: 0.5
anneal_entropy: True
entropy_coeff_start: 0.1
entropy_coeff_horizon: 200_000_000
entropy_coeff_start: 0.2
entropy_coeff_horizon: 500_000
entropy_coeff_end: 0.01
lr_scheduling: True
learning_rate: 2.5e-2
Expand All @@ -70,6 +71,6 @@ ppo:
wandb:
entity: "ucl-dark"
project: ipd
group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-final'
group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-v3'
name: run-seed-${seed}
log: True
1 change: 1 addition & 0 deletions pax/dqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def update(
timestep: dm_env.TimeStep,
action: jnp.array,
new_timestep: dm_env.TimeStep,
other_agents=None,
):

self._replay.add_batch(
Expand Down
12 changes: 11 additions & 1 deletion pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
BattleOfTheSexes,
Chicken,
)
from pax.centralized_learners import CentralizedLearners
from pax.independent_learners import IndependentLearners
from pax.ppo.ppo import make_agent
from pax.ppo.ppo_gru import make_gru_agent
from pax.lola.lola import make_lola
from pax.runner import Runner
from pax.sac.agent import SAC
from pax.strategies import (
Expand All @@ -21,7 +23,6 @@
Random,
Human,
GrimTrigger,
# ZDExtortion,
)
from pax.utils import Section
from pax.watchers import (
Expand Down Expand Up @@ -146,6 +147,10 @@ def get_PPO_agent(seed, player_id):
)
return ppo_agent

def get_LOLA_agent(seed, player_id):
lola_agent = make_lola(seed)
return lola_agent

strategies = {
"TitForTat": TitForTat,
"Defect": Defect,
Expand All @@ -157,6 +162,7 @@ def get_PPO_agent(seed, player_id):
"SAC": get_SAC_agent,
"DQN": get_DQN_agent,
"PPO": get_PPO_agent,
"LOLA": get_LOLA_agent,
}

assert args.agent1 in strategies
Expand All @@ -177,6 +183,9 @@ def get_PPO_agent(seed, player_id):
logger.info(f"Agent Pair: {args.agent1} | {args.agent2}")
logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}")

if args.centralized:
return CentralizedLearners([agent_0, agent_1])

return IndependentLearners([agent_0, agent_1])


Expand Down Expand Up @@ -225,6 +234,7 @@ def dumb_log(agent, *args):
"SAC": sac_log,
"DQN": dqn_log,
"PPO": ppo_log,
"LOLA": dumb_log,
}

assert args.agent1 in strategies
Expand Down
6 changes: 4 additions & 2 deletions pax/independent_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


class IndependentLearners:
"Interface for a set of batched agents to work with environment"
"""Interface for a set of batched agents to work with environment
Performs independent learning
"""

def __init__(self, agents: list):
self.num_agents: int = len(agents)
Expand All @@ -23,7 +25,7 @@ def update(
actions: List[jnp.ndarray],
timesteps: List[TimeStep],
) -> None:
# might have to add some centralised training to this

for agent, t, action, t_1 in zip(
self.agents, old_timesteps, actions, timesteps
):
Expand Down
Empty file added pax/lola/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions pax/lola/lola.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Learning with Opponent-Learning Awareness (LOLA) implementation in JAX
# https://arxiv.org/pdf/1709.04326.pdf

from typing import NamedTuple, Any

from dm_env import TimeStep
import haiku as hk
import jax
import jax.numpy as jnp


class TrainingState(NamedTuple):
params: hk.Params
random_key: jnp.ndarray


class LOLA:
"""Implements LOLA with exact value functions"""

def __init__(self, network: hk.Params, random_key: jnp.ndarray):
def policy(
params: hk.Params, observation: jnp.ndarray, state: TrainingState
):
"""Determines how to choose an action"""
key, subkey = jax.random.split(state.random_key)
logits = network.apply(params, observation)
actions = jax.random.categorical(subkey, logits)
state = state._replace(random_key=key)
return actions, state

def loss():
"""Loss function"""
pass

def sgd():
"""Stochastic gradient descent"""
pass

def make_initial_state(key: jnp.ndarray) -> TrainingState:
"""Make initial training state for LOLA"""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=(1, 5))
params = network.init(subkey, dummy_obs)
return TrainingState(params=params, random_key=key)

self.state = make_initial_state(random_key)
self._policy = policy

def select_action(self, t: TimeStep):
"""Select action based on observation"""
# Unpack
params = self.state.params
state = self.state
action, self.state = self._policy(params, t.observation, state)
return action

def update(
self,
t: TimeStep,
actions: jnp.ndarray,
t_prime: TimeStep,
other_agents: list = None,
):
"""Update agent"""
# for agent in other_agents:
# other_agent_obs = agent._trajectory_buffer.observations
pass


def make_lola(seed: int) -> LOLA:
"""Instantiate LOLA"""
random_key = jax.random.PRNGKey(seed)

def forward(inputs):
"""Forward pass for LOLA"""
values = hk.Linear(2, with_bias=False)
return values(inputs)

network = hk.without_apply_rng(hk.transform(forward))

return LOLA(network=network, random_key=random_key)


if __name__ == "__main__":
lola = make_lola(seed=0)
print(f"LOLA state: {lola.state}")
timestep = TimeStep(
step_type=0,
reward=1,
discount=1,
observation=jnp.array([[1, 0, 0, 0, 0]]),
)
action = lola.select_action(timestep)
print("Action", action)
timestep = TimeStep(
step_type=0, reward=1, discount=1, observation=jnp.zeros(shape=(1, 5))
)
action = lola.select_action(timestep)
print("Action", action)
Empty file added pax/lola/network.py
Empty file.
2 changes: 2 additions & 0 deletions pax/ppo/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def reset(self):
(self._num_envs, self._num_steps, self.gru_dim)
)

self.parameters = jnp.zeros((self._num_envs, self._num_steps))


if __name__ == "__main__":
pass
6 changes: 4 additions & 2 deletions pax/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
super().__init__(name=name)
self._logit_layer = hk.Linear(
num_values,
w_init=hk.initializers.Orthogonal(0.01), # baseline
# w_init=hk.initializers.Orthogonal(0.01), # baseline
w_init=hk.initializers.Constant(0.5),
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
with_bias=False,
)
self._value_layer = hk.Linear(
1,
w_init=hk.initializers.Orthogonal(1.0), # baseline
# w_init=hk.initializers.Orthogonal(1.0), # baseline
w_init=hk.initializers.Constant(0.5),
with_bias=False,
)

Expand Down
18 changes: 13 additions & 5 deletions pax/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def loss(
fraction * entropy_coeff_start
+ (1 - fraction) * entropy_coeff_end
)

# Constant Entropy term
else:
entropy_cost = entropy_coeff_start
# else:
# entropy_cost = entropy_coeff_start
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
entropy_loss = -jnp.mean(entropy)

# Total loss: Minimize policy and value loss; maximize entropy
Expand All @@ -201,6 +202,7 @@ def loss(
"loss_policy": policy_loss,
"loss_value": value_loss,
"loss_entropy": entropy_loss,
"entropy_cost": entropy_cost,
}

@jax.jit
Expand Down Expand Up @@ -371,8 +373,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState:
dummy_obs = utils.add_batch_dim(dummy_obs)
initial_params = network.init(subkey, dummy_obs)
initial_opt_state = optimizer.init(initial_params)
# for dict_key in initial_params.keys():
# print(initial_params[dict_key])
return TrainingState(
params=initial_params,
opt_state=initial_opt_state,
Expand Down Expand Up @@ -401,6 +401,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState:
"loss_policy": 0,
"loss_value": 0,
"loss_entropy": 0,
"entropy_cost": entropy_coeff_start,
}

# Initialize functions
Expand All @@ -421,7 +422,13 @@ def select_action(self, t: TimeStep):
)
return utils.to_numpy(actions)

def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep):
def update(
self,
t: TimeStep,
actions: np.array,
t_prime: TimeStep,
other_agents=None,
):
# Adds agent and environment info to buffer
self._rollouts(
buffer=self._trajectory_buffer,
Expand Down Expand Up @@ -474,6 +481,7 @@ def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep):
self._logger.metrics["loss_policy"] = results["loss_policy"]
self._logger.metrics["loss_value"] = results["loss_value"]
self._logger.metrics["loss_entropy"] = results["loss_entropy"]
self._logger.metrics["entropy_cost"] = results["entropy_cost"]


# TODO: seed, and player_id not used in CartPole
Expand Down
11 changes: 10 additions & 1 deletion pax/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def loss(
"loss_policy": policy_loss,
"loss_value": value_loss,
"loss_entropy": entropy_loss,
"entropy_cost": entropy_cost,
newtonkwan marked this conversation as resolved.
Show resolved Hide resolved
}
# }, new_rnn_unroll_state

Expand Down Expand Up @@ -429,6 +430,7 @@ def make_initial_state(
"loss_policy": 0,
"loss_value": 0,
"loss_entropy": 0,
"entropy_cost": entropy_coeff_start,
}

# Initialize functions
Expand All @@ -450,7 +452,13 @@ def select_action(self, t: TimeStep):
)
return utils.to_numpy(actions)

def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep):
def update(
self,
t: TimeStep,
actions: np.array,
t_prime: TimeStep,
other_agents=None,
):
# Adds agent and environment info to buffer
self._rollouts(
buffer=self._trajectory_buffer,
Expand Down Expand Up @@ -497,6 +505,7 @@ def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep):
self._logger.metrics["loss_policy"] = results["loss_policy"]
self._logger.metrics["loss_value"] = results["loss_value"]
self._logger.metrics["loss_entropy"] = results["loss_entropy"]
self._logger.metrics["entropy_cost"] = results["entropy_cost"]


# TODO: seed, and player_id not used in CartPole
Expand Down
Loading