diff --git a/pax/agents/agent.py b/pax/agents/agent.py index 8b45f2dd..9c590440 100644 --- a/pax/agents/agent.py +++ b/pax/agents/agent.py @@ -1,8 +1,9 @@ from typing import Tuple -from pax.utils import MemoryState, TrainingState import jax.numpy as jnp +from pax.utils import MemoryState, TrainingState + class AgentInterface: """Interface for agents to interact with runners and environemnts. diff --git a/pax/agents/hyper/ppo.py b/pax/agents/hyper/ppo.py index 5398df9d..cfb810b9 100644 --- a/pax/agents/hyper/ppo.py +++ b/pax/agents/hyper/ppo.py @@ -329,7 +329,9 @@ def model_update_epoch( return new_state, new_mem, metrics @jax.jit - def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: + def make_initial_state( + key: Any, hidden: jnp.ndarray + ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) dummy_obs = jnp.zeros(shape=obs_spec) diff --git a/pax/agents/lola/__init__.py b/pax/agents/lola/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pax/agents/lola/lola.py b/pax/agents/lola/lola.py new file mode 100644 index 00000000..d1d95126 --- /dev/null +++ b/pax/agents/lola/lola.py @@ -0,0 +1,867 @@ +from typing import Any, Dict, List, Mapping, NamedTuple, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax +from dm_env import TimeStep + +# from pax.lola.buffer import TrajectoryBuffer +from pax.agents.lola.network import make_network + +from pax import utils +from pax.agents.ppo.ppo_gru import PPO +from pax.runners.runner_marl import Sample +from pax.utils import MemoryState, TrainingState + + +class LOLASample(NamedTuple): + obs_self: jnp.ndarray + obs_other: List[jnp.ndarray] + actions_self: jnp.ndarray + actions_other: List[jnp.ndarray] + dones: jnp.ndarray + rewards_self: jnp.ndarray + rewards_other: List[jnp.ndarray] + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + + observations: jnp.ndarray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray + + +def magic_box(x): + return jnp.exp(x - jax.lax.stop_gradient(x)) + + +class Logger: + metrics: dict + + +class LOLA: + """LOLA with the DiCE objective function.""" + + def __init__( + self, + args, + network: NamedTuple, + outer_optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + player_id: int, + obs_spec: Tuple, + env_params: Any, + env_step, + env_reset, + num_envs: int = 4, + num_steps: int = 150, + use_baseline: bool = True, + gamma: float = 0.96, + ): + self._num_envs = num_envs # number of environments + self._num_opps = args.num_opps + self.env_step = env_step + self.env_reset = env_reset + # self.agent2 = None + self.other_agents = None + self.args = args + + @jax.jit + def policy( + state: TrainingState, observation: jnp.ndarray, mem: MemoryState + ): + """Agent policy to select actions and calculate agent specific information""" + + key, subkey = jax.random.split(state.random_key) + dist, values = network.apply(state.params, observation) + actions = dist.sample(seed=subkey) + mem.extras["values"] = values + mem.extras["log_probs"] = dist.log_prob(actions) + state = state._replace(random_key=key) + return actions, state, mem + + def outer_loss(params, mem, other_params, other_mems, samples): + """Used for the outer rollout""" + # Unpack the samples + obs_1 = samples.obs_self + other_obs = samples.obs_other + + # we care about our own rewards + self_rewards = samples.rewards_self + actions_1 = samples.actions_self + other_actions = samples.actions_other + # jax.debug.breakpoint() + + # Get distribution and value using my network + distribution, values = self.network.apply(params, obs_1) + self_log_prob = distribution.log_prob(actions_1) + + # Get distribution and value using other player's network + other_log_probs = [] + for idx, agent in enumerate(self.other_agents): + if self.args.agent2 == "PPO_memory": + (distribution, _,), _ = agent.network.apply( + other_params[idx], + other_obs[idx], + other_mems[idx].hidden, + ) + else: + distribution, _ = agent.network.apply( + other_params[idx], other_obs[idx] + ) + other_log_probs.append( + distribution.log_prob(other_actions[idx]) + ) + + # flatten opponent and num_envs into one dimension + + # apply discount: + cum_discount = ( + jnp.cumprod(self.gamma * jnp.ones(self_rewards.shape), axis=0) + / self.gamma + ) + + discounted_rewards = self_rewards * cum_discount + discounted_values = values * cum_discount + # jax.debug.breakpoint() + # TODO no clue if this makes any sense + # stochastics nodes involved in rewards dependencies: + sum_other_log_probs = jnp.sum( + jnp.stack(other_log_probs, axis=0), axis=0 + ) + dependencies = jnp.cumsum( + self_log_prob + sum_other_log_probs, axis=0 + ) + # logprob of each stochastic nodes: + stochastic_nodes = self_log_prob + sum_other_log_probs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=0) + ) + + if use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=0, + ) + ) + dice_objective = dice_objective + baseline_term + + G_ts = reverse_cumsum(discounted_rewards, axis=0) + R_ts = G_ts / cum_discount + # # want to minimize this value + value_objective = jnp.mean((R_ts - values) ** 2) + + # want to maximize this objective + loss_total = -dice_objective + value_objective + return loss_total, { + "loss_total": -dice_objective + value_objective, + "loss_policy": -dice_objective, + "loss_value": value_objective, + } + + def inner_loss( + chosen_op_params, + chosen_op_idx, + lola_params, + mem, + other_params, + other_mems, + samples, + ): + """Used for the inner rollout""" + obs_1 = samples.obs_self + other_obs = samples.obs_other + + # we care about the chosen opponent player's rewards + self_rewards = samples.rewards_self + other_rewards = samples.rewards_other + actions_1 = samples.actions_self + other_actions = samples.actions_other + + # Get distribution and valwue using my network + distribution, _ = self.network.apply(lola_params, obs_1) + self_log_prob = distribution.log_prob(actions_1) + + other_log_probs = [] + other_values = [] + for idx, agent in enumerate(self.other_agents): + # treating this case separately cause we want the grads on the chosen other players params + if idx == chosen_op_idx: + if self.args.agent2 == "PPO_memory": + ( + distribution, + values, + ), hidden_state = agent.network.apply( + chosen_op_params, + other_obs[idx], + other_mems[idx].hidden, + ) + + else: + distribution, values = agent.network.apply( + chosen_op_params, other_obs[idx] + ) + else: + if self.args.agent2 == "PPO_memory": + ( + distribution, + values, + ), hidden_state = agent.network.apply( + other_params[idx], + other_obs[idx], + other_mems[idx].hidden, + ) + + else: + distribution, values = agent.network.apply( + other_params[idx], other_obs[idx] + ) + other_values.append(values) + other_log_probs.append( + distribution.log_prob(other_actions[idx]) + ) + # apply discount: + cum_discount = ( + jnp.cumprod(self.gamma * jnp.ones(self_rewards.shape), axis=0) + / self.gamma + ) + discounted_rewards = [ + reward * cum_discount for reward in other_rewards + ] + discounted_values = [ + values * cum_discount for values in other_values + ] + # TODO do actual maths here - no idea what this is doing + # stochastics nodes involved in rewards dependencies: + # dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=0) + # # logprob of each stochastic nodes: + # stochastic_nodes = self_log_prob + other_log_prob + sum_other_log_probs = jnp.sum( + jnp.stack(other_log_probs, axis=0), axis=0 + ) + dependencies = jnp.cumsum( + self_log_prob + sum_other_log_probs, axis=0 + ) + # logprob of each stochastic nodes: + stochastic_nodes = self_log_prob + sum_other_log_probs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum( + magic_box(dependencies) + * discounted_rewards[chosen_op_idx], + axis=0, + ) + ) + + if use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) + * discounted_values[chosen_op_idx], + axis=0, + ) + ) + dice_objective = dice_objective + baseline_term + + G_ts = reverse_cumsum(discounted_rewards[chosen_op_idx], axis=0) + R_ts = G_ts / cum_discount + # # want to minimize this value + value_objective = jnp.mean( + (R_ts - other_values[chosen_op_idx]) ** 2 + ) + + # want to maximize this objective + loss_total = -dice_objective + value_objective + + return loss_total, { + "loss_total": -dice_objective + value_objective, + "loss_policy": -dice_objective, + "loss_value": value_objective, + } + + def make_initial_state(key: Any, hidden) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key, subkey = jax.random.split(key) + dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init(subkey, dummy_obs) + initial_opt_state = outer_optimizer.init(initial_params) + return TrainingState( + params=initial_params, + opt_state=initial_opt_state, + random_key=key, + timesteps=0, + ), MemoryState( + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + hidden=jnp.zeros((self._num_envs, 1)), + ) + + inner_rollout_rng, random_key = jax.random.split(random_key) + self.inner_rollout_rng = inner_rollout_rng + # Initialise training state (parameters, optimiser state, extras). + self._state, self._mem = make_initial_state(random_key, obs_spec) + + self.make_initial_state = make_initial_state + + # Setup player id + self.player_id = player_id + + self.env_params = env_params + + grad_inner = jax.grad(inner_loss, has_aux=True, argnums=0) + grad_outer = jax.grad(outer_loss, has_aux=True, argnums=0) + # vmap over num_envs + self.grad_fn_inner = jax.jit( + jax.vmap(grad_inner, (None, None, None, 0, None, 0, 0), (0, 0)), + static_argnums=1, + ) + self.grad_fn_outer = jax.jit( + jax.vmap( + jax.vmap(grad_outer, (None, 0, None, 0, 0), (0, 0)), + (None, 0, 0, 0, 0), + (0, 0), + ) + ) + + # Set up counters and logger + self._logger = Logger() + self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + } + + # Initialize functions + self._policy = policy + self.network = network + + # initialize some variables + self._outer_optimizer = outer_optimizer + self.gamma = gamma + + # Other useful hyperparameters + self._num_steps = num_steps # number of steps per environment + self._batch_size = int(num_envs * num_steps) # number in one batch + self._obs_spec = obs_spec + + # def select_action(self, state, t: TimeStep): + # """Selects action and updates info with PPO specific information""" + # actions, state = self._policy( + # state.params, t.observation, state + # ) + # utils.to_numpy(actions), state + + def in_lookahead(self, rng, my_state, my_mem, other_states, other_mems): + """ + Performs a rollout using the current parameters of both agents + and simulates a naive learning update step for the other agent + + INPUT: + env: SequentialMatrixGame, an environment object of the game being played + """ + + # do a full rollout + # we want to play num_envs games at once wiht one opponent + rng, reset_rng = jax.random.split(rng) + reset_rngs = jax.random.split(reset_rng, self._num_envs).reshape( + (self._num_envs, -1) + ) + + batch_reset = jax.vmap(self.env_reset, (0, None), 0) + obs, env_state = batch_reset(reset_rngs, self.env_params) + + rewards = [jnp.zeros(self._num_envs)] * self.args.num_players + + inner_rollout_rng, rng = jax.random.split(rng) + inner_rollout_rngs = jax.random.split( + inner_rollout_rng, self._num_envs + ).reshape((self._num_envs, -1)) + batch_step = jax.vmap(self.env_step, (0, 0, 0, None), 0) + batch_policy1 = jax.vmap(self._policy, (None, 0, 0), (0, None, 0)) + batch_policies = [ + jax.vmap(agent._policy, (None, 0, 0), (0, None, 0)) + for agent in self.other_agents + ] + # batch_policy2 = jax.vmap(self.agent2._policy, (None, 0, 0), (0, None, 0)) + + def lola_inlookahead_rollout(carry, unused): + """Runner for inner episode""" + + ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_states, + first_agent_mem, + other_agent_mems, + env_state, + env_params, + ) = carry + # unpack rngs + + # this fn is not batched over num_envs! + vmap_split = jax.vmap(jax.random.split, (0, None), 0) + rngs = vmap_split(rngs, 4) + + env_rng = rngs[:, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, 3, :] + + actions = [] + ( + first_action, + first_agent_state, + new_first_agent_mem, + ) = batch_policy1( + first_agent_state, + first_agent_obs, + first_agent_mem, + ) + actions.append(first_action) + new_other_agent_mems = [None] * len(self.other_agents) + for agent_idx, other_policy in enumerate(batch_policies): + ( + non_first_action, + other_agent_states[agent_idx], + new_other_agent_mems[agent_idx], + ) = other_policy( + other_agent_states[agent_idx], + other_agent_obs[agent_idx], + other_agent_mems[agent_idx], + ) + actions.append(non_first_action) + + ( + all_agent_next_obs, + env_state, + all_agent_rewards, + done, + info, + ) = batch_step( + env_rng, + env_state, + actions, + env_params, + ) + first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs + first_agent_reward, *other_agent_rewards = all_agent_rewards + traj1 = LOLASample( + first_agent_next_obs, + other_agent_next_obs, + actions[0], + actions[1:], + done, + first_agent_reward, + other_agent_rewards, + ) + + other_traj = [ + Sample( + other_agent_obs[agent_idx], + actions[agent_idx + 1], + other_agent_rewards[agent_idx], + new_other_agent_mems[agent_idx].extras["log_probs"], + new_other_agent_mems[agent_idx].extras["values"], + done, + other_agent_mems[agent_idx].hidden, + ) + for agent_idx in range(len(self.other_agents)) + ] + return ( + rngs, + first_agent_next_obs, + tuple(other_agent_next_obs), + first_agent_reward, + tuple(other_agent_rewards), + first_agent_state, + other_agent_states, + new_first_agent_mem, + new_other_agent_mems, + env_state, + env_params, + ), (traj1, *other_traj) + + # jax.debug.breakpoint() + + carry, trajectories = jax.lax.scan( + lola_inlookahead_rollout, + ( + inner_rollout_rngs, + obs[0], + tuple(obs[1:]), + rewards[0], + tuple(rewards[1:]), + my_state, + other_states, + my_mem, + other_mems, + env_state, + self.env_params, + ), + None, + length=self._num_steps, # num_inner_steps + ) + my_mem = carry[7] + other_mems = carry[8] + # jax.debug.breakpoint() + # flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs + vmap_trajectories = jax.tree_map( + lambda x: jnp.swapaxes(x, 0, 1), trajectories + ) + + sample = LOLASample( + obs_self=vmap_trajectories[0].obs_self, + obs_other=[traj.observations for traj in vmap_trajectories[1:]], + actions_self=vmap_trajectories[0].actions_self, + actions_other=[traj.actions for traj in vmap_trajectories[1:]], + dones=vmap_trajectories[0].dones, + rewards_self=vmap_trajectories[0].rewards_self, + rewards_other=[traj.rewards for traj in vmap_trajectories[1:]], + ) + # jax.debug.breakpoint() + # get gradients of opponents + other_gradients = [] + for idx in range(len((self.other_agents))): + chosen_op_idx = idx + chosen_op_params = other_states[idx].params + + gradient, _ = self.grad_fn_inner( + chosen_op_params, + chosen_op_idx, + my_state.params, + my_mem, + [state.params for state in other_states], + other_mems, + sample, + ) + other_gradients.append(gradient) + # avg over numenvs + other_gradients = [ + jax.tree_map(lambda x: x.mean(axis=0), other_gradient) + for other_gradient in other_gradients + ] + + # Update the optimizer + new_other_states = [] + for idx, agent in enumerate(self.other_agents): + updates, opt_state = agent.optimizer.update( + other_gradients[idx], other_states[idx].opt_state + ) + # apply the optimizer updates + params = optax.apply_updates(other_states[idx].params, updates) + + # replace the other player's current parameters with a simulated update + new_other_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=other_states[idx].random_key, + timesteps=other_states[idx].timesteps, + ) + new_other_states.append(new_other_state) + + return new_other_states, other_mems + + def out_lookahead(self, rng, my_state, my_mem, other_states, other_mems): + """ + Performs a real rollout using the current parameters of both agents + and a naive learning update step for the other agent + + INPUT: + env: SequentialMatrixGame, an environment object of the game being played + other_agents: list, a list of objects of the other agents + """ + + rng, reset_rng = jax.random.split(rng) + reset_rngs = jax.random.split( + reset_rng, self._num_envs * self._num_opps + ).reshape((self._num_opps, self._num_envs, -1)) + + batch_reset = jax.vmap( + jax.vmap(self.env_reset, (0, None), 0), (0, None), 0 + ) + obs, env_state = batch_reset(reset_rngs, self.env_params) + + rewards = [ + jnp.zeros((self._num_opps, self._num_envs)), + ] * self.args.num_players + + inner_rollout_rng, _ = jax.random.split(rng) + inner_rollout_rngs = jax.random.split( + inner_rollout_rng, self._num_envs * self._num_opps + ).reshape((self._num_opps, self._num_envs, -1)) + batch_step = jax.vmap( + jax.vmap(self.env_step, (0, 0, 0, None), 0), (0, 0, 0, None), 0 + ) + + def lola_outlookahead_rollout(carry, unused): + """Runner for inner episode""" + + ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_states, + first_agent_mem, + other_agent_mems, + env_state, + env_params, + ) = carry + + # unpack rngs + + vmap_split = jax.vmap( + jax.vmap(jax.random.split, (0, None), 0), (0, None), 0 + ) + rngs = vmap_split(rngs, 4) + + env_rng = rngs[:, :, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, :, 3, :] + + batch_policy1 = jax.vmap(self._policy, (None, 0, 0), (0, None, 0)) + batch_policies = [ + jax.vmap( + jax.vmap(agent._policy, (None, 0, 0), (0, None, 0)), + (0, 0, 0), + (0, 0, 0), + ) + for agent in self.other_agents + ] + actions = [] + ( + first_action, + first_agent_state, + new_first_agent_mem, + ) = batch_policy1( + first_agent_state, + first_agent_obs, + first_agent_mem, + ) + actions.append(first_action) + new_other_agent_mems = [None] * len(self.other_agents) + for agent_idx, other_policy in enumerate(batch_policies): + ( + non_first_action, + other_agent_states[agent_idx], + new_other_agent_mems[agent_idx], + ) = other_policy( + other_agent_states[agent_idx], + other_agent_obs[agent_idx], + other_agent_mems[agent_idx], + ) + actions.append(non_first_action) + + ( + all_agent_next_obs, + env_state, + all_agent_rewards, + done, + info, + ) = batch_step( + env_rng, + env_state, + actions, + env_params, + ) + + first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs + first_agent_reward, *other_agent_rewards = all_agent_rewards + traj1 = LOLASample( + first_agent_obs, + other_agent_obs, + actions[0], + actions[1:], + done, + first_agent_reward, + other_agent_rewards, + ) + + other_traj = [ + Sample( + other_agent_obs[agent_idx], + actions[agent_idx + 1], + other_agent_rewards[agent_idx], + new_other_agent_mems[agent_idx].extras["log_probs"], + new_other_agent_mems[agent_idx].extras["values"], + done, + other_agent_mems[agent_idx].hidden, + ) + for agent_idx in range(len(self.other_agents)) + ] + return ( + rngs, + first_agent_next_obs, + tuple(other_agent_next_obs), + first_agent_reward, + tuple(other_agent_rewards), + first_agent_state, + other_agent_states, + new_first_agent_mem, + new_other_agent_mems, + env_state, + env_params, + ), (traj1, *other_traj) + + # do a full rollout + _, trajectories = jax.lax.scan( + lola_outlookahead_rollout, + ( + inner_rollout_rngs, + obs[0], + tuple(obs[1:]), + rewards[0], + tuple(rewards[1:]), + my_state, + other_states, + my_mem, + other_mems, + env_state, + self.env_params, + ), + None, + length=self._num_steps, + ) + # num_inner, num_opps, num_envs to num_opps, num_envs, num_inner + vmap_trajectories = jax.tree_map( + lambda x: jnp.moveaxis(x, 0, 2), trajectories + ) + sample = LOLASample( + obs_self=vmap_trajectories[0].obs_self, + obs_other=[traj.observations for traj in vmap_trajectories[1:]], + actions_self=vmap_trajectories[0].actions_self, + actions_other=[traj.actions for traj in vmap_trajectories[1:]], + dones=vmap_trajectories[0].dones, + rewards_self=vmap_trajectories[0].rewards_self, + rewards_other=[traj.rewards for traj in vmap_trajectories[1:]], + ) + # print("Before updating") + # print("---------------------") + # print("params", self._state.params) + # print("opt_state", self._state.opt_state) + # print() + # calculate the gradients + gradients, results = self.grad_fn_outer( + my_state.params, + my_mem, + [state.params for state in other_states], + other_mems, + sample, + ) + gradients = jax.tree_map(lambda x: x.mean(axis=(0, 1)), gradients) + # print("Gradients", gradients) + # print() + # Update the optimizer + updates, opt_state = self._outer_optimizer.update( + gradients, my_state.opt_state + ) + # print("Updates", updates) + # print("Updated optimizer", opt_state) + # print() + + # apply the optimizer updates + params = optax.apply_updates(my_state.params, updates) + + # Update internal agent's timesteps + self._total_steps += self._num_envs + self._logger.metrics["total_steps"] += self._num_envs + self._state._replace(timesteps=self._total_steps) + + # Logging + self._logger.metrics["sgd_steps"] += 1 + self._logger.metrics["loss_total"] = results["loss_total"] + self._logger.metrics["loss_policy"] = results["loss_policy"] + self._logger.metrics["loss_value"] = results["loss_value"] + + # replace the player's current parameters with a real update + new_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=self._state.random_key, + timesteps=self._state.timesteps, + ) + return new_state + + def reset_memory(self, memory, eval=False) -> MemoryState: + num_envs = 1 if eval else self._num_envs + memory = memory._replace( + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + hidden=jnp.zeros((self._num_envs, 1)), + ) + return memory + + +def make_lola( + args, + obs_spec, + action_spec, + seed: int, + player_id: int, + env_params: Any, + env_step, + env_reset, +): + """Make Naive Learner Policy Gradient agent""" + # Create Haiku network + network = make_network(action_spec) + # Outer optimizer uses Adam + outer_optimizer = optax.adam(args.lola.lr_out) + # Random key + random_key = jax.random.PRNGKey(seed=seed) + + return LOLA( + args=args, + network=network, + outer_optimizer=outer_optimizer, + random_key=random_key, + obs_spec=obs_spec, + env_params=env_params, + env_step=env_step, + env_reset=env_reset, + player_id=player_id, + num_envs=args.num_envs, + num_steps=args.num_inner_steps, + use_baseline=args.lola.use_baseline, + gamma=args.lola.gamma, + ) + + +def reverse_cumsum(x, axis): + return x + jnp.sum(x, axis=axis, keepdims=True) - jnp.cumsum(x, axis=axis) + + +if __name__ == "__main__": + pass diff --git a/pax/agents/lola/network.py b/pax/agents/lola/network.py new file mode 100644 index 00000000..bc32ac44 --- /dev/null +++ b/pax/agents/lola/network.py @@ -0,0 +1,159 @@ +from typing import Optional + +import distrax +import haiku as hk +import jax +import jax.numpy as jnp + +from pax import utils + + +class CategoricalValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0), + # w_init=hk.initializers.RandomNormal(), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0), + # w_init=hk.initializers.RandomNormal(), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + logits = self._logit_layer(inputs) + # logits = jax.nn.sigmoid(self._logit_layer(inputs)) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (distrax.Categorical(logits=logits), value) + + +class BernoulliValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + # matching the way that they do it. + logits = jnp.squeeze(self._logit_layer(inputs), axis=-1) + probs = jax.nn.sigmoid(logits) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (distrax.Bernoulli(probs=1 - probs), value) + + +class PolicyHead(hk.Module): + """Network head that produces a categorical distribution.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + logits = jnp.squeeze(self._logit_layer(inputs), axis=-1) + probs = jax.nn.sigmoid(logits) + return distrax.Bernoulli(probs=1 - probs) + + +class ValueHead(hk.Module): + """Network head that produces a value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return value + + +def make_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + CategoricalValueHead(num_values=num_actions), + # BernoulliValueHead(num_values=1), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + +def make_policy_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + PolicyHead(num_values=1), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + +def make_value_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + ValueHead(num_values=1), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network diff --git a/pax/agents/naive_exact.py b/pax/agents/naive_exact.py index 074d85e9..72f04a8f 100644 --- a/pax/agents/naive_exact.py +++ b/pax/agents/naive_exact.py @@ -2,8 +2,8 @@ import jax import jax.numpy as jnp -from pax.agents.agent import AgentInterface +from pax.agents.agent import AgentInterface from pax.envs.infinite_matrix_game import EnvParams as InfiniteMatrixGameParams from pax.utils import MemoryState diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index fa839104..49fdf7ef 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -10,10 +10,10 @@ from pax import utils from pax.agents.agent import AgentInterface from pax.agents.ppo.networks import ( - make_ipditm_network, - make_sarl_network, make_coingame_network, make_ipd_network, + make_ipditm_network, + make_sarl_network, ) from pax.utils import Logger, MemoryState, TrainingState, get_advantages @@ -336,7 +336,9 @@ def model_update_epoch( return new_state, new_memory, metrics - def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: + def make_initial_state( + key: Any, hidden: jnp.ndarray + ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) @@ -357,6 +359,7 @@ def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: dummy_obs = utils.add_batch_dim(dummy_obs) initial_params = network.init(subkey, dummy_obs) initial_opt_state = optimizer.init(initial_params) + self.optimizer = optimizer return TrainingState( random_key=key, params=initial_params, @@ -413,6 +416,7 @@ def prepare_batch( # Initialize functions self._policy = policy self.player_id = player_id + self.network = network # Other useful hyperparameters self._num_envs = num_envs # number of environments diff --git a/pax/agents/ppo/ppo_gru.py b/pax/agents/ppo/ppo_gru.py index d8fbfeec..2944a9d6 100644 --- a/pax/agents/ppo/ppo_gru.py +++ b/pax/agents/ppo/ppo_gru.py @@ -377,6 +377,7 @@ def make_initial_state( subkey, dummy_obs, initial_hidden_state ) initial_opt_state = optimizer.init(initial_params) + self.optimizer = optimizer return TrainingState( random_key=key, params=initial_params, @@ -441,6 +442,7 @@ def prepare_batch( } # Initialize functions + self.network = network self._policy = policy self.forward = network.apply self.player_id = player_id @@ -471,8 +473,7 @@ def update( ): """Update the agent -> only called at the end of a trajectory""" - - _, _, mem = self._policy(state, obs, mem) + _1, _2, mem = self._policy(state, obs, mem) traj_batch = self._prepare_batch( traj_batch, traj_batch.dones[-1, ...], mem.extras ) @@ -513,7 +514,13 @@ def make_gru_agent( agent_args.output_channels, agent_args.kernel_shape, ) - elif args.env_id == "iterated_matrix_game": + elif args.env_id in [ + "iterated_matrix_game", + "iterated_tensor_game", + "iterated_nplayer_tensor_game", + "third_party_punishment", + "third_party_random", + ]: network, initial_hidden_state = make_GRU_ipd_network( action_spec, agent_args.hidden_size ) diff --git a/pax/agents/strategies.py b/pax/agents/strategies.py index 8a93627a..cda6ebf5 100644 --- a/pax/agents/strategies.py +++ b/pax/agents/strategies.py @@ -4,8 +4,8 @@ import jax.numpy as jnp import jax.random -from pax.agents.agent import AgentInterface +from pax.agents.agent import AgentInterface from pax.utils import Logger, MemoryState, TrainingState # states are [CC, CD, DC, DD, START] @@ -381,7 +381,7 @@ def _policy( def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: # now either 0, 1, 2, 3 - batch_size, _ = obs.shape + # batch_size, _ = obs.shape obs = obs.argmax(axis=-1) # if 0 | 2 | 4 -> C # if 1 | 3 -> D @@ -488,7 +488,7 @@ def make_initial_state(self, _unused, *args) -> TrainingState: class Stay(AgentInterface): - def __init__(self, num_actions: int, num_envs: int): + def __init__(self, num_actions: int, num_envs: int, num_players: int = 2): self.make_initial_state = initial_state_fun(num_envs) self._state, self._mem = self.make_initial_state(None, None) self._logger = Logger() diff --git a/pax/agents/tensor_strategies.py b/pax/agents/tensor_strategies.py new file mode 100644 index 00000000..7d77bbed --- /dev/null +++ b/pax/agents/tensor_strategies.py @@ -0,0 +1,249 @@ +from functools import partial +from re import A +from typing import Callable, NamedTuple + +import jax.numpy as jnp +import jax.random + +from pax.agents.agent import AgentInterface +from pax.agents.strategies import initial_state_fun +from pax.utils import Logger, MemoryState, TrainingState + + +class TitForTatStrictStay(AgentInterface): + # Switch to what opponents did if the other two played the same move + # otherwise play as before + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + obs = obs.argmax(axis=-1) + # in state 0-3 we cooped, 4-7 we defected + # 0 is cooperate, 1 is defect + # default is we play the same as before + action = jnp.where(obs > 3, 1, 0) + # if obs is 0 or 4, they both cooperated and we cooperate next round + # if obs is 8, we start by cooperating + action = jnp.where(obs % 4 == 0, 0, action) + # if obs is 3 or 7, they both defected and we defect next round + action = jnp.where(obs % 4 == 3, 1, action) + return action + + +class TitForTatStrictSwitch(AgentInterface): + # Switch to what opponents did if the other two played the same move + # otherwise switch + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + obs = obs.argmax(axis=-1) + # in state 0-3 we cooped, 4-7 we defected + # 0 is cooperate, 1 is defect + # default is switch from before + action = jnp.where(obs > 3, 0, 1) + # if obs is 0 or 4, they both cooperated and we cooperate next round + # if obs is 8, we start by cooperating + action = jnp.where(obs % 4 == 0, 0, action) + # if obs is 3 or 7, they both defected and we defect next round + action = jnp.where(obs % 4 == 3, 1, action) + return action + + +class TitForTatCooperate(AgentInterface): + # Cooperate unless they both defect + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + obs = obs.argmax(axis=-1) + # in state 3 and 7 they both defected, we defect + # otherwise we cooperate + # 0 is cooperate, 1 is defect + action = jnp.where(obs % 4 == 3, 1, 0) + return action + + +class TitForTatDefect(AgentInterface): + # Defect unless they both cooperate + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + obs = obs.argmax(axis=-1) + # in state 0 and 4 they both cooperated, we cooperate + # state 8 is start, we cooperate + # otherwise we defect + # 0 is cooperate, 1 is defect + action = jnp.where(obs % 4 == 0, 0, 1) + return action + + +class TitForTatHarsh(AgentInterface): + # Defect unless everyone cooperates, works for n players + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + num_players = jnp.log2(obs.shape[-1] - 1).astype(int) + obs = obs.argmax(axis=-1) + + # 0th state is all c...cc + # 1st state is all c...cd + # (2**num_playerth)-1 state is d...dd + # we cooperate in states _CCCC (0th and (2**num_player-1)th) and in start state (state 2**num_player)th ) + # otherwise we defect + # 0 is cooperate, 1 is defect + action = jnp.where(obs % jnp.exp2(num_players - 1) == 0, 0, 1) + return action + + +class TitForTatSoft(AgentInterface): + # Defect if majority defects, works for n players + def __init__(self, num_envs, *args): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self._logger = Logger() + self._logger.metrics = {} + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return mem + + @partial(jax.jit, static_argnums=(0,)) + def _policy( + self, + state: NamedTuple, + obs: jnp.ndarray, + mem: NamedTuple, + ) -> tuple[jnp.ndarray, NamedTuple, NamedTuple]: + # state is [batch x time_step x num_players] + # return [batch] + return self._reciprocity(obs), state, mem + + def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: + num_players = jnp.log2(obs.shape[-1] - 1).astype(jnp.int32) + max_defect_allowed = (num_players - 1) / 2 + # from one_hot to int + obs = jnp.argmax(obs, axis=-1, keepdims=True).astype(jnp.uint8) + # convert to binary to get actions of opponents + # -num_player-1th is the start state indicator not action + obs_actions = jnp.unpackbits(obs, axis=-1) + num_defects = obs_actions.sum(axis=-1) + # substract our own actions and make sure start state has 0 defect + # our actions are on obs_actions[...,-num_players] + # start state is on obs_actions[...,-num_players-1] + opps_defect = ( + num_defects + - obs_actions[..., -num_players] + - obs_actions[..., -num_players - 1] + ) + # if more than half of them are 1, we defect + # num_defects = jnp.sum(cutoff_obs_actions,axis=-1) + assert opps_defect.shape == obs.shape[:-1] + action = jnp.where(opps_defect <= max_defect_allowed, 0, 1).astype( + jnp.int32 + ) + return action diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index e2e8350f..36b25085 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -16,6 +16,8 @@ save_interval: 100 debug: False # Agents +num_players: 2 +num_shapers: 1 agent1: 'PPO' agent2: 'PPO' diff --git a/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd.yaml b/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd.yaml new file mode 100644 index 00000000..1b6a151e --- /dev/null +++ b/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd.yaml @@ -0,0 +1,132 @@ +# @package _global_ +# two shapers, both trained against PPO mem agents playing each other + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent3: 'PPO_memory' + + +# Environment +env_id: iterated_nplayer_tensor_game +env_type: meta +env_discount: 0.96 + +# Runner +num_players: 3 +num_shapers: 2 + +# rows are (C,D) payoffs depending on number of D in total +# top is 0D payoff, bottom is 0C, +# one more player plays D going downwards +payoff_table: [ + [ 4 , 1000 ], + [ 2 , 5 ], + [ 0 , 3 ], + [ -1000 , 1 ], +] +# Runner +runner: multishaper_evo + +# Training +top_k: 5 +popsize: 100 +num_envs: 2 +num_opps: 10 +num_inner_steps: 100 +num_outer_steps: 1000 +num_iters: 200000 +num_devices: 1 + +# PPO agent parameters +# PPO agent parameters +ppo1: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 +ppo2: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 +ppo3: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean +# Logging setup +wandb: + entity: "ucl-dark" + project: tensor-ipd + group: 'test' + name: 3pl_2shap_ipd_${seed} + log: True + diff --git a/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd_eval.yaml b/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd_eval.yaml new file mode 100644 index 00000000..298353a8 --- /dev/null +++ b/pax/conf/experiment/multiplayer_ipd/3pl_2shap_ipd_eval.yaml @@ -0,0 +1,115 @@ +# @package _global_ +# two shapers, both trained against PPO mem agents playing each other + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent3: 'PPO_memory' + + +# Environment +env_id: iterated_nplayer_tensor_game +env_type: meta +env_discount: 0.96 + +# Runner +num_players: 3 +num_shapers: 2 +payoff_table: [ + [ 4 , 1000 ], + [ 2 , 5 ], + [ 0 , 3 ], + [ -1000 , 1 ], +] + + +# # Runner +runner: multishaper_eval +# paths for loading shaper agents +run_path1: ucl-dark/tensor-ipd/ab3sovla +model_path1: exp/3pl-2shap-ipd/3pl_2shap_ipd_1/2023-08-01_19.28.45.983999/generation_1999_agent_0 +run_path2: ucl-dark/tensor-ipd/ab3sovla +model_path2: exp/3pl-2shap-ipd/3pl_2shap_ipd_1/2023-08-01_19.28.45.983999/generation_1999_agent_1 + + +# Training +num_envs: 2 +num_opps: 10 +num_inner_steps: 100 +num_outer_steps: 1000 +num_iters: 1 # shouldn't do anything, just for lr scheduler to not complain +# total_timesteps: 2.5e7 +num_devices: 1 + +# PPO agent parameters +# PPO agent parameters +ppo1: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 +ppo2: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 +ppo3: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.4e9 + entropy_coeff_end: 0.01 + lr_scheduling: False + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + separate: False + hidden_size: 16 + + +# Logging setup +wandb: + entity: "ucl-dark" + project: tensor-ipd + group: 'test' + name: 3pl_2shap_ipd_1-eval + log: True + diff --git a/pax/conf/experiment/multiplayer_ipd/lola_vs_ppo_ipd.yaml b/pax/conf/experiment/multiplayer_ipd/lola_vs_ppo_ipd.yaml new file mode 100644 index 00000000..cc79c85b --- /dev/null +++ b/pax/conf/experiment/multiplayer_ipd/lola_vs_ppo_ipd.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# Agents +agent1: 'LOLA' +agent2: 'PPO' + +# Environment +env_id: iterated_nplayer_tensor_game +env_type: sequential +env_discount: 0.96 +payoff_table: [ +[ -1 , 1000 ], +[ -3 , 0 ], +[ 1000 , -2 ], +] +runner: tensor_rl_nplayer +num_players: 2 + +num_envs: 100 +num_opps: 1 # TODO idk? +num_outer_steps: 1 +num_inner_steps: 100 # how long a game takes +num_iters: 1000 + +# LOLA agent parameters +lola: + use_baseline: True + adam_epsilon: 1e-5 + lr_out: 0.1 + gamma: 0.96 + num_lookaheads: 1 + +ppo1: # TODO unsure if I need ppo1 or ppo2 if second agent is ppo + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + +ppo2: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + + +# Logging setup +wandb: + entity: "ucl-dark" + project: tensor-ipd + group: 'test' + name: 'LOLA-vs-${agent2}' + log: True \ No newline at end of file diff --git a/pax/envs/in_the_matrix.py b/pax/envs/in_the_matrix.py index 7916abd4..7bcdefcb 100644 --- a/pax/envs/in_the_matrix.py +++ b/pax/envs/in_the_matrix.py @@ -1,5 +1,5 @@ -from enum import IntEnum import math +from enum import IntEnum from typing import Any, Optional, Tuple, Union import chex @@ -18,7 +18,6 @@ rotate_fn, ) - GRID_SIZE = 8 OBS_SIZE = 5 PADDING = OBS_SIZE - 1 diff --git a/pax/envs/iterated_tensor_game_n_player.py b/pax/envs/iterated_tensor_game_n_player.py new file mode 100644 index 00000000..ef0267fe --- /dev/null +++ b/pax/envs/iterated_tensor_game_n_player.py @@ -0,0 +1,142 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.numpy as jnp +from flax import struct +from gymnax.environments import environment, spaces + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + + +@chex.dataclass +class EnvParams: + payoff_table: chex.ArrayDevice + + +class IteratedTensorGameNPlayer(environment.Environment): + """ + JAX Compatible version of tensor game environment. + """ + + def __init__( + self, num_players: int, num_inner_steps: int, num_outer_steps: int + ): + super().__init__() + self.num_players = num_players + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[int, ...], + params: EnvParams, + ): + assert len(actions) == num_players + inner_t, outer_t = state.inner_t, state.outer_t + inner_t += 1 + reset_inner = inner_t == num_inner_steps + + # sum of 1s is number of defectors + num_defect = sum(list(actions)).astype(jnp.int8) + + # calculate rewards + # row of payoff table is number of defectors + # column of payoff table is whether player defected or not + payoff_array = jnp.array(params.payoff_table, dtype=jnp.float32) + assert payoff_array.shape == (num_players + 1, 2) + action_array = jnp.array(actions, dtype=jnp.int8) + relevant_row = payoff_array[num_defect] + + rewards_array = relevant_row[action_array] + all_rewards = tuple(rewards_array) + # calculate states + # we want these to be from first person perspective + # eg as if players were sitting in a circle + # and they start with theirs, then go clockwise + + # list of first person perspective actions + # with each player starting with theirs + fpp_actions = [] + for i in range(num_players): + fpp_action = actions[i:] + actions[:i] + fpp_action = jnp.array(fpp_action) + fpp_actions.append(fpp_action) + + # Start is the last state eg 2**num_players_th + start_state_idx = 2**num_players + len_one_hot = 2**num_players + 1 + # binary to int conversion for decoding actions + b2i = 2 ** jnp.arange(num_players - 1, -1, -1) + all_obs = [] + for i in range(num_players): + # first state is all c...cc + # second state is all c...cd + # 2**num_playerth state is d...dd + # so we can just binary decode the actions to get state + obs = (fpp_actions[i] * b2i).sum() + # if first step then return START state. + obs = jax.lax.select( + reset_inner, + start_state_idx * jnp.ones_like(obs), + obs, + ) + # one hot encode + obs = jax.nn.one_hot(obs, len_one_hot, dtype=jnp.int8) + all_obs.append(obs) + all_obs = tuple(all_obs) + # out step keeping + inner_t = jax.lax.select( + reset_inner, jnp.zeros_like(inner_t), inner_t + ) + outer_t_new = outer_t + 1 + outer_t = jax.lax.select(reset_inner, outer_t_new, outer_t) + reset_outer = outer_t == num_outer_steps + state = EnvState(inner_t=inner_t, outer_t=outer_t) + + return ( + all_obs, + state, + all_rewards, + reset_outer, + {"discount": jnp.zeros((), dtype=jnp.int8)}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + state = EnvState( + inner_t=jnp.zeros((), dtype=jnp.int8), + outer_t=jnp.zeros((), dtype=jnp.int8), + ) + # start state is the last one eg 2**num_players_th + # and one hot vectors should be 2**num_players+1 long + obs = jax.nn.one_hot( + (2**num_players) * jnp.ones(()), + 2**num_players + 1, + dtype=jnp.int8, + ) + all_obs = tuple([obs for _ in range(num_players)]) + return all_obs, state + + # overwrite Gymnax as it makes single-agent assumptions + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + @property + def name(self) -> str: + """Environment name.""" + return "IteratedTensorGame-Nplayer" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 2 + + def observation_space(self, params: EnvParams) -> spaces.Discrete: + """Observation space of the environment.""" + obs_space = jnp.power(2, self.num_players) + 1 + return spaces.Discrete(obs_space) diff --git a/pax/experiment.py b/pax/experiment.py index 46f3e327..b990e885 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -3,21 +3,16 @@ from datetime import datetime from functools import partial -# NOTE: THIS MUST BE DONE BEFORE IMPORTING JAX -# uncomment to debug multi-devices on CPU -# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" -# from jax.config import config -# config.update('jax_disable_jit', True) - +import gymnax import hydra +import jax import jax.numpy as jnp import omegaconf from evosax import CMA_ES, PGPE, OpenES, ParameterReshaper, SimpleGA -import gymnax -import jax import wandb from pax.agents.hyper.ppo import make_hyper +from pax.agents.lola.lola import make_lola from pax.agents.mfos_ppo.ppo_gru import make_mfos_agent from pax.agents.naive.naive import make_naive_pg from pax.agents.naive_exact import NaiveExact @@ -37,21 +32,34 @@ Stay, TitForTat, ) +from pax.agents.tensor_strategies import ( + TitForTatCooperate, + TitForTatDefect, + TitForTatHarsh, + TitForTatSoft, + TitForTatStrictStay, + TitForTatStrictSwitch, +) from pax.envs.coin_game import CoinGame from pax.envs.coin_game import EnvParams as CoinGameParams +from pax.envs.in_the_matrix import EnvParams as InTheMatrixParams +from pax.envs.in_the_matrix import InTheMatrix from pax.envs.infinite_matrix_game import EnvParams as InfiniteMatrixGameParams from pax.envs.infinite_matrix_game import InfiniteMatrixGame from pax.envs.iterated_matrix_game import EnvParams as IteratedMatrixGameParams from pax.envs.iterated_matrix_game import IteratedMatrixGame -from pax.envs.in_the_matrix import InTheMatrix -from pax.envs.in_the_matrix import ( - EnvParams as InTheMatrixParams, +from pax.envs.iterated_tensor_game_n_player import ( + EnvParams as IteratedTensorGameNPlayerParams, ) +from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer from pax.runners.runner_eval import EvalRunner +from pax.runners.runner_eval_multishaper import MultishaperEvalRunner from pax.runners.runner_evo import EvoRunner +from pax.runners.runner_evo_multishaper import MultishaperEvoRunner +from pax.runners.runner_ipditm_eval import IPDITMEvalRunner from pax.runners.runner_marl import RLRunner +from pax.runners.runner_marl_nplayer import NplayerRLRunner from pax.runners.runner_sarl import SARLRunner -from pax.runners.runner_ipditm_eval import IPDITMEvalRunner from pax.utils import Section from pax.watchers import ( logger_hyper, @@ -64,6 +72,13 @@ value_logger_ppo, ) +# NOTE: THIS MUST BE sDONE BEFORE IMPORTING JAX +# uncomment to debug multi-devices on CPU +# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" +# from jax.config import config + +# config.update("jax_disable_jit", True) + def global_setup(args): """Set up global variables.""" @@ -108,7 +123,20 @@ def env_setup(args, logger=None): f"Env Type: {args.env_type} | Inner Episode Length: {args.num_inner_steps}" ) logger.info(f"Outer Episode Length: {args.num_outer_steps}") + elif args.env_id == "iterated_nplayer_tensor_game": + payoff = jnp.array(args.payoff_table) + + env = IteratedTensorGameNPlayer( + num_players=args.num_players, + num_inner_steps=args.num_inner_steps, + num_outer_steps=args.num_outer_steps, + ) + env_params = IteratedTensorGameNPlayerParams(payoff_table=payoff) + if logger: + logger.info( + f"Env Type: {args.env_type} s| Inner Episode Length: {args.num_inner_steps}" + ) elif args.env_id == "infinite_matrix_game": payoff = jnp.array(args.payoff) env = InfiniteMatrixGame(num_steps=args.num_steps) @@ -165,12 +193,15 @@ def runner_setup(args, env, agents, save_dir, logger): if args.runner == "eval": logger.info("Evaluating with EvalRunner") return EvalRunner(agents, env, args) + elif args.runner == "multishaper_eval": + logger.info("Training with multishaper eval Runner") + return MultishaperEvalRunner(agents, env, save_dir, args) elif args.runner == "ipditm_eval": logger.info("Evaluating with ipditmEvalRunner") return IPDITMEvalRunner(agents, env, save_dir, args) - if args.runner == "evo": - agent1, _ = agents + if args.runner == "evo" or args.runner == "multishaper_evo": + agent1 = agents[0] algo = args.es.algo strategies = {"CMA_ES", "OpenES", "PGPE", "SimpleGA"} assert algo in strategies, f"{algo} not in evolution strategies" @@ -254,14 +285,34 @@ def get_pgpe_strategy(agent): strategy, es_params, param_reshaper = get_ga_strategy(agent1) logger.info(f"Evolution Strategy: {algo}") - - return EvoRunner( - agents, env, strategy, es_params, param_reshaper, save_dir, args - ) + if args.runner == "evo": + logger.info("Training with EVO runner") + return EvoRunner( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) + elif args.runner == "multishaper_evo": + logger.info("Training with multishaper EVO runner") + return MultishaperEvoRunner( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) elif args.runner == "rl": logger.info("Training with RL Runner") return RLRunner(agents, env, save_dir, args) + elif args.runner == "tensor_rl_nplayer": + return NplayerRLRunner(agents, env, save_dir, args) elif args.runner == "sarl": logger.info("Training with SARL Runner") return SARLRunner(agents, env, save_dir, args) @@ -272,7 +323,10 @@ def get_pgpe_strategy(agent): # flake8: noqa: C901 def agent_setup(args, env, env_params, logger): """Set up agent variables.""" - if args.env_id == "iterated_matrix_game": + if ( + args.env_id == "iterated_matrix_game" + or args.env_id == "iterated_nplayer_tensor_game" + ): obs_shape = env.observation_space(env_params).n elif args.env_id == "InTheMatrix": obs_shape = jax.tree_map( @@ -283,8 +337,20 @@ def agent_setup(args, env, env_params, logger): num_actions = env.num_actions + def get_LOLA_agent(seed, player_id): + return make_lola( + args, + obs_spec=obs_shape, + action_spec=num_actions, + seed=seed, + player_id=player_id, + env_params=env_params, + env_step=env.step, + env_reset=env.reset, + ) + def get_PPO_memory_agent(seed, player_id): - player_args = args.ppo1 if player_id == 1 else args.ppo2 + player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id)) num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps @@ -299,10 +365,12 @@ def get_PPO_memory_agent(seed, player_id): ) def get_PPO_agent(seed, player_id): - player_args = args.ppo1 if player_id == 1 else args.ppo2 - num_iterations = args.num_iters + player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id)) + if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps + else: + num_iterations = args.num_iters ppo_agent = make_agent( args, player_args, @@ -384,12 +452,18 @@ def get_random_agent(seed, player_id): # flake8: noqa: C901 def get_stay_agent(seed, player_id): - agent = Stay(num_actions, args.num_envs) + agent = Stay(num_actions, args.num_envs, args.num_players) agent.player_id = player_id return agent strategies = { "TitForTat": partial(TitForTat, args.num_envs), + "TitForTatStrictStay": partial(TitForTatStrictStay, args.num_envs), + "TitForTatStrictSwitch": partial(TitForTatStrictSwitch, args.num_envs), + "TitForTatCooperate": partial(TitForTatCooperate, args.num_envs), + "TitForTatDefect": partial(TitForTatDefect, args.num_envs), + "TitForTatHarsh": partial(TitForTatHarsh, args.num_envs), + "TitForTatSoft": partial(TitForTatSoft, args.num_envs), "Defect": partial(Defect, args.num_envs), "Altruistic": partial(Altruistic, args.num_envs), "Random": get_random_agent, @@ -398,6 +472,7 @@ def get_stay_agent(seed, player_id): "GoodGreedy": partial(GoodGreedy, args.num_envs), "EvilGreedy": partial(EvilGreedy, args.num_envs), "RandomGreedy": partial(RandomGreedy, args.num_envs), + "LOLA": get_LOLA_agent, "PPO": get_PPO_agent, "PPO_memory": get_PPO_memory_agent, "Naive": get_naive_pg, @@ -427,54 +502,66 @@ def get_stay_agent(seed, player_id): if args.runner in ["eval", "sarl"]: logger.info("Using Independent Learners") return agent_1 + else: - assert args.agent1 in strategies - assert args.agent2 in strategies + for i in range(1, args.num_players + 1): + assert ( + omegaconf.OmegaConf.select(args, "agent" + str(i)) + in strategies + ) - num_agents = 2 - seeds = [seed for seed in range(args.seed, args.seed + num_agents)] - # Create Player IDs by normalizing seeds to 1, 2 respectively + seeds = [ + seed for seed in range(args.seed, args.seed + args.num_players) + ] + # Create Player IDs by normalizing seeds to 1, 2, 3 ..n respectively pids = [ seed % seed + i if seed != 0 else 1 - for seed, i in zip(seeds, range(1, num_agents + 1)) + for seed, i in zip(seeds, range(1, args.num_players + 1)) ] - agent_0 = strategies[args.agent1](seeds[0], pids[0]) # player 1 - agent_1 = strategies[args.agent2](seeds[1], pids[1]) # player 2 + agents = [] + for i in range(args.num_players): + agents.append( + strategies[ + omegaconf.OmegaConf.select(args, "agent" + str(i + 1)) + ](seeds[i], pids[i]) + ) + logger.info( + f"Agent Pair: {[omegaconf.OmegaConf.select(args, 'agent' + str(i)) for i in range(1, args.num_players + 1)]}" + ) + logger.info(f"Agent seeds: {seeds}") - if args.agent1 in ["PPO", "PPO_memory"]: - logger.info(f"PPO with CNN: {args.ppo1.with_cnn}") - logger.info(f"Agent Pair: {args.agent1} | {args.agent2}") - logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}") - return (agent_0, agent_1) + return agents def watcher_setup(args, logger): """Set up watcher variables.""" - def ppo_memory_log(agent): + def ppo_memory_log( + agent, + ): losses = losses_ppo(agent) - if args.env_id not in ["coin_game", "InTheMatrix", "iterated_matrix_game"]: + if args.env_id not in [ + "coin_game", + "InTheMatrix", + "iterated_matrix_game", + "iterated_nplayer_tensor_game", + ]: policy = policy_logger_ppo_with_memory(agent) losses.update(policy) - if args.wandb.log: - losses = jax.tree_util.tree_map( - lambda x: x.item() if isinstance(x, jax.Array) else x, losses - ) - wandb.log(losses) return def ppo_log(agent): losses = losses_ppo(agent) - if args.env_id not in ["coin_game", "InTheMatrix", "iterated_matrix_game"]: + if args.env_id not in [ + "coin_game", + "InTheMatrix", + "iterated_matrix_game", + "iterated_nplayer_tensor_game", + ]: policy = policy_logger_ppo(agent) value = value_logger_ppo(agent) losses.update(value) losses.update(policy) - if args.wandb.log: - losses = jax.tree_util.tree_map( - lambda x: x.item() if isinstance(x, jax.Array) else x, losses - ) - wandb.log(losses) return def dumb_log(agent, *args): @@ -501,7 +588,10 @@ def naive_logger(agent): def naive_pg_log(agent): losses = naive_pg_losses(agent) - if args.env_id in ["finite_matrix_game"]: + if args.env_id in [ + "iterated_matrix_game", + "iterated_nplayer_tensor_game", + ]: policy = policy_logger_ppo(agent) value = value_logger_ppo(agent) losses.update(value) @@ -523,6 +613,7 @@ def naive_pg_log(agent): "RandomGreedy": dumb_log, "MFOS": dumb_log, "PPO": ppo_log, + "LOLA": dumb_log, "PPO_memory": ppo_memory_log, "Naive": naive_pg_log, "Hyper": hyper_log, @@ -542,13 +633,11 @@ def naive_pg_log(agent): return agent_1_log else: - assert args.agent1 in strategies - assert args.agent2 in strategies - - agent_0_log = strategies[args.agent1] - agent_1_log = strategies[args.agent2] - - return [agent_0_log, agent_1_log] + agent_log = [] + for i in range(1, args.num_players + 1): + assert getattr(args, f"agent{i}") in strategies + agent_log.append(strategies[getattr(args, f"agent{i}")]) + return agent_log @hydra.main(config_path="conf", config_name="config") @@ -563,7 +652,6 @@ def main(args): with Section("Agent setup", logger=logger): agent_pair = agent_setup(args, env, env_params, logger) - with Section("Watcher setup", logger=logger): watchers = watcher_setup(args, logger) @@ -575,21 +663,20 @@ def main(args): print(f"Number of Training Iterations: {args.num_iters}") - if args.runner == "evo": + if args.runner == "evo" or args.runner == "multishaper_evo": runner.run_loop(env_params, agent_pair, args.num_iters, watchers) - elif args.runner == "rl": + elif args.runner == "rl" or args.runner == "tensor_rl_nplayer": # number of episodes print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env_params, agent_pair, args.num_iters, watchers) - elif args.runner == "ipditm_eval": + elif args.runner == "ipditm_eval" or args.runner == "multishaper_eval": runner.run_loop(env_params, agent_pair, watchers) elif args.runner == "sarl": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) - elif args.runner == "eval": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) diff --git a/pax/runners/runner_eval_multishaper.py b/pax/runners/runner_eval_multishaper.py new file mode 100644 index 00000000..fcc85929 --- /dev/null +++ b/pax/runners/runner_eval_multishaper.py @@ -0,0 +1,681 @@ +import os +import time +from typing import Any, List, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +from omegaconf import OmegaConf + +import wandb +from pax.utils import MemoryState, TrainingState, load, save +from pax.watchers import ( + ipditm_stats, + n_player_ipd_visitation, + tensor_ipd_visitation, +) + +MAX_WANDB_CALLS = 1000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class MFOSSample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + meta_actions: jnp.ndarray + + +@jax.jit +def reduce_outer_traj(traj: Sample) -> Sample: + """Used to collapse lax.scan outputs dims""" + # x: [outer_loop, inner_loop, num_opps, num_envs ...] + # x: [timestep, batch_size, ...] + num_envs = traj.observations.shape[2] * traj.observations.shape[3] + num_timesteps = traj.observations.shape[0] * traj.observations.shape[1] + return jax.tree_util.tree_map( + lambda x: x.reshape((num_timesteps, num_envs) + x.shape[4:]), + traj, + ) + + +class MultishaperEvalRunner: + """ + Reinforcement Learning runner provides a convenient example for quickly writing + a MARL runner for PAX. The MARLRunner class can be used to + run any two RL agents together either in a meta-game or regular game, it composes together agents, + watchers, and the environment. Within the init, we declare vmaps and pmaps for training. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The environment that the agents will run in. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + # flake8: noqa: C901 + def __init__(self, agents, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.num_opps = args.num_opps + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + + def _reshape_opp_dim(x): + # x: [num_opps, num_envs ...] + # x: [batch_size, ...] + batch_size = args.num_envs * args.num_opps + return jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), x + ) + + self.reduce_opp_dim = jax.jit(_reshape_opp_dim) + self.ipd_stats = n_player_ipd_visitation + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + self.ipditm_stats = jax.jit(ipditm_stats) + # VMAP for num opps: we vmap over the rng but not params + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + self.num_players = args.num_players + self.num_shapers = args.num_shapers + self.num_targets = args.num_players - args.num_shapers + self.num_outer_steps = self.args.num_outer_steps + shapers = agents[: self.num_shapers] + targets = agents[self.num_shapers :] + # set up agents + # batch MemoryState not TrainingState + for agent_idx, shaper_agent in enumerate(shapers): + agent_arg = f"agent{agent_idx+1}" + if OmegaConf.select(args, agent_arg) == "NaiveEx": + # special case where NaiveEx has a different call signature + shaper_agent.batch_init = jax.jit( + jax.vmap(shaper_agent.make_initial_state) + ) + else: + + shaper_agent.batch_init = jax.vmap( + shaper_agent.make_initial_state, + (None, 0), + (None, 0), + ) + shaper_agent.batch_reset = jax.jit( + jax.vmap(shaper_agent.reset_memory, (0, None), 0), + static_argnums=1, + ) + + shaper_agent.batch_policy = jax.jit( + jax.vmap(shaper_agent._policy, (None, 0, 0), (0, None, 0)) + ) + + # go through opponents + for agent_idx, target_agent in enumerate(targets): + agent_arg = f"agent{agent_idx+self.num_shapers+1}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + target_agent.batch_init = jax.jit( + jax.vmap(target_agent.make_initial_state) + ) + else: + target_agent.batch_init = jax.vmap( + target_agent.make_initial_state, (0, None), 0 + ) + target_agent.batch_policy = jax.jit(jax.vmap(target_agent._policy)) + target_agent.batch_reset = jax.jit( + jax.vmap(target_agent.reset_memory, (0, None), 0), + static_argnums=1, + ) + target_agent.batch_update = jax.jit( + jax.vmap(target_agent.update, (1, 0, 0, 0), 0) + ) + + for agent_idx, shaper_agent in enumerate(shapers): + agent_arg = f"agent{agent_idx+1}" + if OmegaConf.select(args, agent_arg) != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile( + shaper_agent._mem.hidden, (args.num_opps, 1, 1) + ) + ( + shaper_agent._state, + shaper_agent._mem, + ) = shaper_agent.batch_init( + shaper_agent._state.random_key, init_hidden + ) + + for agent_idx, target_agent in enumerate(targets): + agent_arg = f"agent{agent_idx+self.num_shapers+1}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile( + target_agent._mem.hidden, (args.num_opps, 1, 1) + ) + ( + target_agent._state, + target_agent._mem, + ) = target_agent.batch_init( + jax.random.split( + target_agent._state.random_key, args.num_opps + ), + init_hidden, + ) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + shapers_obs, + targets_obs, + shapers_reward, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ) = carry + new_targets_mem = [None] * self.num_targets + new_shapers_mem = [None] * self.num_shapers + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, :, 3, :] + shapers_actions = [] + targets_actions = [] + for agent_idx, shaper_agent in enumerate(shapers): + ( + shaper_action, + shapers_state[agent_idx], + new_shapers_mem[agent_idx], + ) = shaper_agent.batch_policy( + shapers_state[agent_idx], + shapers_obs[agent_idx], + shapers_mem[agent_idx], + ) + shapers_actions.append(shaper_action) + for agent_idx, target_agent in enumerate(targets): + ( + target_action, + targets_state[agent_idx], + new_targets_mem[agent_idx], + ) = target_agent.batch_policy( + targets_state[agent_idx], + targets_obs[agent_idx], + targets_mem[agent_idx], + ) + targets_actions.append(target_action) + ( + all_agent_next_obs, + env_state, + all_agent_rewards, + done, + info, + ) = env.step( + env_rng, + env_state, + tuple(shapers_actions + targets_actions), + env_params, + ) + shapers_next_obs = all_agent_next_obs[: self.num_shapers] + targets_next_obs = all_agent_next_obs[self.num_shapers :] + shapers_reward, targets_reward = ( + all_agent_rewards[: self.num_shapers], + all_agent_rewards[self.num_shapers :], + ) + + shapers_traj = [ + Sample( + shapers_obs[agent_idx], + shapers_actions[agent_idx], + shapers_reward[agent_idx], + new_shapers_mem[agent_idx].extras["log_probs"], + new_shapers_mem[agent_idx].extras["values"], + done, + shapers_mem[agent_idx].hidden, + ) + for agent_idx in range(self.num_shapers) + ] + targets_traj = [ + Sample( + targets_obs[agent_idx], + targets_actions[agent_idx], + targets_rewards[agent_idx], + new_targets_mem[agent_idx].extras["log_probs"], + new_targets_mem[agent_idx].extras["values"], + done, + targets_mem[agent_idx].hidden, + ) + for agent_idx in range(self.num_targets) + ] + return ( + rngs, + tuple(shapers_next_obs), + tuple(targets_next_obs), + tuple(shapers_reward), + tuple(targets_reward), + shapers_state, + targets_state, + new_shapers_mem, + new_targets_mem, + env_state, + env_params, + ), tuple(shapers_traj + targets_traj) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=self.args.num_inner_steps, + ) + targets_metrics = [None] * self.num_targets + ( + rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + for agent_idx, shaper_agent in enumerate(shapers): + agent_arg = f"agent{agent_idx+1}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "MFOS": + shapers_mem[agent_idx] = shaper_agent.meta_policy( + shapers_mem[agent_idx] + ) + + # update second agent + targets_traj = trajectories[self.num_shapers :] + for agent_idx, target_agent in enumerate(targets): + ( + targets_state[agent_idx], + targets_mem[agent_idx], + targets_metrics[agent_idx], + ) = target_agent.batch_update( + targets_traj[agent_idx], + targets_obs[agent_idx], + targets_state[agent_idx], + targets_mem[agent_idx], + ) + return ( + rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ), (trajectories, targets_metrics) + + def _rollout( + _rng_run: jnp.ndarray, + _shapers_state: List[TrainingState], + _shapers_mem: List[MemoryState], + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] * args.num_opps + ).reshape((args.num_opps, args.num_envs, -1)) + + obs, env_state = env.reset(rngs, _env_params) + shapers_obs = obs[: self.num_shapers] + targets_obs = obs[self.num_shapers :] + rewards = [ + jnp.zeros((args.num_opps, args.num_envs), dtype=jnp.float32) + ] * args.num_players + # Player 1 + for agent_idx, shaper_agent in enumerate(shapers): + _shapers_mem[agent_idx] = shaper_agent.batch_reset( + _shapers_mem[agent_idx], False + ) + # Other players + _rng_run, *target_rngs = jax.random.split( + _rng_run, self.num_players + ) + targets_mem = [None] * self.num_targets + targets_state = [None] * self.num_targets + for agent_idx, target_agent in enumerate(targets): + # if eg 2 shapers, agent3 is the first non-shaper + agent_arg = f"agent{agent_idx+1+self.num_shapers}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + ( + targets_mem[agent_idx], + targets_state[agent_idx], + ) = target_agent.batch_init(targets_obs[agent_idx]) + + elif self.args.env_type in ["meta"]: + # meta-experiments - init other agents per trial + ( + targets_state[agent_idx], + targets_mem[agent_idx], + ) = target_agent.batch_init( + jax.random.split( + target_rngs[agent_idx], self.num_opps + ), + target_agent._mem.hidden, + ) + + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + tuple(obs[: self.num_shapers]), + tuple(obs[self.num_shapers :]), + tuple(rewards[: self.num_shapers]), + tuple(rewards[self.num_shapers :]), + _shapers_state, + targets_state, + _shapers_mem, + targets_mem, + env_state, + _env_params, + ), + None, + length=self.num_outer_steps, + ) + + ( + rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ) = vals + trajectories, targets_metrics = stack + shapers_traj = trajectories[: self.num_shapers] + targets_traj = trajectories[self.num_shapers :] + + # reset memory + for agent_idx, shaper_agent in enumerate(shapers): + shapers_mem[agent_idx] = shaper_agent.batch_reset( + shapers_mem[agent_idx], False + ) + + for agent_idx, target_agent in enumerate(targets): + targets_mem[agent_idx] = target_agent.batch_reset( + targets_mem[agent_idx], False + ) + # Stats + if args.env_id == "iterated_nplayer_tensor_game": + total_env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + trajectories[0].observations, + num_players=args.num_players, + ), + ) + shapers_rewards = [ + shapers_traj[shaper_idx].rewards.mean() + for shaper_idx in range(self.num_shapers) + ] + + targets_rewards = [ + targets_traj[target_idx].rewards.mean() + for target_idx in range(self.num_targets) + ] + else: + raise NotImplementedError + + return ( + total_env_stats, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + targets_metrics, + trajectories, + ) + + # self.rollout = _rollout + self.rollout = jax.jit(_rollout) + + def run_loop(self, env_params, agents, watchers): + """Run training of agents in environment""" + print("Training") + print("-----------------------") + shaper_agents = agents[: self.num_shapers] + target_agents = agents[self.num_shapers :] + rng, _ = jax.random.split(self.random_key) + + # get initial state and memory + shapers_state = [] + shapers_mem = [] + for agent_idx, shaper_agent in enumerate(shaper_agents): + shapers_state.append(shaper_agent._state) + shapers_mem.append(shaper_agent._mem) + targets_state = [] + targets_mem = [] + for agent_idx, target_agent in enumerate(target_agents): + targets_state.append(target_agent._state) + targets_mem.append(target_agent._mem) + + for agent_idx, shaper_agent in enumerate(shaper_agents): + model_path = f"model_path{agent_idx+1}" + run_path = f"run_path{agent_idx+1}" + if model_path not in self.args: + raise ValueError( + f"Please provide a model path for shaper {agent_idx+1}" + ) + + wandb.restore( + name=self.args[model_path], + run_path=self.args[run_path], + root=os.getcwd(), + ) + pretrained_params = load(self.args[model_path]) + shapers_state[agent_idx] = shapers_state[agent_idx]._replace( + params=pretrained_params + ) + + # run actual loop + rng, rng_run = jax.random.split(rng, 2) + # RL Rollout + ( + env_stats, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + targets_metrics, + trajectories, + ) = self.rollout( + rng_run, + shapers_state, + shapers_mem, + env_params, + ) + + # for stat in env_stats.keys(): + # print(stat + f": {env_stats[stat].item()}") + # print( + # f"Shapers Reward Per Timestep: {[float(reward.mean()) for reward in shapers_rewards]}" + # ) + # print( + # f"Targets Reward Per Timestep: {[float(reward.mean()) for reward in targets_rewards]}" + # ) + # print() + + if watchers: + + list_traj1 = [ + Sample( + observations=jax.tree_util.tree_map( + lambda x: x[i, ...], trajectories[0].observations + ), + actions=trajectories[0].actions[i, ...], + rewards=trajectories[0].rewards[i, ...], + dones=trajectories[0].dones[i, ...], + # env_state=None, + behavior_log_probs=trajectories[0].behavior_log_probs[ + i, ... + ], + behavior_values=trajectories[0].behavior_values[i, ...], + hiddens=trajectories[0].hiddens[i, ...], + ) + for i in range(self.args.num_outer_steps) + ] + + list_of_env_stats = [ + jax.tree_util.tree_map( + lambda x: x.item(), + self.ipd_stats( + observations=traj.observations, + num_players=self.args.num_players, + ), + ) + for traj in list_traj1 + ] + shaper_traj = trajectories[: self.num_shapers] + target_traj = trajectories[self.num_shapers :] + + # log agent one + watchers[0](agents[0]) + # log the inner episodes + shaper_rewards_log = [ + { + f"eval/reward_per_timestep/shaper_{shaper_idx+1}": float( + traj.rewards[i].mean().item() + ) + for (shaper_idx, traj) in enumerate(shaper_traj) + } + for i in range(len(list_of_env_stats)) + ] + target_rewards_log = [ + { + f"eval/reward_per_timestep/target_{target_idx+1}": float( + traj.rewards[i].mean().item() + ) + for (target_idx, traj) in enumerate(target_traj) + } + for i in range(len(list_of_env_stats)) + ] + + # log avg reward for players combined + global_welfare_log = [ + { + f"eval/global_welfare_per_timestep": float( + sum( + traj.rewards[i].mean().item() + for traj in trajectories + ) + ) + / len(trajectories) + } + for i in range(len(list_of_env_stats)) + ] + shaper_welfare_log = [ + { + f"eval/shaper_welfare_per_timestep": float( + sum( + traj.rewards[i].mean().item() + for traj in shaper_traj + ) + ) + / len(shaper_traj) + } + for i in range(len(list_of_env_stats)) + ] + target_welfare_log = [ + { + f"eval/target_welfare_per_timestep": float( + sum( + traj.rewards[i].mean().item() + for traj in target_traj + ) + ) + / len(target_traj) + } + for i in range(len(list_of_env_stats)) + ] + + for i in range(len(list_of_env_stats)): + wandb.log( + { + "train_iteration": i, + } + | list_of_env_stats[i] + | shaper_rewards_log[i] + | target_rewards_log[i] + | shaper_welfare_log[i] + | target_welfare_log[i] + | global_welfare_log[i] + ) + shaper_rewards_log = { + f"eval/meta_reward/shaper{idx+1}": float(rew.mean().item()) + for (idx, rew) in enumerate(shapers_rewards) + } + target_rewards_log = { + f"eval/meta_reward/target{idx+1}": float(rew.mean().item()) + for (idx, rew) in enumerate(targets_rewards) + } + + wandb.log( + { + "episodes": 1, + } + | shaper_rewards_log + | target_rewards_log + ) + + return agents diff --git a/pax/runners/runner_evo_multishaper.py b/pax/runners/runner_evo_multishaper.py new file mode 100644 index 00000000..afe0e36e --- /dev/null +++ b/pax/runners/runner_evo_multishaper.py @@ -0,0 +1,814 @@ +import os +import time +from datetime import datetime +from functools import partial +from typing import Any, Callable, List, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +from evosax import FitnessShaper +from omegaconf import OmegaConf + +import wandb +from pax.utils import MemoryState, TrainingState, save + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, n_player_ipd_visitation + +MAX_WANDB_CALLS = 1000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class MultishaperEvoRunner: + """ + Evoluationary Strategy runner provides a convenient example for quickly writing + a MARL runner for PAX. The EvoRunner class can be used to + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. + It composes together agents, watchers, and the environment. + Within the init, we declare vmaps and pmaps for training. + The environment provided must conform to a meta-environment. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + # TODO JIT this + self.ipd_stats = n_player_ipd_visitation + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + self.num_players = args.num_players + self.num_shapers = args.num_shapers + self.num_targets = args.num_players - args.num_shapers + self.num_outer_steps = args.num_outer_steps + shapers = agents[: self.num_shapers] + targets = agents[self.num_shapers :] + + # vmap agents accordingly + # shapers are batched over popsize and num_opps + for shaper_agent in shapers: + shaper_agent.batch_init = jax.vmap( + jax.vmap( + shaper_agent.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + shaper_agent.batch_reset = jax.jit( + jax.vmap( + jax.vmap(shaper_agent.reset_memory, (0, None), 0), + (0, None), + 0, + ), + static_argnums=1, + ) + + shaper_agent.batch_policy = jax.jit( + jax.vmap( + jax.vmap(shaper_agent._policy, (None, 0, 0), (0, None, 0)), + ) + ) + # go through opponents, we start with agent2 + for agent_idx, target_agent in enumerate(targets): + agent_arg = f"agent{agent_idx+self.num_shapers+1}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + # special case where NaiveEx has a different call signature + target_agent.batch_init = jax.jit( + jax.vmap(jax.vmap(target_agent.make_initial_state)) + ) + else: + target_agent.batch_init = jax.jit( + jax.vmap( + jax.vmap( + target_agent.make_initial_state, (0, None), 0 + ), + (0, None), + 0, + ) + ) + + target_agent.batch_policy = jax.jit( + jax.vmap(jax.vmap(target_agent._policy, 0, 0)) + ) + target_agent.batch_reset = jax.jit( + jax.vmap( + jax.vmap(target_agent.reset_memory, (0, None), 0), + (0, None), + 0, + ), + static_argnums=1, + ) + + target_agent.batch_update = jax.jit( + jax.vmap( + jax.vmap(target_agent.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + if OmegaConf.select(args, agent_arg) != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile( + target_agent._mem.hidden, (args.num_opps, 1, 1) + ) + + agent_rng = jnp.concatenate( + [ + jax.random.split( + target_agent._state.random_key, args.num_opps + ) + ] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + + ( + target_agent._state, + target_agent._mem, + ) = target_agent.batch_init( + agent_rng, + init_hidden, + ) + + # jit evo + strategy.ask = jax.jit(strategy.ask) + strategy.tell = jax.jit(strategy.tell) + param_reshaper.reshape = jax.jit(param_reshaper.reshape) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + shapers_obs, + targets_obs, + shapers_reward, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ) = carry + new_targets_mem = [None] * self.num_targets + new_shapers_mem = [None] * self.num_shapers + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + shapers_actions = [] + targets_actions = [] + for agent_idx, shaper_agent in enumerate(shapers): + ( + shaper_action, + shapers_state[agent_idx], + new_shapers_mem[agent_idx], + ) = shaper_agent.batch_policy( + shapers_state[agent_idx], + shapers_obs[agent_idx], + shapers_mem[agent_idx], + ) + shapers_actions.append(shaper_action) + + for agent_idx, target_agent in enumerate(targets): + ( + target_action, + targets_state[agent_idx], + new_targets_mem[agent_idx], + ) = target_agent.batch_policy( + targets_state[agent_idx], + targets_obs[agent_idx], + targets_mem[agent_idx], + ) + targets_actions.append(target_action) + ( + all_agent_next_obs, + env_state, + all_agent_rewards, + done, + info, + ) = env.step( + env_rng, + env_state, + tuple(shapers_actions + targets_actions), + env_params, + ) + shapers_next_obs = all_agent_next_obs[: self.num_shapers] + targets_next_obs = all_agent_next_obs[self.num_shapers :] + shapers_reward, targets_reward = ( + all_agent_rewards[: self.num_shapers], + all_agent_rewards[self.num_shapers :], + ) + + shapers_traj = [ + Sample( + shapers_obs[agent_idx], + shapers_actions[agent_idx], + shapers_reward[agent_idx], + new_shapers_mem[agent_idx].extras["log_probs"], + new_shapers_mem[agent_idx].extras["values"], + done, + shapers_mem[agent_idx].hidden, + ) + for agent_idx in range(self.num_shapers) + ] + targets_traj = [ + Sample( + targets_obs[agent_idx], + targets_actions[agent_idx], + targets_rewards[agent_idx], + new_targets_mem[agent_idx].extras["log_probs"], + new_targets_mem[agent_idx].extras["values"], + done, + targets_mem[agent_idx].hidden, + ) + for agent_idx in range(self.num_targets) + ] + # jax.debug.breakpoint() + # print(len(shapers_next_obs)) + # print(len(targets_next_obs)) + # print(shapers_next_obs[0].shape) + # print(targets_next_obs[0].shape) + return ( + rngs, + tuple(shapers_next_obs), + tuple(targets_next_obs), + tuple(shapers_reward), + tuple(targets_reward), + shapers_state, + targets_state, + new_shapers_mem, + new_targets_mem, + env_state, + env_params, + ), tuple(shapers_traj + targets_traj) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + targets_metrics = [None] * self.num_targets + ( + rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + for agent_idx, shaper_agent in enumerate(shapers): + agent_arg = f"agent{agent_idx+1}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "MFOS": + shapers_mem[agent_idx] = shaper_agent.meta_policy( + shapers_mem[agent_idx] + ) + # update opponents + targets_traj = trajectories[self.num_shapers :] + for agent_idx, target_agent in enumerate(targets): + ( + targets_state[agent_idx], + targets_mem[agent_idx], + targets_metrics[agent_idx], + ) = target_agent.batch_update( + targets_traj[agent_idx], + targets_obs[agent_idx], + targets_state[agent_idx], + targets_mem[agent_idx], + ) + return ( + rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + env_params, + ), (trajectories, targets_metrics) + + def _rollout( + _params: List[jnp.ndarray], + _rng_run: jnp.ndarray, + _shapers_state: List[TrainingState], + _shapers_mem: List[MemoryState], + _env_params: Any, + ): + # env reset + env_rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs, env_state = env.reset(env_rngs, _env_params) + shapers_obs = obs[: self.num_shapers] + targets_obs = obs[self.num_shapers :] + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] * args.num_players + + # Shapers + for agent_idx, shaper_agent in enumerate(shapers): + _shapers_state[agent_idx] = _shapers_state[agent_idx]._replace( + params=_params[agent_idx] + ) + _shapers_mem[agent_idx] = shaper_agent.batch_reset( + _shapers_mem[agent_idx], False + ) + # Other players + targets_mem = [None] * self.num_targets + targets_state = [None] * self.num_targets + + _rng_run, *target_rngs = jax.random.split( + _rng_run, self.num_players + ) + for agent_idx, target_agent in enumerate(targets): + # if eg 2 shapers, agent3 is the first non-shaper + agent_arg = f"agent{agent_idx+1+self.num_shapers}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + ( + targets_mem[agent_idx], + targets_state[agent_idx], + ) = target_agent.batch_init(targets_obs[agent_idx]) + else: + # meta-experiments - init 2nd agent per trial + target_agent_rng = jnp.concatenate( + [ + jax.random.split( + target_rngs[agent_idx], args.num_opps + ) + ] + * args.popsize + ).reshape(args.popsize, args.num_opps, -1) + ( + targets_state[agent_idx], + targets_mem[agent_idx], + ) = target_agent.batch_init( + target_agent_rng, + target_agent._mem.hidden, + ) + + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + env_rngs, + tuple(obs[: self.num_shapers]), + tuple(obs[self.num_shapers :]), + tuple(rewards[: self.num_shapers]), + tuple(rewards[self.num_shapers :]), + _shapers_state, + targets_state, + _shapers_mem, + targets_mem, + env_state, + _env_params, + ), + None, + length=self.num_outer_steps, + ) + ( + env_rngs, + shapers_obs, + targets_obs, + shapers_rewards, + targets_rewards, + shapers_state, + targets_state, + shapers_mem, + targets_mem, + env_state, + _env_params, + ) = vals + trajectories, targets_metrics = stack + shapers_traj = trajectories[: self.num_shapers] + targets_traj = trajectories[self.num_shapers :] + + # Fitness + shapers_fitness = [ + shapers_traj[shaper_idx].rewards.mean(axis=(0, 1, 3, 4)) + for shaper_idx in range(self.num_shapers) + ] + targets_fitness = [ + targets_traj[targets_idx].rewards.mean(axis=(0, 1, 3, 4)) + for targets_idx in range(self.num_targets) + ] + # # Stats + if args.env_id in [ + "iterated_nplayer_tensor_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + trajectories[0].observations, args.num_players + ), + ) + shapers_rewards = [ + shapers_traj[shaper_idx].rewards.mean() + for shaper_idx in range(self.num_shapers) + ] + + targets_rewards = [ + targets_traj[target_idx].rewards.mean() + for target_idx in range(self.num_targets) + ] + else: + raise NotImplementedError + + return ( + shapers_fitness, + targets_fitness, + env_stats, + shapers_rewards, + targets_rewards, + targets_metrics, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None), + ) + + print( + f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" + ) + + def run_loop( + self, + env_params, + agents, + num_iters: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + print(f"Number of Players: {self.num_players}") + print(f"Number of Shapers: {self.num_shapers}") + print(f"Number of Targets: {self.num_targets}") + print(f"Number of Generations: {num_iters}") + print(f"Number of Meta Episodes: {self.num_outer_steps}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + rng, *evo_keys = jax.random.split( + self.random_key, self.num_shapers + 1 + ) + # Initialize evolution + num_gens = num_iters + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_states = [ + strategy.initialize(evo_keys[idx], es_params) + for idx in range(self.num_shapers) + ] + fit_shaper = FitnessShaper( + maximize=self.args.es.maximise, + centered_rank=self.args.es.centered_rank, + w_decay=self.args.es.w_decay, + z_score=self.args.es.z_score, + ) + es_logging = [ + ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + ] * self.num_shapers + logs = [es_log.initialize() for es_log in es_logging] + + # Reshape a single agent's params before vmapping + shaper_agents = agents[: self.num_shapers] + target_agents = agents[self.num_shapers :] + + init_hiddens = [ + jnp.tile( + shaper_agents[shaper_idx]._mem.hidden, + (popsize, num_opps, 1, 1), + ) + for shaper_idx in range(self.num_shapers) + ] + + shapers_state = [] + shapers_mem = [] + for shaper_idx, shaper_agent in enumerate(shaper_agents): + rng, shaper_rng = jax.random.split(rng, 2) + pop_shaper_rng = jax.random.split(shaper_rng, popsize) + shaper_agent._state, shaper_agent._mem = shaper_agent.batch_init( + pop_shaper_rng, + init_hiddens[shaper_idx], + ) + shapers_state.append(shaper_agent._state) + shapers_mem.append(shaper_agent._mem) + + for gen in range(num_gens): + rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) + # Ask + xs = [] + shapers_params = [] + old_evo_states = evo_states + evo_states = [] + for shaper_idx, shaper_agent in enumerate(shaper_agents): + shaper_rng, rng_evo = jax.random.split(rng_evo, 2) + x, evo_state = strategy.ask( + shaper_rng, old_evo_states[shaper_idx], es_params + ) + params = param_reshaper.reshape(x) + if self.args.num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + evo_states.append(evo_state) + xs.append(x) + shapers_params.append(params) + # Evo Rollout + ( + shapers_fitness, + targets_fitness, + env_stats, + shapers_rewards, + targets_rewards, + targets_metrics, + ) = self.rollout( + shapers_params, rng_run, shapers_state, shapers_mem, env_params + ) + + # Aggregate over devices + shapers_fitness = [ + jnp.reshape(fitness, popsize * self.args.num_devices) + for fitness in shapers_fitness + ] + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Tell + fitness_re = [ + fit_shaper.apply(x, fitness) + for x, fitness in zip(xs, shapers_fitness) + ] + + if self.args.es.mean_reduce: + fitness_re = [fit_re - fit_re.mean() for fit_re in fitness_re] + evo_states = [ + strategy.tell(x, fit_re, evo_state, es_params) + for x, fit_re, evo_state in zip(xs, fitness_re, evo_states) + ] + + # Logging + logs = [ + es_log.update(log, x, fitness) + for es_log, log, x, fitness in zip( + es_logging, logs, xs, shapers_fitness + ) + ] + # Saving + if gen % self.args.save_interval == 0 or gen == num_gens - 1: + for shaper_idx in range(self.num_shapers): + log_savepath = os.path.join( + self.save_dir, f"generation_{gen}_agent_{shaper_idx}" + ) + if self.args.num_devices > 1: + top_params = param_reshaper.reshape( + logs[shaper_idx]["top_gen_params"][ + 0 : self.args.num_devices + ] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + logs[shaper_idx]["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath) + if watchers: + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {gen} locally") + + if gen % log_interval == 0: + print(f"Generation: {gen}") + print( + "--------------------------------------------------------------------------" + ) + print( + [ + f"Shaper{idx} fitness: {shapers_fitness[idx].mean()} " + for idx in range(self.num_shapers) + ] + ) + print( + [ + f"Target{idx} fitness: {targets_fitness[idx].mean()} " + for idx in range(self.num_targets) + ] + ) + print( + f"Shapers Reward Per Timestep: {[float(reward.mean()) for reward in shapers_rewards]}" + ) + print( + f"Targets Reward Per Timestep: {[float(reward.mean()) for reward in targets_rewards]}" + ) + print( + f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" + ) + + if watchers: + shaper_rewards_strs = [ + "train/reward_per_timestep/shaper_" + str(i) + for i in range(1, self.num_shapers + 1) + ] + shaper_rewards_val = [ + float(reward.mean()) for reward in shapers_rewards + ] + target_rewards_strs = [ + "train/reward_per_timestep/target_" + str(i) + for i in range(1, self.num_targets + 1) + ] + target_rewards_val = [ + float(reward.mean()) for reward in targets_rewards + ] + rewards_strs = shaper_rewards_strs + target_rewards_strs + rewards_val = shaper_rewards_val + target_rewards_val + rewards_dict = dict(zip(rewards_strs, rewards_val)) + + shaper_fitness_str = [ + "train/fitness/shaper_" + str(i) + for i in range(1, self.num_shapers + 1) + ] + shaper_fitness_val = [ + float(fitness.mean()) for fitness in shapers_fitness + ] + target_fitness_str = [ + "train/fitness/target_" + str(i) + for i in range(1, self.num_targets + 1) + ] + target_fitness_val = [ + float(fitness.mean()) for fitness in targets_fitness + ] + fitness_strs = shaper_fitness_str + target_fitness_str + fitness_vals = shaper_fitness_val + target_fitness_val + + fitness_dict = dict(zip(fitness_strs, fitness_vals)) + + shaper_welfare = float( + sum([reward.mean() for reward in shapers_rewards]) + / self.num_shapers + ) + if self.num_targets > 0: + target_welfare = float( + sum([reward.mean() for reward in targets_rewards]) + / self.num_targets + ) + else: + target_welfare = 0 + + all_rewards = shapers_rewards + targets_rewards + global_welfare = float( + sum([reward.mean() for reward in all_rewards]) + / self.args.num_players + ) + wandb_log = { + "train_iteration": gen, + # "train/fitness/top_overall_mean": log["log_top_mean"][gen], + # "train/fitness/top_overall_std": log["log_top_std"][gen], + # "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + # "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + # "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/welfare/shaper": shaper_welfare, + "train/welfare/target": target_welfare, + "train/global_welfare": global_welfare, + } | rewards_dict + wandb_log = wandb_log | fitness_dict + wandb_log.update(env_stats) + # # loop through population + # for idx, (overall_fitness, gen_fitness) in enumerate( + # zip(log["top_fitness"], log["top_gen_fitness"]) + # ): + # wandb_log[ + # f"train/fitness/top_overall_agent_{idx+1}" + # ] = overall_fitness + # wandb_log[ + # f"train/fitness/top_gen_agent_{idx+1}" + # ] = gen_fitness + + # other player metrics + # metrics [outer_timesteps, num_opps] + for agent, metrics in zip(agents[1:], targets_metrics): + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), metrics + ) + + agent._logger.metrics.update(flattened_metrics) + # TODO fix agent logger + # for watcher, agent in zip(watchers, agents): + # watcher(agent) + wandb_log = jax.tree_util.tree_map( + lambda x: x.item() if isinstance(x, jax.Array) else x, + wandb_log, + ) + wandb.log(wandb_log) + + return agents diff --git a/pax/runners/runner_ipditm_eval.py b/pax/runners/runner_ipditm_eval.py index 7eedaac7..4d1d6733 100644 --- a/pax/runners/runner_ipditm_eval.py +++ b/pax/runners/runner_ipditm_eval.py @@ -1,22 +1,18 @@ import os import time -from typing import Any, NamedTuple -from PIL import Image from datetime import datetime +from typing import Any, NamedTuple import jax import jax.numpy as jnp +import numpy as onp +from PIL import Image +from tqdm import tqdm import wandb +from pax.envs.in_the_matrix import EnvState, InTheMatrix from pax.utils import MemoryState, TrainingState, load from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats -from pax.envs.in_the_matrix import ( - InTheMatrix, - EnvState, -) - -from tqdm import tqdm -import numpy as onp MAX_WANDB_CALLS = 1000 diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 2eed3e1d..98263fc4 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -6,7 +6,13 @@ import jax.numpy as jnp import wandb -from pax.utils import MemoryState, TrainingState, save +from pax.utils import ( + MemoryState, + TrainingState, + copy_state_and_mem, + copy_state_and_network, + save, +) from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats MAX_WANDB_CALLS = 1000 @@ -37,6 +43,16 @@ class MFOSSample(NamedTuple): meta_actions: jnp.ndarray +class LOLASample(NamedTuple): + obs_self: jnp.ndarray + obs_other: jnp.ndarray + actions_self: jnp.ndarray + actions_other: jnp.ndarray + dones: jnp.ndarray + rewards_self: jnp.ndarray + rewards_other: jnp.ndarray + + @jax.jit def reduce_outer_traj(traj: Sample) -> Sample: """Used to collapse lax.scan outputs dims""" @@ -92,22 +108,25 @@ def _reshape_opp_dim(x): # VMAP for num_envs self.ipditm_stats = jax.jit(ipditm_stats) # VMAP for num envs: we vmap over the rng but not params - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( + env.batch_reset = jax.vmap(env.reset, (0, None), 0) + env.batch_step = jax.vmap( env.step, (0, 0, 0, None), 0 # rng, state, actions, params ) # VMAP for num opps: we vmap over the rng but not params - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( + env.batch_reset = jax.jit(jax.vmap(env.batch_reset, (0, None), 0)) + env.batch_step = jax.jit( jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params + env.batch_step, + (0, 0, 0, None), + 0, # rng, state, actions, params ) ) self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) num_outer_steps = self.args.num_outer_steps agent1, agent2 = agents + agent1.agent2 = agent2 # Pointer for LOLA # set up agents if args.agent1 == "NaiveEx": @@ -120,6 +139,11 @@ def _reshape_opp_dim(x): (None, 0), (None, 0), ) + if args.agent1 == "LOLA": + # batch for num_opps + agent1.batch_in_lookahead = jax.vmap( + agent1.in_lookahead, (0, None, 0, 0, 0), (0, 0) + ) agent1.batch_reset = jax.jit( jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 ) @@ -191,7 +215,13 @@ def _inner_rollout(carry, unused): obs2, a2_mem, ) - (next_obs1, next_obs2), env_state, rewards, done, info = env.step( + ( + (next_obs1, next_obs2), + env_state, + rewards, + done, + info, + ) = env.batch_step( env_rng, env_state, (a1, a2), @@ -209,16 +239,15 @@ def _inner_rollout(carry, unused): a1_mem.hidden, a1_mem.th, ) - else: - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) + traj1 = Sample( + obs1, + a1, + rewards[0], + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) traj2 = Sample( obs2, a2, @@ -304,8 +333,9 @@ def _rollout( rngs = jnp.concatenate( [jax.random.split(_rng_run, args.num_envs)] * args.num_opps ).reshape((args.num_opps, args.num_envs, -1)) + _rng_run, _ = jax.random.split(_rng_run) - obs, env_state = env.reset(rngs, _env_params) + obs, env_state = env.batch_reset(rngs, _env_params) rewards = [ jnp.zeros((args.num_opps, args.num_envs)), jnp.zeros((args.num_opps, args.num_envs)), @@ -358,12 +388,39 @@ def _rollout( traj_1, traj_2, a2_metrics = stack # update outer agent - a1_state, _, a1_metrics = agent1.update( - reduce_outer_traj(traj_1), - self.reduce_opp_dim(obs1), - a1_state, - self.reduce_opp_dim(a1_mem), - ) + if args.agent1 != "LOLA": + a1_state, _, a1_metrics = agent1.update( + reduce_outer_traj(traj_1), + self.reduce_opp_dim(obs1), + a1_state, + self.reduce_opp_dim(a1_mem), + ) + if args.agent1 == "LOLA": + a1_metrics = None + # copy so we don't modify the original during simulation + self_state, self_mem = copy_state_and_mem(a1_state, a1_mem) + other_state, other_mem = copy_state_and_mem(a2_state, a2_mem) + # get new state of opponent after their lookahead optimisation + for _ in range(args.lola.num_lookaheads): + _rng_run, _ = jax.random.split(_rng_run) + lookahead_rng = jax.random.split(_rng_run, args.num_opps) + + # we want to batch this num_opps times + other_state, other_mem = agent1.batch_in_lookahead( + lookahead_rng, + self_state, + self_mem, + other_state, + other_mem, + ) + # get our new state after our optimisation based on ops new state + _rng_run, out_look_rng = jax.random.split(_rng_run) + a1_state = agent1.out_lookahead( + out_look_rng, a1_state, a1_mem, other_state, other_mem + ) + + if args.agent2 == "LOLA": + raise NotImplementedError("LOLA not implemented for agent2") # reset memory a1_mem = agent1.batch_reset(a1_mem, False) @@ -426,9 +483,6 @@ def run_loop(self, env_params, agents, num_iters, watchers): """Run training of agents in environment""" print("Training") print("-----------------------") - num_iters = max( - int(num_iters / (self.args.num_envs * self.num_opps)), 1 - ) log_interval = int(max(num_iters / MAX_WANDB_CALLS, 5)) save_interval = self.args.save_interval @@ -491,9 +545,10 @@ def run_loop(self, env_params, agents, num_iters, watchers): flattened_metrics_1 = jax.tree_util.tree_map( lambda x: jnp.mean(x), a1_metrics ) - agent1._logger.metrics = ( - agent1._logger.metrics | flattened_metrics_1 - ) + if self.args.agent1 != "LOLA": + agent1._logger.metrics = ( + agent1._logger.metrics | flattened_metrics_1 + ) # metrics [outer_timesteps, num_opps] flattened_metrics_2 = jax.tree_util.tree_map( lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics diff --git a/pax/runners/runner_marl_nplayer.py b/pax/runners/runner_marl_nplayer.py new file mode 100644 index 00000000..1c43e819 --- /dev/null +++ b/pax/runners/runner_marl_nplayer.py @@ -0,0 +1,630 @@ +import os +import time +from typing import Any, List, NamedTuple + +import jax +import jax.numpy as jnp +from omegaconf import OmegaConf + +import wandb +from pax.utils import MemoryState, TrainingState, copy_state_and_mem, save +from pax.watchers import n_player_ipd_visitation + +MAX_WANDB_CALLS = 1000 + + +class LOLASample(NamedTuple): + obs_self: jnp.ndarray + obs_other: jnp.ndarray + actions_self: jnp.ndarray + actions_other: jnp.ndarray + dones: jnp.ndarray + rewards_self: jnp.ndarray + rewards_other: jnp.ndarray + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class MFOSSample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + meta_actions: jnp.ndarray + + +@jax.jit +def reduce_outer_traj(traj: Sample) -> Sample: + """Used to collapse lax.scan outputs dims""" + # x: [outer_loop, inner_loop, num_opps, num_envs ...] + # x: [timestep, batch_size, ...] + num_envs = traj.rewards.shape[2] * traj.rewards.shape[3] + num_timesteps = traj.rewards.shape[0] * traj.rewards.shape[1] + return jax.tree_util.tree_map( + lambda x: x.reshape((num_timesteps, num_envs) + x.shape[4:]), + traj, + ) + + +class NplayerRLRunner: + """ + Reinforcement Learning runner provides a convenient example for quickly writing + a MARL runner for PAX. The MARLRunner class can be used to + run any two RL agents together either in a meta-game or regular game, it composes together agents, + watchers, and the environment. Within the init, we declare vmaps and pmaps for training. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The environment that the agents will run in. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + # flake8: noqa: C901 + def __init__(self, agents, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.num_opps = args.num_opps + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + + def _reshape_opp_dim(x): + # x: [num_opps, num_envs ...] + # x: [batch_size, ...] + batch_size = args.num_envs * args.num_opps + return jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), x + ) + + self.reduce_opp_dim = jax.jit(_reshape_opp_dim) + self.ipd_stats = n_player_ipd_visitation + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # VMAP for num opps: we vmap over the rng but not params + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + num_outer_steps = self.args.num_outer_steps + agent1, *other_agents = agents + agent1.other_agents = other_agents + + # set up agents + if args.agent1 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent1.batch_init = jax.jit(jax.vmap(agent1.make_initial_state)) + else: + # batch MemoryState not TrainingState + agent1.batch_init = jax.vmap( + agent1.make_initial_state, + (None, 0), + (None, 0), + ) + agent1.batch_reset = jax.jit( + jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 + ) + + agent1.batch_policy = jax.jit( + jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)) + ) + if args.agent1 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) + agent1._state, agent1._mem = agent1.batch_init( + agent1._state.random_key, init_hidden + ) + if args.agent1 == "LOLA": + # batch for num_opps + agent1.batch_in_lookahead = jax.vmap( + agent1.in_lookahead, (0, None, 0, 0, 0), (0, 0) + ) + + # go through opponents, we start with agent2 + for agent_idx, non_first_agent in enumerate(other_agents): + agent_arg = f"agent{agent_idx+2}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + # special case where NaiveEx has a different call signature + non_first_agent.batch_init = jax.jit( + jax.vmap(non_first_agent.make_initial_state) + ) + else: + non_first_agent.batch_init = jax.vmap( + non_first_agent.make_initial_state, (0, None), 0 + ) + non_first_agent.batch_policy = jax.jit( + jax.vmap(non_first_agent._policy) + ) + non_first_agent.batch_reset = jax.jit( + jax.vmap(non_first_agent.reset_memory, (0, None), 0), + static_argnums=1, + ) + non_first_agent.batch_update = jax.jit( + jax.vmap(non_first_agent.update, (1, 0, 0, 0), 0) + ) + + if OmegaConf.select(args, agent_arg) != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile( + non_first_agent._mem.hidden, (args.num_opps, 1, 1) + ) + agent_rng = jax.random.split( + non_first_agent._state.random_key, args.num_opps + ) + ( + non_first_agent._state, + non_first_agent._mem, + ) = non_first_agent.batch_init( + agent_rng, + init_hidden, + ) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, :, 3, :] + new_other_agent_mem = [None] * len(other_agents) + actions = [] + + ( + first_action, + first_agent_state, + new_first_agent_mem, + ) = agent1.batch_policy( + first_agent_state, + first_agent_obs, + first_agent_mem, + ) + actions.append(first_action) + for agent_idx, non_first_agent in enumerate(other_agents): + ( + non_first_action, + other_agent_state[agent_idx], + new_other_agent_mem[agent_idx], + ) = non_first_agent.batch_policy( + other_agent_state[agent_idx], + other_agent_obs[agent_idx], + other_agent_mem[agent_idx], + ) + actions.append(non_first_action) + ( + all_agent_next_obs, + env_state, + all_agent_rewards, + done, + info, + ) = env.step( + env_rng, + env_state, + actions, + env_params, + ) + + first_agent_next_obs, *other_agent_next_obs = all_agent_next_obs + first_agent_reward, *other_agent_rewards = all_agent_rewards + if args.agent1 == "MFOS": + traj1 = MFOSSample( + first_agent_obs, + first_action, + first_agent_reward, + new_first_agent_mem.extras["log_probs"], + new_first_agent_mem.extras["values"], + done, + first_agent_mem.hidden, + first_agent_mem.th, + ) + else: + traj1 = Sample( + first_agent_obs, + first_action, + first_agent_reward, + new_first_agent_mem.extras["log_probs"], + new_first_agent_mem.extras["values"], + done, + first_agent_mem.hidden, + ) + other_traj = [ + Sample( + other_agent_obs[agent_idx], + actions[agent_idx + 1], + other_agent_rewards[agent_idx], + new_other_agent_mem[agent_idx].extras["log_probs"], + new_other_agent_mem[agent_idx].extras["values"], + done, + other_agent_mem[agent_idx].hidden, + ) + for agent_idx in range(len(other_agents)) + ] + return ( + rngs, + first_agent_next_obs, + tuple(other_agent_next_obs), + first_agent_reward, + tuple(other_agent_rewards), + first_agent_state, + other_agent_state, + new_first_agent_mem, + new_other_agent_mem, + env_state, + env_params, + ), (traj1, *other_traj) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=self.args.num_inner_steps, + ) + other_agent_metrics = [None] * len(other_agents) + ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + first_agent_mem = agent1.meta_policy(first_agent_mem) + + # update second agent + for agent_idx, non_first_agent in enumerate(other_agents): + ( + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + other_agent_metrics[agent_idx], + ) = non_first_agent.batch_update( + trajectories[agent_idx + 1], + other_agent_obs[agent_idx], + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + ) + return ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + env_state, + env_params, + ), (trajectories, other_agent_metrics) + + def _rollout( + _rng_run: jnp.ndarray, + first_agent_state: TrainingState, + first_agent_mem: MemoryState, + other_agent_state: List[TrainingState], + other_agent_mem: List[MemoryState], + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] * args.num_opps + ).reshape((args.num_opps, args.num_envs, -1)) + + obs, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.num_opps, args.num_envs)), + ] * args.num_players + # Player 1 + first_agent_mem = agent1.batch_reset(first_agent_mem, False) + # Other players + _rng_run, other_agent_rng = jax.random.split(_rng_run, 2) + + # Resetting Agents if necessary + if args.agent1 == "NaiveEx": + first_agent_state, first_agent_mem = agent1.batch_init(obs[0]) + + for agent_idx, non_first_agent in enumerate(other_agents): + # indexing starts at 2 for args + agent_arg = f"agent{agent_idx+2}" + # equivalent of args.agent_n + if OmegaConf.select(args, agent_arg) == "NaiveEx": + ( + other_agent_mem[agent_idx], + other_agent_state[agent_idx], + ) = non_first_agent.batch_init(obs[agent_idx + 1]) + + elif self.args.env_type in ["meta"]: + # meta-experiments - init 2nd agent per trial + ( + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + ) = non_first_agent.batch_init( + jax.random.split(other_agent_rng, self.num_opps), + non_first_agent._mem.hidden, + ) + _rng_run, other_agent_rng = jax.random.split(_rng_run, 2) + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + obs[0], + tuple(obs[1:]), + rewards[0], + tuple(rewards[1:]), + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + env_state, + _env_params, + ), + None, + length=num_outer_steps, + ) + + ( + rngs, + first_agent_obs, + other_agent_obs, + first_agent_reward, + other_agent_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + env_state, + env_params, + ) = vals + trajectories, other_agent_metrics = stack + + # update outer agent + if args.agent1 != "LOLA": + first_agent_state, _, first_agent_metrics = agent1.update( + reduce_outer_traj(trajectories[0]), + self.reduce_opp_dim(first_agent_obs), + first_agent_state, + self.reduce_opp_dim(first_agent_mem), + ) + + elif args.agent1 == "LOLA": + # jax.debug.breakpoint() + # copy so we don't modify the original during simulation + first_agent_metrics = None + + self_state, self_mem = copy_state_and_mem( + first_agent_state, first_agent_mem + ) + other_states, other_mems = [], [] + for agent_idx, non_first_agent in enumerate(other_agents): + other_state, other_mem = copy_state_and_mem( + other_agent_state[agent_idx], + other_agent_mem[agent_idx], + ) + other_states.append(other_state) + other_mems.append(other_mem) + # get new state of opponent after their lookahead optimisation + for _ in range(args.lola.num_lookaheads): + _rng_run, _ = jax.random.split(_rng_run) + lookahead_rng = jax.random.split(_rng_run, args.num_opps) + + # we want to batch this num_opps times + other_states, other_mems = agent1.batch_in_lookahead( + lookahead_rng, + self_state, + self_mem, + other_states, + other_mems, + ) + # get our new state after our optimisation based on ops new state + _rng_run, out_look_rng = jax.random.split(_rng_run) + first_agent_state = agent1.out_lookahead( + out_look_rng, + first_agent_state, + first_agent_mem, + other_states, + other_mems, + ) + # jax.debug.breakpoint() + + if args.agent2 == "LOLA": + raise NotImplementedError("LOLA not implemented for agent2") + + # reset memory + first_agent_mem = agent1.batch_reset(first_agent_mem, False) + for agent_idx, non_first_agent in enumerate(other_agents): + other_agent_mem[agent_idx] = non_first_agent.batch_reset( + other_agent_mem[agent_idx], False + ) + # Stats + if args.env_id == "iterated_nplayer_tensor_game": + total_env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + trajectories[0].observations, + num_players=args.num_players, + ), + ) + total_rewards = [traj.rewards.mean() for traj in trajectories] + else: + total_env_stats = {} + total_rewards = [traj.rewards.mean() for traj in trajectories] + + return ( + total_env_stats, + total_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + first_agent_metrics, + other_agent_metrics, + trajectories, + ) + + # self.rollout = _rollout + self.rollout = jax.jit(_rollout) + + def run_loop(self, env_params, agents, num_iters, watchers): + """Run training of agents in environment""" + print("Training") + print("-----------------------") + # log_interval = int(max(num_iters / MAX_WANDB_CALLS, 5)) + log_interval = 100 + save_interval = self.args.save_interval + + agent1, *other_agents = agents + rng, _ = jax.random.split(self.random_key) + + first_agent_state, first_agent_mem = agent1._state, agent1._mem + other_agent_mem = [None] * len(other_agents) + other_agent_state = [None] * len(other_agents) + + for agent_idx, non_first_agent in enumerate(other_agents): + other_agent_state[agent_idx], other_agent_mem[agent_idx] = ( + non_first_agent._state, + non_first_agent._mem, + ) + + print(f"Num iters {num_iters}") + print(f"Log Interval {log_interval}") + print(f"Save Interval {save_interval}") + # run actual loop + for i in range(num_iters): + rng, rng_run = jax.random.split(rng, 2) + # RL Rollout + ( + env_stats, + total_rewards, + first_agent_state, + other_agent_state, + first_agent_mem, + other_agent_mem, + first_agent_metrics, + other_agent_metrics, + trajectories, + ) = self.rollout( + rng_run, + first_agent_state, + first_agent_mem, + other_agent_state, + other_agent_mem, + env_params, + ) + + # saving + if i % save_interval == 0: + log_savepath1 = os.path.join( + self.save_dir, f"agent1_iteration_{i}" + ) + if watchers: + print(f"Saving iteration {i} locally and to WandB") + wandb.save(log_savepath1) + else: + print(f"Saving iteration {i} locally") + + # #logging + # if i % log_interval == 0: + # print(f"Episode {i}") + # for stat in env_stats.keys(): + # print(stat + f": {env_stats[stat].item()}") + # print( + # f"Reward Per Timestep: {[float(reward.mean()) for reward in total_rewards]}" + # ) + # print() + + if watchers: + # metrics [outer_timesteps] + flattened_metrics_1 = jax.tree_util.tree_map( + lambda x: jnp.mean(x), first_agent_metrics + ) + if self.args.agent1 != "LOLA": + agent1._logger.metrics = ( + agent1._logger.metrics | flattened_metrics_1 + ) + + for watcher, agent in zip(watchers, agents): + watcher(agent) + + env_stats = jax.tree_util.tree_map( + lambda x: x.item(), env_stats + ) + rewards_strs = [ + "train/reward_per_timestep/player_" + str(i) + for i in range(1, len(total_rewards) + 1) + ] + rewards_val = [ + float(reward.mean()) for reward in total_rewards + ] + global_welfare = { + f"train/global_welfare_per_timestep": float( + sum(reward.mean().item() for reward in total_rewards) + ) + / len(total_rewards) + } + + rewards_dict = dict(zip(rewards_strs, rewards_val)) + wandb_log = ( + {"train_iteration": i} + | rewards_dict + | env_stats + | global_welfare + ) + wandb.log(wandb_log) + + agents[0]._state = first_agent_state + for agent_idx, non_first_agent in enumerate(other_agents): + agents[agent_idx + 1]._state = other_agent_state[agent_idx] + return agents diff --git a/pax/runners/runner_sarl.py b/pax/runners/runner_sarl.py index 5e0f1973..a573e18b 100644 --- a/pax/runners/runner_sarl.py +++ b/pax/runners/runner_sarl.py @@ -4,8 +4,8 @@ import jax import jax.numpy as jnp -import wandb +import wandb from pax.utils import MemoryState, TrainingState, save # from jax.config import config diff --git a/pax/utils.py b/pax/utils.py index 68f7233b..284e337c 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -93,6 +93,19 @@ def to_numpy(values): return jax.tree_util.tree_map(np.asarray, values) +class LOLATrainingState(NamedTuple): + """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" + + policy_params: hk.Params + value_params: hk.Params + policy_opt_state: optax.GradientTransformation + value_opt_state: optax.GradientTransformation + random_key: jnp.ndarray + timesteps: int + extras: Mapping[str, jnp.ndarray] + hidden: None + + class TrainingState(NamedTuple): """Training state consists of network parameters, optimiser state, random key, timesteps""" @@ -133,3 +146,66 @@ def load(filename: str): with open(filename, "rb") as handle: es_logger = pickle.load(handle) return es_logger + + +def copy_state_and_network(agent): + import copy + + """Copies an agent state and returns the state""" + state = TrainingState( + params=copy.deepcopy(agent._state.params), + opt_state=agent._state.opt_state, + random_key=agent._state.random_key, + timesteps=agent._state.timesteps, + ) + mem = MemoryState( + hidden=copy.deepcopy(agent._mem.hidden), + extras={ + "values": copy.deepcopy(agent._mem.extras["values"]), + "log_probs": copy.deepcopy(agent._mem.extras["log_probs"]), + }, + ) + network = agent.network + return state, mem, network + + +def copy_state_and_mem(state, mem): + import copy + + """Copies an agent state and returns the state""" + state = TrainingState( + params=copy.deepcopy(state.params), + opt_state=state.opt_state, + random_key=state.random_key, + timesteps=state.timesteps, + ) + mem = MemoryState( + hidden=copy.deepcopy(mem.hidden), + extras={ + "values": copy.deepcopy(mem.extras["values"]), + "log_probs": copy.deepcopy(mem.extras["log_probs"]), + }, + ) + return state, mem + + +def copy_extended_state_and_network(agent): + import copy + + """Copies an agent state and returns the state""" + state = LOLATrainingState( + policy_params=copy.deepcopy(agent._state.policy_params), + value_params=copy.deepcopy(agent._state.value_params), + policy_opt_state=agent._state.policy_opt_state, + value_opt_state=agent._state.value_opt_state, + random_key=agent._state.random_key, + timesteps=agent._state.timesteps, + extras={ + "values": jnp.zeros(agent._num_envs), + "log_probs": jnp.zeros(agent._num_envs), + }, + hidden=None, + ) + policy_network = agent.policy_network + value_network = agent.value_network + return state, policy_network, value_network diff --git a/pax/watchers.py b/pax/watchers.py index 462f8616..14ca65a8 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -1,4 +1,5 @@ import enum +import itertools import pickle from functools import partial from typing import NamedTuple @@ -11,8 +12,8 @@ import pax.agents.hyper.ppo as HyperPPO import pax.agents.ppo.ppo as PPO from pax.agents.naive_exact import NaiveExact -from pax.envs.iterated_matrix_game import EnvState, IteratedMatrixGame from pax.envs.in_the_matrix import InTheMatrix +from pax.envs.iterated_matrix_game import EnvState, IteratedMatrixGame # five possible states START = jnp.array([[0, 0, 0, 0, 1]]) @@ -425,6 +426,639 @@ def ipd_visitation( } +# we cannot jit on numplayers because we need to use jnp.bincount +def n_player_ipd_visitation( + observations: jnp.ndarray, + num_players: int, +) -> dict: + # obs [num_outer_steps, num_inner_steps, num_opps, num_envs, num_states] + state_actions = jnp.argmax(observations, axis=-1) + # 2**num_players is the number of possible states plus start state + hist = jnp.bincount(state_actions.flatten(), length=2**num_players + 1) + state_freq = hist + state_probs = hist / hist.sum() + + num_def_vis = jnp.zeros( + (2 * num_players + 1,) + ) # 2 choices for first player * numpl different amount of defectors + start state + for state_idx in range(2**num_players): # not dealing with start state + # check what first agent did + if state_idx >= 2 ** (num_players - 1): # if first agent defected + num_def_state_idx = num_players # we start with D, add more later + new_state_idx = state_idx - 2 ** ( + num_players - 1 + ) # get rid of first agent + else: + num_def_state_idx = 0 # we start with C, add more later + new_state_idx = state_idx + + # count how many defectors there are amongst opponents + num_def = bin(new_state_idx).count("1") + # add it to the index + num_def_state_idx = num_def_state_idx + num_def + num_def_vis = num_def_vis.at[num_def_state_idx].add(hist[state_idx]) + + # dealing with start state + num_def_vis = num_def_vis.at[2 * num_players].add( + hist[2**num_players] + ) # num start state vis is the last + grouped_state_freq = num_def_vis + grouped_state_probs = grouped_state_freq / grouped_state_freq.sum() + + # generate the dict keys for logging + letters = ["C", "D"] + combinations = list(itertools.product(letters, repeat=num_players)) + combinations = ["".join(c) for c in combinations] + ["START"] + + visitation_strs = ["state_visitation/" + c for c in combinations] + prob_strs = ["state_probability/" + c for c in combinations] + + def generate_grouped_combs_strs(num_players): + # eg 4 pl order is + # C3C0D, C2C1D, C1C2D, C0C3D, idx (0...num_pl-1 ) + # D3C0D, D2C1D, D1C2D, D0C3D, idx (num_pl... 2*num_pl-1 ) + # START idx (2*num_pl) + grouped_combs = [] + for n in range(0, num_players): # num_players-1 opponents + string_1 = f"C{n}D" + grouped_combs.append(string_1) + for n in range(0, num_players): + string_2 = f"D{n}D" + grouped_combs.append(string_2) + grouped_combs.append("START") + return grouped_combs + + grouped_comb_strs = generate_grouped_combs_strs(num_players) + grouped_visitation_strs = [ + "grouped_state_visitation/" + c for c in grouped_comb_strs + ] + grouped_prob_strs = [ + "grouped_state_probability/" + c for c in grouped_comb_strs + ] + + visitation_dict = ( + dict(zip(visitation_strs, state_freq)) + | dict(zip(prob_strs, state_probs)) + | dict(zip(grouped_visitation_strs, grouped_state_freq)) + | dict(zip(grouped_prob_strs, grouped_state_probs)) + ) + return visitation_dict + + +def tensor_ipd_visitation( + observations: jnp.ndarray, +) -> dict: + # obs [num_outer_steps, num_inner_steps, num_opps, num_envs, num_states] + state_actions = jnp.argmax(observations, axis=-1) + hist = jnp.bincount(state_actions.flatten(), length=9) + state_freq = hist + state_probs = hist / hist.sum() + return { + "state_visitation/CCC": state_freq[0], + "state_visitation/CCD": state_freq[1], + "state_visitation/CDC": state_freq[2], + "state_visitation/CDD": state_freq[3], + "state_visitation/DCC": state_freq[4], + "state_visitation/DCD": state_freq[5], + "state_visitation/DDC": state_freq[6], + "state_visitation/DDD": state_freq[7], + "state_visitation/START": state_freq[8], + "state_probability/CCC": state_probs[0], + "state_probability/CCD": state_probs[1], + "state_probability/CDC": state_probs[2], + "state_probability/CDD": state_probs[3], + "state_probability/DCC": state_probs[4], + "state_probability/DCD": state_probs[5], + "state_probability/DDC": state_probs[6], + "state_probability/DDD": state_probs[7], + "state_probability/START": state_probs[8], + } + + +def third_party_punishment_visitation( + log_obs: jnp.ndarray, +) -> dict: + # prev_actions: [num_outer_steps, num_inner_steps, num_opps, num_envs, 6] + # the 6 are binary of c/d: pl1 vs pl2, pl1 vs pl3, pl2 vs pl3 of the prev step + # punishments: [num_outer_steps, num_inner_steps, num_opps, num_envs, 3] + # the 3 are nums 0,1,2,3 for pl1 pl2 pl3 punishment actions + # (0 for no punish, 1,2 for punish first or second opponent, 3 for punish both) + + prev_actions, punishments = log_obs + # flatten out + prev_actions = prev_actions.reshape(-1, 6) + punishments = punishments.reshape(-1, 3) + + pl1_vs_pl2_bin = prev_actions[:, 0:2] # pl1 vs pl2 + pl1_vs_pl3_bin = prev_actions[:, 2:4] # pl1 vs pl3 + pl2_vs_pl3_bin = prev_actions[:, 4:6] # pl2 vs pl3 + b2i = 2 ** jnp.arange(2 - 1, -1, -1) + pl1_vs_pl2 = (pl1_vs_pl2_bin * b2i).sum(axis=-1) + pl1_vs_pl3 = (pl1_vs_pl3_bin * b2i).sum(axis=-1) + pl2_vs_pl3 = (pl2_vs_pl3_bin * b2i).sum(axis=-1) + + # all 3 games combined + all_games_actions = jnp.concatenate([pl1_vs_pl2, pl1_vs_pl3, pl2_vs_pl3]) + hist = jnp.bincount(all_games_actions, length=4) + action_freq = hist + action_probs = hist / hist.sum() + + # games action breakdown + pl1_v_pl2_hist = jnp.bincount(pl1_vs_pl2, length=4) + pl1_v_pl2_action_freq = pl1_v_pl2_hist + pl1_v_pl2_action_probs = pl1_v_pl2_hist / pl1_v_pl2_hist.sum() + + pl1_v_pl3_hist = jnp.bincount(pl1_vs_pl3, length=4) + pl1_v_pl3_action_freq = pl1_v_pl3_hist + pl1_v_pl3_action_probs = pl1_v_pl3_hist / pl1_v_pl3_hist.sum() + + pl2_v_pl3_hist = jnp.bincount(pl2_vs_pl3, length=4) + pl2_v_pl3_action_freq = pl2_v_pl3_hist + pl2_v_pl3_action_probs = pl2_v_pl3_hist / pl2_v_pl3_hist.sum() + + # player actions breakdown + pl1_total_defects_prob = ( + pl1_vs_pl2_bin[:, 0] + pl1_vs_pl3_bin[:, 0] + ).sum() / len(pl1_vs_pl2_bin[:, 0]) + pl2_total_defects_prob = ( + pl1_vs_pl2_bin[:, 1] + pl2_vs_pl3_bin[:, 0] + ).sum() / len(pl1_vs_pl2_bin[:, 0]) + pl3_total_defects_prob = ( + pl1_vs_pl3_bin[:, 1] + pl2_vs_pl3_bin[:, 1] + ).sum() / len(pl1_vs_pl2_bin[:, 0]) + + # pl1 punished if pl3 gives 1 or 3 + # pl1 punished if pl2 gives 2 or 3 + pl1_punished_prob = ( + jnp.where(punishments[:, 2] % 2 == 1, 1, 0) + + jnp.where(punishments[:, 1] > 1, 1, 0) + ).sum() / len(punishments[:, 0]) + # pl2 punished if pl3 gives 2 or 3 + # pl2 punished if pl1 gives 1 or 3 + pl2_punished_prob = ( + jnp.where(punishments[:, 2] > 1, 1, 0) + + jnp.where(punishments[:, 0] % 2 == 1, 1, 0) + ).sum() / len(punishments[:, 0]) + # pl3 punished if pl1 gives 2 or 3 + # pl3 punished if pl2 gives 1 or 3 + pl3_punished_prob = ( + jnp.where(punishments[:, 0] > 1, 1, 0) + + jnp.where(punishments[:, 1] % 2 == 1, 1, 0) + ).sum() / len(punishments[:, 0]) + + pl1_num_punishes = ( + jnp.where(punishments[:, 0] == 1, 1, 0) + + jnp.where(punishments[:, 0] == 2, 1, 0) + + jnp.where(punishments[:, 0] == 3, 2, 0) + ).sum() / len(punishments[:, 0]) + + pl2_num_punishes = ( + jnp.where(punishments[:, 1] == 1, 1, 0) + + jnp.where(punishments[:, 1] == 2, 1, 0) + + jnp.where(punishments[:, 1] == 3, 2, 0) + ).sum() / len(punishments[:, 0]) + + pl3_num_punishes = ( + jnp.where(punishments[:, 2] == 1, 1, 0) + + jnp.where(punishments[:, 2] == 2, 1, 0) + + jnp.where(punishments[:, 2] == 3, 2, 0) + ).sum() / len(punishments[:, 0]) + + pl1_punished_defecting_pl2 = ( + jnp.where(punishments[:, 0] % 2 == 1, 1, 0) + * jnp.where( + # pl2 defected against pl3 is states 2 or 3 + pl2_vs_pl3 > 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + pl1_punished_defecting_pl3 = ( + jnp.where(punishments[:, 0] > 1, 1, 0) + * jnp.where( + # pl3 defected against pl2 is states 1 or 3 + pl2_vs_pl3 % 2 == 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + pl2_punished_defecting_pl3 = ( + jnp.where(punishments[:, 1] % 2 == 1, 1, 0) + * jnp.where( + # pl3 defected against pl1 is states 1 or 3 + pl1_vs_pl3 % 2 == 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + pl2_punished_defecting_pl1 = ( + # pl2 punished pl1 is states 2 or 3 + jnp.where(punishments[:, 1] > 1, 1, 0) + * jnp.where( + # pl1 defected against pl3 is states 2 or 3 + pl1_vs_pl3 > 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + pl3_punished_defecting_pl1 = ( + jnp.where(punishments[:, 2] % 2 == 1, 1, 0) + * jnp.where( + # pl1 defected against pl2 is states 2 or 3 + pl1_vs_pl2 > 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + pl3_punished_defecting_pl2 = ( + jnp.where(punishments[:, 2] > 1, 1, 0) + * jnp.where( + # pl2 defected against pl1 is states 1 or 3 + pl1_vs_pl2 % 2 == 1, + 1, + 0, + ) + ).sum() / len(punishments[:, 0]) + + num_pl1pl2_defects = ( + jnp.where(pl1_vs_pl2 > 1, 1, 0) + jnp.where(pl1_vs_pl2 % 2 == 1, 1, 0) + ).sum() / len(pl1_vs_pl2) + + num_pl1pl3_defects = ( + jnp.where(pl1_vs_pl3 > 1, 1, 0) + jnp.where(pl1_vs_pl3 % 2 == 1, 1, 0) + ).sum() / len(pl1_vs_pl3) + + num_pl2pl3_defects = ( + jnp.where(pl2_vs_pl3 > 1, 1, 0) + jnp.where(pl2_vs_pl3 % 2 == 1, 1, 0) + ).sum() / len(pl2_vs_pl3) + + pl1_punishes_defecting_players = ( + pl1_punished_defecting_pl2 + pl1_punished_defecting_pl3 + ) + pl1_punished_defect_to_total_punishes = ( + pl1_punishes_defecting_players / pl1_num_punishes + ) + pl1_punished_defects_to_total_opn_defects = ( + pl1_punishes_defecting_players / num_pl2pl3_defects + ) + + pl2_punishes_defecting_players = ( + pl2_punished_defecting_pl1 + pl2_punished_defecting_pl3 + ) + pl2_punished_defect_to_total_punishes = ( + pl2_punishes_defecting_players / pl2_num_punishes + ) + pl2_punished_defects_to_total_opn_defects = ( + pl2_punishes_defecting_players / num_pl1pl3_defects + ) + + pl3_punishes_defecting_players = ( + pl3_punished_defecting_pl1 + pl3_punished_defecting_pl2 + ) + pl3_punished_defect_to_total_punishes = ( + pl3_punishes_defecting_players / pl3_num_punishes + ) + pl3_punished_defects_to_total_opn_defects = ( + pl3_punishes_defecting_players / num_pl1pl2_defects + ) + + total_punishment = pl1_num_punishes + pl2_num_punishes + pl3_num_punishes + + # generate the dict keys for logging + letters = ["C", "D"] + combinations = list(itertools.product(letters, repeat=2)) + combinations = ["".join(c) for c in combinations] + + all_game_visitation_strs = [ + "all_games_state_visitation/" + c for c in combinations + ] + all_game_prob_strs = [ + "all_game_state_probability/" + c for c in combinations + ] + + pl1_v_pl2_visitation_strs = [ + "pl1_v_pl2_state_visitation/" + c for c in combinations + ] + pl1_v_pl2_prob_strs = [ + "pl1_v_pl2_state_probability/" + c for c in combinations + ] + + pl1_v_pl3_visitation_strs = [ + "pl1_v_pl3_state_visitation/" + c for c in combinations + ] + pl1_v_pl3_prob_strs = [ + "pl1_v_pl3_state_probability/" + c for c in combinations + ] + + pl2_v_pl3_visitation_strs = [ + "pl2_v_pl3_state_visitation/" + c for c in combinations + ] + pl2_v_pl3_prob_strs = [ + "pl2_v_pl3_state_probability/" + c for c in combinations + ] + + pl1_total_defects_prob_str = "pl1_total_defects_prob" + pl2_total_defects_prob_str = "pl2_total_defects_prob" + pl3_total_defects_prob_str = "pl3_total_defects_prob" + + pl1_punished_prob_str = "pl1_punished_prob" + pl2_punished_prob_str = "pl2_punished_prob" + pl3_punished_prob_str = "pl3_punished_prob" + + pl1_num_punishes_str = "pl1_num_punishes" + pl2_num_punishes_str = "pl2_num_punishes" + pl3_num_punishes_str = "pl3_num_punishes" + + pl1_punished_defect_to_total_punishes_str = ( + "pl1_punished_defect_to_total_punishes" + ) + pl1_punished_defects_to_total_opn_defects_str = ( + "pl1_punished_defects_to_total_opn_defects" + ) + pl2_punished_defect_to_total_punishes_str = ( + "pl2_punished_defect_to_total_punishes" + ) + pl2_punished_defects_to_total_opn_defects_str = ( + "pl2_punished_defects_to_total_opn_defects" + ) + pl3_punished_defect_to_total_punishes_str = ( + "pl3_punished_defect_to_total_punishes" + ) + pl3_punished_defects_to_total_opn_defects_str = ( + "pl3_punished_defects_to_total_opn_defects" + ) + + total_punishment_str = "total_punishment" + + visitation_dict = ( + dict(zip(all_game_visitation_strs, action_freq)) + | dict(zip(all_game_prob_strs, action_probs)) + | dict(zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq)) + | dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs)) + | dict(zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq)) + | dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs)) + | dict(zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq)) + | dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs)) + | {pl1_total_defects_prob_str: pl1_total_defects_prob} + | {pl2_total_defects_prob_str: pl2_total_defects_prob} + | {pl3_total_defects_prob_str: pl3_total_defects_prob} + | {pl1_punished_prob_str: pl1_punished_prob} + | {pl2_punished_prob_str: pl2_punished_prob} + | {pl3_punished_prob_str: pl3_punished_prob} + | {pl1_num_punishes_str: pl1_num_punishes} + | {pl2_num_punishes_str: pl2_num_punishes} + | {pl3_num_punishes_str: pl3_num_punishes} + | {total_punishment_str: total_punishment} + | { + pl1_punished_defect_to_total_punishes_str: pl1_punished_defect_to_total_punishes + } + | { + pl1_punished_defects_to_total_opn_defects_str: pl1_punished_defects_to_total_opn_defects + } + | { + pl2_punished_defect_to_total_punishes_str: pl2_punished_defect_to_total_punishes + } + | { + pl2_punished_defects_to_total_opn_defects_str: pl2_punished_defects_to_total_opn_defects + } + | { + pl3_punished_defect_to_total_punishes_str: pl3_punished_defect_to_total_punishes + } + | { + pl3_punished_defects_to_total_opn_defects_str: pl3_punished_defects_to_total_opn_defects + } + ) + + return visitation_dict + + +def third_party_random_visitation( + log_obs: jnp.ndarray, +) -> dict: + + (prev_actions, curr_actions) = log_obs + prev_actions = jnp.moveaxis( + jnp.array(prev_actions).reshape(3, -1), 0, -1 + ) # shape (-1,3) + curr_actions = jnp.moveaxis( + jnp.array(curr_actions).reshape(3, -1), 0, -1 + ) # shape (-1,3) + len_idx = prev_actions.shape[0] + + prev_cd_actions = jnp.where(prev_actions > 3, 1, 0) + punish_actions = curr_actions % 4 + + pl1_defect_prob = sum(prev_cd_actions[:, 0]) / len(prev_cd_actions[:, 0]) + pl2_defect_prob = sum(prev_cd_actions[:, 1]) / len(prev_cd_actions[:, 1]) + pl3_defect_prob = sum(prev_cd_actions[:, 2]) / len(prev_cd_actions[:, 2]) + + game1 = jnp.stack([prev_cd_actions[:, 0], prev_cd_actions[:, 1]], axis=-1) + game2 = jnp.stack([prev_cd_actions[:, 1], prev_cd_actions[:, 2]], axis=-1) + game3 = jnp.stack([prev_cd_actions[:, 2], prev_cd_actions[:, 0]], axis=-1) + + selected_game_actions = jnp.stack( # len_idx3x2 + [game1, game2, game3], axis=1 + ) + + b2i = 2 ** jnp.arange(2 - 1, -1, -1) + pl1_vs_pl2 = (game1 * b2i).sum(axis=-1) + pl2_vs_pl3 = (game2 * b2i).sum(axis=-1) + pl3_vs_pl1 = (game3 * b2i).sum(axis=-1) + # all 3 games combined + all_games_actions = jnp.concatenate([pl1_vs_pl2, pl2_vs_pl3, pl3_vs_pl1]) + hist = jnp.bincount(all_games_actions, length=4) + action_probs = hist / hist.sum() + + # games action breakdown + pl1_v_pl2_hist = jnp.bincount(pl1_vs_pl2, length=4) + pl1_v_pl2_action_probs = pl1_v_pl2_hist / pl1_v_pl2_hist.sum() + + pl2_v_pl3_hist = jnp.bincount(pl2_vs_pl3, length=4) + pl2_v_pl3_action_probs = pl2_v_pl3_hist / pl2_v_pl3_hist.sum() + + pl3_v_pl1_hist = jnp.bincount(pl3_vs_pl1, length=4) + pl3_v_pl1_action_probs = pl3_v_pl1_hist / pl3_v_pl1_hist.sum() + + # pl1 got punished if pl2 punishes second player or pl3 punishes first player + pun_pl1 = jnp.where(punish_actions[:, 1] > 1, 1, 0) + jnp.where( + punish_actions[:, 2] % 2 == 1, 1, 0 + ) + # pl2 got punished if pl3 punishes second player or pl1 punishes first player + pun_pl2 = jnp.where(punish_actions[:, 2] > 1, 1, 0) + jnp.where( + punish_actions[:, 0] % 2 == 1, 1, 0 + ) + # pl3 got punished if pl1 punishes second player or pl2 punishes first player + pun_pl3 = jnp.where(punish_actions[:, 0] > 1, 1, 0) + jnp.where( + punish_actions[:, 1] % 2 == 1, 1, 0 + ) + + game_selected_punish = (pun_pl1 + pun_pl2 + pun_pl3).sum() / len_idx + + # how man defects they punished + intr_pl1 = jnp.where(punish_actions[:, 0] % 2 == 1, 1, 0) * jnp.where( + prev_cd_actions[:, 1] == 1, 1, 0 + ) + jnp.where(punish_actions[:, 0] > 1, 1, 0) * jnp.where( + prev_cd_actions[:, 2] == 1, 1, 0 + ) + intr_pl2 = jnp.where(punish_actions[:, 1] % 2 == 1, 1, 0) * jnp.where( + prev_cd_actions[:, 2] == 1, 1, 0 + ) + jnp.where(punish_actions[:, 1] > 1, 1, 0) * jnp.where( + prev_cd_actions[:, 0] == 1, 1, 0 + ) + intr_pl3 = jnp.where(punish_actions[:, 2] % 2 == 1, 1, 0) * jnp.where( + prev_cd_actions[:, 0] == 1, 1, 0 + ) + jnp.where(punish_actions[:, 2] > 1, 1, 0) * jnp.where( + prev_cd_actions[:, 1] == 1, 1, 0 + ) + + pl1_punish_defect_vs_total_defect = intr_pl1.sum() / game2.sum() + pl2_punish_defect_vs_total_defect = intr_pl2.sum() / game3.sum() + pl3_punish_defect_vs_total_defect = intr_pl3.sum() / game1.sum() + + pl1_total_punish = ( + jnp.where(punish_actions[:, 0] == 1, 1, 0) + + jnp.where(punish_actions[:, 0] == 2, 1, 0) + + jnp.where(punish_actions[:, 0] == 3, 2, 0) + ).sum() + + pl2_total_punish = ( + jnp.where(punish_actions[:, 1] == 1, 1, 0) + + jnp.where(punish_actions[:, 1] == 2, 1, 0) + + jnp.where(punish_actions[:, 1] == 3, 2, 0) + ).sum() + + pl3_total_punish = ( + jnp.where(punish_actions[:, 2] == 1, 1, 0) + + jnp.where(punish_actions[:, 2] == 2, 1, 0) + + jnp.where(punish_actions[:, 2] == 3, 2, 0) + ).sum() + + pl1_punish_defect_vs_total_punish = intr_pl1.sum() / ( + pl1_total_punish + 0.0001 + ) + pl1_punish_prob = pl1_total_punish / len(punish_actions[:, 0]) + + pl2_punish_defect_vs_total_punish = intr_pl2.sum() / ( + pl2_total_punish + 0.0001 + ) + pl2_punish_prob = pl2_total_punish / len(punish_actions[:, 1]) + + pl3_punish_defect_vs_total_punish = intr_pl3.sum() / ( + pl3_total_punish + 0.0001 + ) + pl3_punish_prob = pl3_total_punish / len(punish_actions[:, 2]) + + # generate the dict keys for logging + combinations = ["CC", "CD", "DC", "DD"] + + game_prob_strs = [ + "total_game_state_probability/" + c for c in combinations + ] + game_selected_punish_str = "game_selected_punish_prob" + + game1_prob_strs = ["game1_state_probability/" + c for c in combinations] + game2_prob_strs = ["game2_state_probability/" + c for c in combinations] + game3_prob_strs = ["game3_state_probability/" + c for c in combinations] + + pl1_defects_prob_str = "pl1_defects_prob" + pl2_defects_prob_str = "pl2_defects_prob" + pl3_defects_prob_str = "pl3_defects_prob" + + pl1_punish_prob_str = "pl1_punish_prob" + pl2_punish_prob_str = "pl2_punish_prob" + pl3_punish_prob_str = "pl3_punish_prob" + + pl1_punished_defect_to_total_punishes_str = ( + "pl1_punished_defect_to_total_punishes" + ) + pl1_punished_defects_to_total_opn_defects_str = ( + "pl1_punished_defects_to_total_opn_defects" + ) + pl2_punished_defect_to_total_punishes_str = ( + "pl2_punished_defect_to_total_punishes" + ) + pl2_punished_defects_to_total_opn_defects_str = ( + "pl2_punished_defects_to_total_opn_defects" + ) + pl3_punished_defect_to_total_punishes_str = ( + "pl3_punished_defect_to_total_punishes" + ) + pl3_punished_defects_to_total_opn_defects_str = ( + "pl3_punished_defects_to_total_opn_defects" + ) + + visitation_dict = ( + dict(zip(game_prob_strs, action_probs)) + | dict(zip(game1_prob_strs, pl1_v_pl2_action_probs)) + | dict(zip(game2_prob_strs, pl2_v_pl3_action_probs)) + | dict(zip(game3_prob_strs, pl3_v_pl1_action_probs)) + | {game_selected_punish_str: game_selected_punish} + | {pl1_defects_prob_str: pl1_defect_prob} + | {pl2_defects_prob_str: pl2_defect_prob} + | {pl3_defects_prob_str: pl3_defect_prob} + | {pl1_punish_prob_str: pl1_punish_prob} + | {pl2_punish_prob_str: pl2_punish_prob} + | {pl3_punish_prob_str: pl3_punish_prob} + | { + pl1_punished_defect_to_total_punishes_str: pl1_punish_defect_vs_total_punish + } + | { + pl1_punished_defects_to_total_opn_defects_str: pl1_punish_defect_vs_total_defect + } + | { + pl2_punished_defect_to_total_punishes_str: pl2_punish_defect_vs_total_punish + } + | { + pl2_punished_defects_to_total_opn_defects_str: pl2_punish_defect_vs_total_defect + } + | { + pl3_punished_defect_to_total_punishes_str: pl3_punish_defect_vs_total_punish + } + | { + pl3_punished_defects_to_total_opn_defects_str: pl3_punish_defect_vs_total_defect + } + ) + return visitation_dict + + +# def tensor_ipd_coop_probs( +# observations: jnp.ndarray, +# actions: jnp.ndarray, +# final_obs: jnp.ndarray, +# agent_idx: int = 1, +# ) -> dict: +# num_timesteps = observations.shape[0] * observations.shape[1] +# # obs = [0....8], a = [0, 1] +# # combine = [0, .... 17] +# state_actions = 2 * jnp.argmax(observations, axis=-1) + actions +# state_actions = jnp.reshape( +# state_actions, +# (num_timesteps,) + state_actions.shape[2:], +# ) +# final_obs = jax.lax.expand_dims(2 * jnp.argmax(final_obs, axis=-1), [0]) +# state_actions = jnp.append(state_actions, final_obs, axis=0) +# hist = jnp.bincount(state_actions.flatten(), length=18) +# state_freq = hist.reshape((int(hist.shape[0] / 2), 2)).sum(axis=1) +# action_probs = jnp.nan_to_num(hist[::2] / state_freq) +# # THIS IS FROM AGENTS OWN PERSPECTIVE +# return { +# f"cooperation_probability/{agent_idx}/CCC": action_probs[0], +# f"cooperation_probability/{agent_idx}/CCD": action_probs[1], +# f"cooperation_probability/{agent_idx}/CDC": action_probs[2], +# f"cooperation_probability/{agent_idx}/CDD": action_probs[3], +# f"cooperation_probability/{agent_idx}/DCC": action_probs[4], +# f"cooperation_probability/{agent_idx}/DCD": action_probs[5], +# f"cooperation_probability/{agent_idx}/DDC": action_probs[6], +# f"cooperation_probability/{agent_idx}/DDD": action_probs[7], +# f"cooperation_probability/{agent_idx}/START": action_probs[8], +# } + + def cg_visitation(state: NamedTuple) -> dict: # [num_opps, num_envs, num_outer_episodes] total_1 = state.red_coop + state.red_defect diff --git a/setup.py b/setup.py index e51ce348..9cca7ccc 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import os -import setuptools import subprocess +import setuptools + def _parse_requirements(requirements_txt_path): with open(requirements_txt_path) as fp: diff --git a/test/envs/test_in_the_matrix.py b/test/envs/test_in_the_matrix.py index 27ea81f1..2b0e1365 100644 --- a/test/envs/test_in_the_matrix.py +++ b/test/envs/test_in_the_matrix.py @@ -1,11 +1,7 @@ import jax import jax.numpy as jnp -from pax.envs.in_the_matrix import ( - InTheMatrix, - EnvParams, - EnvState, -) +from pax.envs.in_the_matrix import EnvParams, EnvState, InTheMatrix def test_ipditm_shapes(): diff --git a/test/envs/test_iterated_tensor_game_n_player.py b/test/envs/test_iterated_tensor_game_n_player.py new file mode 100644 index 00000000..bffbd423 --- /dev/null +++ b/test/envs/test_iterated_tensor_game_n_player.py @@ -0,0 +1,874 @@ +import jax +import jax.numpy as jnp +import pytest + +from pax.agents.strategies import TitForTat +from pax.agents.tensor_strategies import TitForTatStrictStay +from pax.envs.iterated_tensor_game_n_player import ( + EnvParams, + IteratedTensorGameNPlayer, +) + +####### 2 PLAYERS ####### +payoff_table_2pl = [ + [4, jnp.nan], + [2, 5], + [jnp.nan, 3], +] +cc_p1, cc_p2 = payoff_table_2pl[0][0], payoff_table_2pl[0][0] +cd_p1, cd_p2 = payoff_table_2pl[1][0], payoff_table_2pl[1][1] +dc_p1, dc_p2 = payoff_table_2pl[1][1], payoff_table_2pl[1][0] +dd_p1, dd_p2 = payoff_table_2pl[2][1], payoff_table_2pl[2][1] +# states +cc_obs = 0 +cd_obs = 1 +dc_obs = 2 +dd_obs = 3 + +####### 3 PLAYERS ####### +payoff_table_3pl = [ + [4, jnp.nan], + [2.66, 5.66], + [1.33, 4.33], + [jnp.nan, 3], +] +# this is for verification +ccc_p1, ccc_p2, ccc_p3 = ( + payoff_table_3pl[0][0], + payoff_table_3pl[0][0], + payoff_table_3pl[0][0], +) +ccd_p1, ccd_p2, ccd_p3 = ( + payoff_table_3pl[1][0], + payoff_table_3pl[1][0], + payoff_table_3pl[1][1], +) +cdc_p1, cdc_p2, cdc_p3 = ( + payoff_table_3pl[1][0], + payoff_table_3pl[1][1], + payoff_table_3pl[1][0], +) +cdd_p1, cdd_p2, cdd_p3 = ( + payoff_table_3pl[2][0], + payoff_table_3pl[2][1], + payoff_table_3pl[2][1], +) +dcc_p1, dcc_p2, dcc_p3 = ( + payoff_table_3pl[1][1], + payoff_table_3pl[1][0], + payoff_table_3pl[1][0], +) +dcd_p1, dcd_p2, dcd_p3 = ( + payoff_table_3pl[2][1], + payoff_table_3pl[2][0], + payoff_table_3pl[2][1], +) +ddc_p1, ddc_p2, ddc_p3 = ( + payoff_table_3pl[2][1], + payoff_table_3pl[2][1], + payoff_table_3pl[2][0], +) +ddd_p1, ddd_p2, ddd_p3 = ( + payoff_table_3pl[3][1], + payoff_table_3pl[3][1], + payoff_table_3pl[3][1], +) + +# states +ccc_obs = 0 +ccd_obs = 1 +cdc_obs = 2 +cdd_obs = 3 +dcc_obs = 4 +dcd_obs = 5 +ddc_obs = 6 +ddd_obs = 7 + +####### 4 PLAYERS ####### +payoff_table_4pl = [ + [4, jnp.nan], + [3, 6], + [2, 5], + [1, 4], + [jnp.nan, 3], +] + +cccc_p1, cccc_p2, cccc_p3, cccc_p4 = ( + payoff_table_4pl[0][0], + payoff_table_4pl[0][0], + payoff_table_4pl[0][0], + payoff_table_4pl[0][0], +) +cccd_p1, cccd_p2, cccd_p3, cccd_p4 = ( + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], + payoff_table_4pl[1][1], +) +ccdc_p1, ccdc_p2, ccdc_p3, ccdc_p4 = ( + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], + payoff_table_4pl[1][1], + payoff_table_4pl[1][0], +) +ccdd_p1, ccdd_p2, ccdd_p3, ccdd_p4 = ( + payoff_table_4pl[2][0], + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], + payoff_table_4pl[2][1], +) +cdcc_p1, cdcc_p2, cdcc_p3, cdcc_p4 = ( + payoff_table_4pl[1][0], + payoff_table_4pl[1][1], + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], +) +cdcd_p1, cdcd_p2, cdcd_p3, cdcd_p4 = ( + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], +) +cddc_p1, cddc_p2, cddc_p3, cddc_p4 = ( + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], +) +cddd_p1, cddd_p2, cddd_p3, cddd_p4 = ( + payoff_table_4pl[3][0], + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], +) +dccc_p1, dccc_p2, dccc_p3, dccc_p4 = ( + payoff_table_4pl[1][1], + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], + payoff_table_4pl[1][0], +) +dccd_p1, dccd_p2, dccd_p3, dccd_p4 = ( + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], +) +dcdc_p1, dcdc_p2, dcdc_p3, dcdc_p4 = ( + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], +) +dcdd_p1, dcdd_p2, dcdd_p3, dcdd_p4 = ( + payoff_table_4pl[3][1], + payoff_table_4pl[3][0], + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], +) +ddcc_p1, ddcc_p2, ddcc_p3, ddcc_p4 = ( + payoff_table_4pl[2][1], + payoff_table_4pl[2][1], + payoff_table_4pl[2][0], + payoff_table_4pl[2][0], +) +ddcd_p1, ddcd_p2, ddcd_p3, ddcd_p4 = ( + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], + payoff_table_4pl[3][0], + payoff_table_4pl[3][1], +) +dddc_p1, dddc_p2, dddc_p3, dddc_p4 = ( + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], + payoff_table_4pl[3][1], + payoff_table_4pl[3][0], +) +dddd_p1, dddd_p2, dddd_p3, dddd_p4 = ( + payoff_table_4pl[4][1], + payoff_table_4pl[4][1], + payoff_table_4pl[4][1], + payoff_table_4pl[4][1], +) +# states +cccc_obs = 0 +cccd_obs = 1 +ccdc_obs = 2 +ccdd_obs = 3 +cdcc_obs = 4 +cdcd_obs = 5 +cddc_obs = 6 +cddd_obs = 7 +dccc_obs = 8 +dccd_obs = 9 +dcdc_obs = 10 +dcdd_obs = 11 +ddcc_obs = 12 +ddcd_obs = 13 +dddc_obs = 14 +dddd_obs = 15 + +# ###### Begin actual tests ####### +@pytest.mark.parametrize("payoff", [payoff_table_2pl]) +def test_single_batch_2pl(payoff) -> None: + num_envs = 5 + rng = jax.random.PRNGKey(0) + num_players = 2 + len_one_hot = 2**num_players + 1 + ##### setup + env = IteratedTensorGameNPlayer( + num_players=2, num_inner_steps=5, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff) + + action = jnp.ones((num_envs,), dtype=jnp.float32) + r_array = jnp.ones((num_envs,), dtype=jnp.float32) + obs_array = jnp.ones((num_envs,), dtype=jnp.float32) + + # we want to batch over actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + + ###### test 2 player + # cc + obs, env_state, rewards, done, info = env.step( + rng, env_state, (0 * action, 0 * action), env_params + ) + expected_reward1 = cc_p1 * r_array + expected_reward2 = cc_p2 * r_array + expected_obs1 = jax.nn.one_hot( + cc_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + expected_obs2 = jax.nn.one_hot( + cc_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + + ##dc + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 0 * action), env_params + ) + + expected_obs1 = jax.nn.one_hot( + dc_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + expected_obs2 = jax.nn.one_hot( + cd_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + assert jnp.array_equal(rewards[0], dc_p1 * r_array) + assert jnp.array_equal(rewards[1], dc_p2 * r_array) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + + ##cd + obs, env_state, rewards, done, info = env.step( + rng, env_state, (0 * action, 1 * action), env_params + ) + expected_obs1 = jax.nn.one_hot( + cd_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + expected_obs2 = jax.nn.one_hot( + dc_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + assert jnp.array_equal(rewards[0], cd_p1 * r_array) + assert jnp.array_equal(rewards[1], cd_p2 * r_array) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + + ##dd + obs, env_state, rewards, done, info = env.step( + rng, env_state, (1 * action, 1 * action), env_params + ) + expected_obs1 = jax.nn.one_hot( + dd_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + expected_obs2 = jax.nn.one_hot( + dd_obs * obs_array, len_one_hot, dtype=jnp.int8 + ) + assert jnp.array_equal(rewards[0], dd_p1 * r_array) + assert jnp.array_equal(rewards[1], dd_p2 * r_array) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + + +@pytest.mark.parametrize("payoff", [payoff_table_4pl]) +def test_single_batch_4pl(payoff) -> None: + num_envs = 5 + rng = jax.random.PRNGKey(0) + num_players = 4 + len_one_hot = 2**num_players + 1 + ##### setup + env = IteratedTensorGameNPlayer( + num_players=num_players, num_inner_steps=100, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff) + + action = jnp.ones((num_envs,), dtype=jnp.float32) + r_array = jnp.ones((num_envs,), dtype=jnp.float32) + obs_array = jnp.ones((num_envs,), dtype=jnp.float32) + + # we want to batch over actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + + ###### test 4 player + # cccc + cccc = (0 * action, 0 * action, 0 * action, 0 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, cccc, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + cccc_p1 * r_array, + cccc_p2 * r_array, + cccc_p3 * r_array, + cccc_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(cccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + + # cccd + cccd = (0 * action, 0 * action, 0 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, cccd, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + cccd_p1 * r_array, + cccd_p2 * r_array, + cccd_p3 * r_array, + cccd_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(cccd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ccdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cdcc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # ccdc + ccdc = (0 * action, 0 * action, 1 * action, 0 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, ccdc, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + ccdc_p1 * r_array, + ccdc_p2 * r_array, + ccdc_p3 * r_array, + ccdc_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(ccdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cdcc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cccd_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # ccdd + ccdd = (0 * action, 0 * action, 1 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, ccdd, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + ccdd_p1 * r_array, + ccdd_p2 * r_array, + ccdd_p3 * r_array, + ccdd_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(ccdd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cddc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ddcc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dccd_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # cdcc + cdcc = (0 * action, 1 * action, 0 * action, 0 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, cdcc, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + cdcc_p1 * r_array, + cdcc_p2 * r_array, + cdcc_p3 * r_array, + cdcc_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(cdcc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cccd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ccdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # cdcd + cdcd = (0 * action, 1 * action, 0 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, cdcd, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + cdcd_p1 * r_array, + cdcd_p2 * r_array, + cdcd_p3 * r_array, + cdcd_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(cdcd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dcdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cdcd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dcdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # cddc skipped + # cddd skipped + # dccd skipped + # dcdc skipped + # dcdd + dcdd = (1 * action, 0 * action, 1 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, dcdd, env_params + ) + expected_reward1, expected_reward2, expected_reward3, expected_reward4 = ( + dcdd_p1 * r_array, + dcdd_p2 * r_array, + dcdd_p3 * r_array, + dcdd_p4 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3, expected_obs4 = ( + jax.nn.one_hot(dcdd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cddd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(dddc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ddcd_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(rewards[3], expected_reward4) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + assert jnp.array_equal(obs[3], expected_obs4) + # ddcc skipped + # ddcd skipped + # dddc skipped + # dddd skipped + + +@pytest.mark.parametrize("payoff", [payoff_table_3pl]) +def test_single_batch_3pl(payoff) -> None: + num_envs = 5 + rng = jax.random.PRNGKey(0) + num_players = 3 + len_one_hot = 2**num_players + 1 + ##### setup + env = IteratedTensorGameNPlayer( + num_players=num_players, num_inner_steps=100, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff) + + action = jnp.ones((num_envs,), dtype=jnp.float32) + r_array = jnp.ones((num_envs,), dtype=jnp.float32) + obs_array = jnp.ones((num_envs,), dtype=jnp.float32) + + # we want to batch over actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + + ###### test 3 player + # ccc + ccc = (0 * action, 0 * action, 0 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, ccc, env_params + ) + expected_reward1, expected_reward2, expected_reward3 = ( + ccc_p1 * r_array, + ccc_p2 * r_array, + ccc_p3 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3 = ( + jax.nn.one_hot(ccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ccc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + + # dcd + dcd = (1 * action, 0 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, dcd, env_params + ) + expected_reward1, expected_reward2, expected_reward3 = ( + dcd_p1 * r_array, + dcd_p2 * r_array, + dcd_p3 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3 = ( + jax.nn.one_hot(dcd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cdd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ddc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + + # dcc + dcc = (1 * action, 0 * action, 0 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, dcc, env_params + ) + expected_reward1, expected_reward2, expected_reward3 = ( + dcc_p1 * r_array, + dcc_p2 * r_array, + dcc_p3 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3 = ( + jax.nn.one_hot(dcc_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ccd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(cdc_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + + # ddd + ddd = (1 * action, 1 * action, 1 * action) + obs, env_state, rewards, done, info = env.step( + rng, env_state, ddd, env_params + ) + expected_reward1, expected_reward2, expected_reward3 = ( + ddd_p1 * r_array, + ddd_p2 * r_array, + ddd_p3 * r_array, + ) + expected_obs1, expected_obs2, expected_obs3 = ( + jax.nn.one_hot(ddd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ddd_obs * obs_array, len_one_hot, dtype=jnp.int8), + jax.nn.one_hot(ddd_obs * obs_array, len_one_hot, dtype=jnp.int8), + ) + assert jnp.array_equal(rewards[0], expected_reward1) + assert jnp.array_equal(rewards[1], expected_reward2) + assert jnp.array_equal(rewards[2], expected_reward3) + assert jnp.array_equal(obs[0], expected_obs1) + assert jnp.array_equal(obs[1], expected_obs2) + assert jnp.array_equal(obs[2], expected_obs3) + + +def test_batch_diff_actions() -> None: + num_envs = 5 + all_ones = jnp.ones((num_envs,)) + + rng = jax.random.PRNGKey(0) + + # 5 envs, with actions ccc, dcc, cdc, ddd, ccd + pl1_actions = jnp.array([0, 1, 0, 1, 0], dtype=jnp.float32) + pl2_actions = jnp.array([0, 0, 1, 1, 0], dtype=jnp.float32) + pl3_actions = jnp.array([0, 0, 0, 1, 1], dtype=jnp.float32) + given_actions = (pl1_actions, pl2_actions, pl3_actions) + pl1_reward = jnp.array( + [ccc_p1, dcc_p1, cdc_p1, ddd_p1, ccd_p1], dtype=jnp.float32 + ) + pl2_reward = jnp.array( + [ccc_p2, dcc_p2, cdc_p2, ddd_p2, ccd_p2], dtype=jnp.float32 + ) + pl3_reward = jnp.array( + [ccc_p3, dcc_p3, cdc_p3, ddd_p3, ccd_p3], dtype=jnp.float32 + ) + exp_rewards = (pl1_reward, pl2_reward, pl3_reward) + pl1_states = jnp.array( + [ccc_obs, dcc_obs, cdc_obs, ddd_obs, ccd_obs], dtype=jnp.float32 + ) + pl2_states = jnp.array( + [ccc_obs, ccd_obs, dcc_obs, ddd_obs, cdc_obs], dtype=jnp.float32 + ) + pl3_states = jnp.array( + [ccc_obs, cdc_obs, ccd_obs, ddd_obs, dcc_obs], dtype=jnp.float32 + ) + exp_obs = [ + jax.nn.one_hot(pl1_states, 9, dtype=jnp.int8), + jax.nn.one_hot(pl2_states, 9, dtype=jnp.int8), + jax.nn.one_hot(pl3_states, 9, dtype=jnp.int8), + ] + + env = IteratedTensorGameNPlayer( + num_players=3, num_inner_steps=10, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff_table_3pl) + # we want to batch over envs purely by actions + env.step = jax.vmap( + env.step, in_axes=(None, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, + env_state, + given_actions, + env_params, + ) + assert jnp.array_equal(rewards[0], exp_rewards[0]) + assert jnp.array_equal(rewards[1], exp_rewards[1]) + assert jnp.array_equal(rewards[2], exp_rewards[2]) + assert jnp.array_equal(obs[0], exp_obs[0]) + assert jnp.array_equal(obs[1], exp_obs[1]) + assert jnp.array_equal(obs[2], exp_obs[2]) + assert (done == False).all() + + +def test_tit_for_tat_strict_match_3pl() -> None: + # just tests they all cooperate + num_envs = 5 + rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( + num_envs, -1 + ) + env = IteratedTensorGameNPlayer( + num_players=3, num_inner_steps=5, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff_table_3pl) + + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + + obs, env_state = env.reset(rngs, env_params) + tit_for_tat = TitForTatStrictStay(num_envs) + + action_0, _, _ = tit_for_tat._policy(None, obs[0], None) + action_1, _, _ = tit_for_tat._policy(None, obs[1], None) + action_2, _, _ = tit_for_tat._policy(None, obs[2], None) + assert jnp.array_equal(action_0, action_1) + assert jnp.array_equal(action_0, action_2) + + for _ in range(10): + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (action_0, action_1, action_2), env_params + ) + assert jnp.array_equal( + rewards[0], + rewards[1], + ) + assert jnp.array_equal( + rewards[0], + rewards[2], + ) + + +def test_tit_for_tat_strict_match_4pl() -> None: + # just tests they all cooperate + num_envs = 5 + rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( + num_envs, -1 + ) + env = IteratedTensorGameNPlayer( + num_players=4, num_inner_steps=5, num_outer_steps=1 + ) + env_params = EnvParams(payoff_table=payoff_table_4pl) + + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + + obs, env_state = env.reset(rngs, env_params) + tit_for_tat = TitForTatStrictStay(num_envs) + + action_0, _, _ = tit_for_tat._policy(None, obs[0], None) + action_1, _, _ = tit_for_tat._policy(None, obs[1], None) + action_2, _, _ = tit_for_tat._policy(None, obs[2], None) + action_3, _, _ = tit_for_tat._policy(None, obs[3], None) + assert jnp.array_equal(action_0, action_1) + assert jnp.array_equal(action_0, action_2) + assert jnp.array_equal(action_0, action_3) + + for _ in range(10): + obs, env_state, rewards, done, info = env.step( + rngs, + env_state, + (action_0, action_1, action_2, action_3), + env_params, + ) + assert jnp.array_equal( + rewards[0], + rewards[1], + ) + assert jnp.array_equal( + rewards[0], + rewards[2], + ) + assert jnp.array_equal( + rewards[0], + rewards[3], + ) + + +def test_longer_game_4pl() -> None: + num_envs = 2 + num_outer_steps = 25 + num_inner_steps = 2 + env = IteratedTensorGameNPlayer( + num_players=4, + num_inner_steps=num_inner_steps, + num_outer_steps=num_outer_steps, + ) + + # batch over actions and env_states + env.reset = jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) + env.step = jax.vmap( + env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) + ) + + env_params = EnvParams(payoff_table=payoff_table_4pl) + rngs = jnp.concatenate(num_envs * [jax.random.PRNGKey(0)]).reshape( + num_envs, -1 + ) + obs, env_state = env.reset(rngs, env_params) + agent = TitForTatStrictStay(num_envs) + r1 = [] + r2 = [] + r3 = [] + r4 = [] + for _ in range(num_outer_steps): + for _ in range(num_inner_steps): + action, _, _ = agent._policy(None, obs[0], None) + obs, env_state, rewards, done, info = env.step( + rngs, env_state, (action, action, action, action), env_params + ) + r1.append(rewards[0]) + r2.append(rewards[1]) + r3.append(rewards[2]) + r4.append(rewards[3]) + # all coop w/ this strategy + assert jnp.array_equal(rewards[0], rewards[1]) + assert jnp.array_equal(rewards[0], rewards[2]) + assert jnp.array_equal(rewards[0], rewards[3]) + + assert (done == True).all() + + assert jnp.mean(jnp.stack(r1)) == 4 + assert jnp.mean(jnp.stack(r2)) == 4 + assert jnp.mean(jnp.stack(r3)) == 4 + + +def test_done(): + num_inner_steps = 5 + env = IteratedTensorGameNPlayer( + num_players=4, num_inner_steps=num_inner_steps, num_outer_steps=1 + ) + env_params = EnvParams( + payoff_table=payoff_table_4pl, + ) + rng = jax.random.PRNGKey(0) + obs, env_state = env.reset(rng, env_params) + action = 0 + + for _ in range(num_inner_steps - 1): + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action, action, action), env_params + ) + # check not start state + assert (done == False).all() + assert (obs[0].argmax() != 8).all() + assert (obs[1].argmax() != 8).all() + assert (obs[2].argmax() != 8).all() + assert (obs[3].argmax() != 8).all() + + # check final + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action, action, action), env_params + ) + assert (done == True).all() + + # check back at start + assert jnp.array_equal(obs[0].argmax(), 2**4) + assert jnp.array_equal(obs[1].argmax(), 2**4) + assert jnp.array_equal(obs[2].argmax(), 2**4) + assert jnp.array_equal(obs[3].argmax(), 2**4) + + +def test_reset(): + rng = jax.random.PRNGKey(0) + env = IteratedTensorGameNPlayer( + num_players=4, num_inner_steps=5, num_outer_steps=20 + ) + env_params = EnvParams( + payoff_table=payoff_table_4pl, + ) + action = 0 + + obs, env_state = env.reset(rng, env_params) + # one fewer than inner steps + for _ in range(4): + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action, action, action), env_params + ) + assert done == False + + obs, env_state = env.reset(rng, env_params) + # assert not done bc we reset env + for _ in range(4): + obs, env_state, rewards, done, info = env.step( + rng, env_state, (action, action, action, action), env_params + ) + assert done == False diff --git a/test/test_tensor_strategies.py b/test/test_tensor_strategies.py new file mode 100644 index 00000000..c6e55701 --- /dev/null +++ b/test/test_tensor_strategies.py @@ -0,0 +1,614 @@ +import jax +import jax.numpy as jnp + +from pax.agents.tensor_strategies import ( + TitForTatCooperate, + TitForTatDefect, + TitForTatHarsh, + TitForTatSoft, + TitForTatStrictStay, + TitForTatStrictSwitch, +) + +batch_number = 3 +# all obs are only of final state e.g batch x dim. +ccc_obs = jnp.asarray(batch_number * [[1, 0, 0, 0, 0, 0, 0, 0, 0]]) +ccd_obs = jnp.asarray(batch_number * [[0, 1, 0, 0, 0, 0, 0, 0, 0]]) +cdc_obs = jnp.asarray(batch_number * [[0, 0, 1, 0, 0, 0, 0, 0, 0]]) +cdd_obs = jnp.asarray(batch_number * [[0, 0, 0, 1, 0, 0, 0, 0, 0]]) +dcc_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 1, 0, 0, 0, 0]]) +dcd_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 1, 0, 0, 0]]) +ddc_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 0, 1, 0, 0]]) +ddd_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 0, 0, 1, 0]]) +initial_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 0, 0, 0, 1]]) + +cooperate_action = 0 * jnp.ones((batch_number,), dtype=jnp.int32) +defect_action = 1 * jnp.ones((batch_number,), dtype=jnp.int32) + + +def test_titfortat_strict_stay(): + # Switch to what opponents did if the other two played the same move + # otherwise play as before + agent = TitForTatStrictStay(num_envs=1) + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + +def test_titfortat_strict_switch(): + # Play what opponents did if the other two played the same move + # otherwise switch from previous + agent = TitForTatStrictSwitch(num_envs=1) + + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + +def test_titfortat_cooperate(): + agent = TitForTatCooperate(num_envs=1) + + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + +def test_titfortat_defect(): + agent = TitForTatDefect(num_envs=1) + + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + +def test_titfortat_harsh(): + agent = TitForTatHarsh(num_envs=1) + # test with 3 players + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + initial_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 0, 0, 0, 1]]) + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + # test with 2 players + cc_obs = jnp.asarray(batch_number * [[1, 0, 0, 0, 0]]) + cd_obs = jnp.asarray(batch_number * [[0, 1, 0, 0, 0]]) + dc_obs = jnp.asarray(batch_number * [[0, 0, 1, 0, 0]]) + dd_obs = jnp.asarray(batch_number * [[0, 0, 0, 1, 0]]) + initial_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 1]]) + action, _, _ = agent._policy(None, cc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + # test with 4 players for few actions + cccc_obs = jnp.asarray( + batch_number + * [ + [ + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + dccc_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + initial_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + ] + ] + ) + + dccd_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + ccdc_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + dddd_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + ] + ] + ) + + action, _, _ = agent._policy(None, cccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dccd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, ccdc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + +def test_titfortat_soft(): + # Defect if at more than half of opponents defected + agent = TitForTatSoft(num_envs=1) + # test with 3 players + action, _, _ = agent._policy(None, ccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cdd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dcc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dcd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ddd_obs, None) + assert jnp.array_equal(defect_action, action) + initial_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 0, 0, 0, 0, 1]]) + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + # test with 2 players + cc_obs = jnp.asarray(batch_number * [[1, 0, 0, 0, 0]]) + cd_obs = jnp.asarray(batch_number * [[0, 1, 0, 0, 0]]) + dc_obs = jnp.asarray(batch_number * [[0, 0, 1, 0, 0]]) + dd_obs = jnp.asarray(batch_number * [[0, 0, 0, 1, 0]]) + initial_obs = jnp.asarray(batch_number * [[0, 0, 0, 0, 1]]) + action, _, _ = agent._policy(None, cc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, cd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action) + + # test with 4 players for few actions + cccc_obs = jnp.asarray( + batch_number + * [ + [ + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + dccc_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + initial_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + ] + ] + ) + + dccd_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + ccdc_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ] + ) + dddd_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + ] + ] + ) + dddc_obs = jnp.asarray( + batch_number + * [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + ] + ] + ) + action, _, _ = agent._policy(None, cccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dccc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dccd_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, ccdc_obs, None) + assert jnp.array_equal(cooperate_action, action) + + action, _, _ = agent._policy(None, dddd_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, dddc_obs, None) + assert jnp.array_equal(defect_action, action) + + action, _, _ = agent._policy(None, initial_obs, None) + assert jnp.array_equal(cooperate_action, action)