Skip to content

Commit

Permalink
n player fixes, n player cournot
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Aug 7, 2023
1 parent cab28b5 commit 4d9b179
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 135 deletions.
46 changes: 25 additions & 21 deletions pax/envs/cournot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,43 @@ class EnvParams:
b: float
marginal_cost: float


def to_obs_array(params: EnvParams) -> jnp.ndarray:
return jnp.array([params.a, params.b, params.marginal_cost])


class CournotGame(environment.Environment):
def __init__(self, num_inner_steps: int):
def __init__(self, num_players: int, num_inner_steps: int):
super().__init__()
self.num_players = num_players

def _step(
key: chex.PRNGKey,
state: EnvState,
actions: Tuple[float, float],
actions: Tuple[float, ...],
params: EnvParams,
):
assert len(actions) == num_players
t = state.outer_t
done = t >= num_inner_steps
key, _ = jax.random.split(key, 2)
q1 = actions[0]
q2 = actions[1]
p = params.a - params.b * (q1 + q2)
r1 = jnp.squeeze(p * q1 - params.marginal_cost * q1)
r2 = jnp.squeeze(p * q2 - params.marginal_cost * q2)
obs1 = jnp.concatenate([to_obs_array(params), jnp.array(q2), jnp.array(p)])
obs2 = jnp.concatenate([to_obs_array(params), jnp.array(q1), jnp.array(p)])

done = t >= num_inner_steps
actions = jnp.asarray(actions).squeeze()
actions = jnp.clip(actions, a_min=0)
p = params.a - params.b * actions.sum()

all_obs = []
all_rewards = []
for i in range(num_players):
q = actions[i]
r = p * q - params.marginal_cost * q
obs = jnp.concatenate([actions, jnp.array([p])])
all_obs.append(obs)
all_rewards.append(r)

state = EnvState(
inner_t=state.inner_t + 1, outer_t=state.outer_t + 1
)

return (
(obs1, obs2),
tuple(all_obs),
state,
(r1, r2),
tuple(all_rewards),
done,
{},
)
Expand All @@ -62,9 +66,9 @@ def _reset(
inner_t=jnp.zeros((), dtype=jnp.int8),
outer_t=jnp.zeros((), dtype=jnp.int8),
)
obs = jax.random.uniform(key, (2,))
obs = jnp.concatenate([to_obs_array(params), obs])
return (obs, obs), state
obs = jax.random.uniform(key, (num_players + 1,))
obs = jnp.concatenate([obs])
return tuple([obs for _ in range(num_players)]), state

self.step = jax.jit(_step)
self.reset = jax.jit(_reset)
Expand All @@ -87,7 +91,7 @@ def action_space(

def observation_space(self, params: EnvParams) -> spaces.Box:
"""Observation space of the environment."""
return spaces.Box(low=0, high=float('inf'), shape=5, dtype=jnp.float32)
return spaces.Box(low=0, high=float('inf'), shape=self.num_players + 1, dtype=jnp.float32)

@staticmethod
def nash_policy(params: EnvParams) -> float:
Expand Down
105 changes: 37 additions & 68 deletions pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,15 @@
from pax.envs.infinite_matrix_game import InfiniteMatrixGame
from pax.envs.iterated_matrix_game import EnvParams as IteratedMatrixGameParams
from pax.envs.iterated_matrix_game import IteratedMatrixGame
from pax.envs.iterated_tensor_game import EnvParams as IteratedTensorGameParams
from pax.envs.iterated_tensor_game import IteratedTensorGame
from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer
from pax.envs.iterated_tensor_game_n_player import (
EnvParams as IteratedTensorGameNPlayerParams,
)
from pax.runners.runner_eval import EvalRunner
from pax.runners.runner_eval_3player import TensorEvalRunner
from pax.runners.runner_eval_nplayer import NPlayerEvalRunner
from pax.runners.runner_evo import EvoRunner
from pax.runners.runner_evo_3player import TensorEvoRunner
from pax.runners.runner_evo_nplayer import NPlayerEvoRunner
from pax.runners.runner_ipditm_eval import IPDITMEvalRunner
from pax.runners.runner_evo_nplayer import TensorEvoRunner
from pax.runners.runner_marl import RLRunner
from pax.runners.runner_marl_3player import TensorRLRunner
from pax.runners.runner_marl_nplayer import NplayerRLRunner
from pax.runners.runner_sarl import SARLRunner
from pax.runners.runner_ipditm_eval import IPDITMEvalRunner
Expand All @@ -80,6 +74,7 @@
value_logger_ppo,
)


# NOTE: THIS MUST BE DONE BEFORE IMPORTING JAX
# uncomment to debug multi-devices on CPU
# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
Expand Down Expand Up @@ -132,19 +127,6 @@ def env_setup(args, logger=None):
f"Env Type: {args.env_type} | Inner Episode Length: {args.num_inner_steps}"
)
logger.info(f"Outer Episode Length: {args.num_outer_steps}")
elif args.env_id == "iterated_tensor_game":
payoff = jnp.array(args.payoff)

env = IteratedTensorGame(
num_inner_steps=args.num_inner_steps,
num_outer_steps=args.num_outer_steps,
)
env_params = IteratedTensorGameParams(payoff_matrix=payoff)

if logger:
logger.info(
f"Env Type: {args.env_type} s| Inner Episode Length: {args.num_inner_steps}"
)
elif args.env_id == "iterated_nplayer_tensor_game":
payoff = jnp.array(args.payoff_table)

Expand Down Expand Up @@ -210,7 +192,7 @@ def env_setup(args, logger=None):
a=args.a, b=args.b, marginal_cost=args.marginal_cost
)
env = CournotGame(
num_inner_steps=args.num_inner_steps,
num_players=args.num_players, num_inner_steps=args.num_inner_steps,
)
if logger:
logger.info(
Expand Down Expand Up @@ -243,9 +225,6 @@ def runner_setup(args, env, agents, save_dir, logger):
if args.runner == "eval":
logger.info("Evaluating with EvalRunner")
return EvalRunner(agents, env, args)
elif args.runner == "tensor_eval":
logger.info("Training with tensor eval Runner")
return TensorEvalRunner(agents, env, save_dir, args)
elif args.runner == "tensor_eval_nplayer":
logger.info("Training with n-player tensor eval Runner")
return NPlayerEvalRunner(agents, env, save_dir, args)
Expand All @@ -254,9 +233,9 @@ def runner_setup(args, env, agents, save_dir, logger):
return IPDITMEvalRunner(agents, env, save_dir, args)

if (
args.runner == "evo"
or args.runner == "tensor_evo"
or args.runner == "tensor_evo_nplayer"
args.runner == "evo"
or args.runner == "tensor_evo"
or args.runner == "tensor_evo_nplayer"
):
agent1 = agents[0]
algo = args.es.algo
Expand Down Expand Up @@ -362,23 +341,10 @@ def get_pgpe_strategy(agent):
save_dir,
args,
)
elif args.runner == "tensor_evo_nplayer":
return NPlayerEvoRunner(
agents,
env,
strategy,
es_params,
param_reshaper,
save_dir,
args,
)

elif args.runner == "rl":
logger.info("Training with RL Runner")
return RLRunner(agents, env, save_dir, args)
elif args.runner == "tensor_rl":
logger.info("Training with tensor RL Runner")
return TensorRLRunner(agents, env, save_dir, args)
elif args.runner == "tensor_rl_nplayer":
return NplayerRLRunner(agents, env, save_dir, args)
elif args.runner == "sarl":
Expand All @@ -393,9 +359,9 @@ def agent_setup(args, env, env_params, logger):
"""Set up agent variables."""

if (
args.env_id == "iterated_matrix_game"
or args.env_id == "iterated_tensor_game"
or args.env_id == "iterated_nplayer_tensor_game"
args.env_id == "iterated_matrix_game"
or args.env_id == "iterated_tensor_game"
or args.env_id == "iterated_nplayer_tensor_game"
):
obs_shape = env.observation_space(env_params).n
elif args.env_id == "InTheMatrix":
Expand Down Expand Up @@ -423,7 +389,8 @@ def get_PPO_memory_agent(seed, player_id):
)

def get_PPO_agent(seed, player_id):
player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id))
default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None)
player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args)

if player_id == 1 and args.env_type == "meta":
num_iterations = args.num_outer_steps
Expand Down Expand Up @@ -559,11 +526,11 @@ def get_stay_agent(seed, player_id):
return agent_1

else:
for i in range(1, args.num_players + 1):
assert (
omegaconf.OmegaConf.select(args, "agent" + str(i))
in strategies
)
default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None)
agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in
range(1, args.num_players + 1)]
for strategy in agent_strategies:
assert strategy in strategies

seeds = [
seed for seed in range(args.seed, args.seed + args.num_players)
Expand All @@ -574,14 +541,12 @@ def get_stay_agent(seed, player_id):
for seed, i in zip(seeds, range(1, args.num_players + 1))
]
agents = []
for i in range(args.num_players):
for idx, strategy in enumerate(agent_strategies):
agents.append(
strategies[
omegaconf.OmegaConf.select(args, "agent" + str(i + 1))
](seeds[i], pids[i])
strategies[strategy](seeds[idx], pids[idx])
)
logger.info(
f"Agent Pair: {[omegaconf.OmegaConf.select(args, 'agent' + str(i)) for i in range(1, args.num_players + 1)]}"
f"Agent Pair: {strategies}"
)
logger.info(f"Agent seeds: {seeds}")

Expand All @@ -592,7 +557,7 @@ def watcher_setup(args, logger):
"""Set up watcher variables."""

def ppo_memory_log(
agent,
agent,
):
losses = losses_ppo(agent)
if args.env_id not in [
Expand All @@ -615,6 +580,8 @@ def ppo_log(agent):
losses = losses_ppo(agent)
if args.env_id not in [
"coin_game",
"Cournot",
"Fishery",
"InTheMatrix",
"iterated_matrix_game",
"iterated_tensor_game",
Expand Down Expand Up @@ -697,13 +664,15 @@ def naive_pg_log(agent):
assert args.agent1 in strategies

agent_1_log = naive_pg_log # strategies[args.agent1] #

return agent_1_log
else:
agent_log = []
for i in range(1, args.num_players + 1):
assert getattr(args, f"agent{i}") in strategies
agent_log.append(strategies[getattr(args, f"agent{i}")])
default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None)
agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in
range(1, args.num_players + 1)]
for strategy in agent_strategies:
assert strategy in strategies
agent_log.append(strategies[strategy])
return agent_log


Expand Down Expand Up @@ -734,25 +703,25 @@ def main(args):
print(f"Number of Training Iterations: {args.num_iters}")

if (
args.runner == "evo"
or args.runner == "tensor_evo"
or args.runner == "tensor_evo_nplayer"
args.runner == "evo"
or args.runner == "tensor_evo"
or args.runner == "tensor_evo_nplayer"
):
runner.run_loop(env_params, agent_pair, args.num_iters, watchers)

elif (
args.runner == "rl"
or args.runner == "tensor_rl"
or args.runner == "tensor_rl_nplayer"
args.runner == "rl"
or args.runner == "tensor_rl"
or args.runner == "tensor_rl_nplayer"
):
# number of episodes
print(f"Number of Episodes: {args.num_iters}")
runner.run_loop(env_params, agent_pair, args.num_iters, watchers)

elif (
args.runner == "ipditm_eval"
or args.runner == "tensor_eval"
or args.runner == "tensor_eval_nplayer"
args.runner == "ipditm_eval"
or args.runner == "tensor_eval"
or args.runner == "tensor_eval_nplayer"
):
runner.run_loop(env_params, agent_pair, watchers)

Expand Down
3 changes: 1 addition & 2 deletions pax/runners/runner_eval_nplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from omegaconf import OmegaConf

import wandb
from pax.utils import MemoryState, TrainingState, save, load
from pax.utils import MemoryState, TrainingState, load
from pax.watchers import (
ipditm_stats,
n_player_ipd_visitation,
tensor_ipd_visitation,
)

MAX_WANDB_CALLS = 1000
Expand Down
3 changes: 2 additions & 1 deletion pax/runners/runner_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

# TODO: import when evosax library is updated
# from evosax.utils import ESLog
from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats, fishery_stats
from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats
from pax.watchers.fishery import fishery_stats
from pax.watchers.cournot import cournot_stats

MAX_WANDB_CALLS = 1000
Expand Down
4 changes: 2 additions & 2 deletions pax/runners/runner_evo_nplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# TODO: import when evosax library is updated
# from evosax.utils import ESLog
from pax.watchers import ESLog, cg_visitation, tensor_ipd_visitation
from pax.watchers import ESLog, cg_visitation, n_player_ipd_visitation

MAX_WANDB_CALLS = 1000

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.top_k = args.top_k
self.train_steps = 0
self.train_episodes = 0
self.ipd_stats = jax.jit(tensor_ipd_visitation)
self.ipd_stats = jax.jit(n_player_ipd_visitation)
self.cg_stats = jax.jit(jax.vmap(cg_visitation))

# Evo Runner has 3 vmap dims (popsize, num_opps, num_envs)
Expand Down
3 changes: 2 additions & 1 deletion pax/runners/runner_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import wandb
from pax.utils import MemoryState, TrainingState, save
from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats, fishery_stats
from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats
from pax.watchers.fishery import fishery_stats
from pax.watchers.cournot import cournot_stats

MAX_WANDB_CALLS = 1000
Expand Down
Loading

0 comments on commit 4d9b179

Please sign in to comment.