From dd6555bc6c01059f89345c5d1d1afdb9ea631770 Mon Sep 17 00:00:00 2001 From: Aidandos Date: Tue, 24 Oct 2023 12:56:57 +0000 Subject: [PATCH] restructuring and deleting some runners --- .gitignore | 5 + pax/experiment.py | 14 +- .../runner_evo_mixed_IPD_payoffs.py | 2 + .../{ => experimental}/runner_evo_mixed_lr.py | 13 +- .../runner_evo_mixed_payoffs.py | 1 + .../runner_evo_mixed_payoffs_gen.py | 1 + .../runner_evo_mixed_payoffs_input.py | 8 +- .../runner_evo_mixed_payoffs_only_opp.py | 2 + ...runner_evo_mixed_payoffs_input_onlymeta.py | 661 ------------------ pax/runners/runner_evo_mixed_payoffs_pred.py | 645 ----------------- 10 files changed, 31 insertions(+), 1321 deletions(-) rename pax/runners/{ => experimental}/runner_evo_mixed_IPD_payoffs.py (99%) rename pax/runners/{ => experimental}/runner_evo_mixed_lr.py (97%) rename pax/runners/{ => experimental}/runner_evo_mixed_payoffs.py (99%) rename pax/runners/{ => experimental}/runner_evo_mixed_payoffs_gen.py (99%) rename pax/runners/{ => experimental}/runner_evo_mixed_payoffs_input.py (98%) rename pax/runners/{ => experimental}/runner_evo_mixed_payoffs_only_opp.py (99%) delete mode 100644 pax/runners/runner_evo_mixed_payoffs_input_onlymeta.py delete mode 100644 pax/runners/runner_evo_mixed_payoffs_pred.py diff --git a/.gitignore b/.gitignore index f8f04061..94f827de 100644 --- a/.gitignore +++ b/.gitignore @@ -114,3 +114,8 @@ experiment.log # Pax pax/version.py + +*.gif +*.json +*.png +*.sh diff --git a/pax/experiment.py b/pax/experiment.py index b714bad7..090e0957 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -65,13 +65,13 @@ from pax.runners.runner_evo import EvoRunner from pax.runners.runner_evo_multishaper import MultishaperEvoRunner from pax.runners.runner_evo_hardstop import EvoHardstopRunner -from pax.runners.runner_evo_mixed_lr import EvoMixedLRRunner -from pax.runners.runner_evo_mixed_payoffs import EvoMixedPayoffRunner -from pax.runners.runner_evo_mixed_IPD_payoffs import EvoMixedIPDPayoffRunner -from pax.runners.runner_evo_mixed_payoffs_input import EvoMixedPayoffInputRunner -from pax.runners.runner_evo_mixed_payoffs_gen import EvoMixedPayoffGenRunner -from pax.runners.runner_evo_mixed_payoffs_pred import EvoMixedPayoffPredRunner -from pax.runners.runner_evo_mixed_payoffs_only_opp import EvoMixedPayoffOnlyOppRunner +from pax.runners.experimental.runner_evo_mixed_lr import EvoMixedLRRunner +from pax.runners.experimental.runner_evo_mixed_payoffs import EvoMixedPayoffRunner +from pax.runners.experimental.runner_evo_mixed_IPD_payoffs import EvoMixedIPDPayoffRunner +from pax.runners.experimental.runner_evo_mixed_payoffs_input import EvoMixedPayoffInputRunner +from pax.runners.experimental.runner_evo_mixed_payoffs_gen import EvoMixedPayoffGenRunner +from pax.runners.experimental.runner_evo_mixed_payoffs_pred import EvoMixedPayoffPredRunner +from pax.runners.experimental.runner_evo_mixed_payoffs_only_opp import EvoMixedPayoffOnlyOppRunner from pax.runners.runner_evo_scanned import EvoScannedRunner from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer diff --git a/pax/runners/runner_evo_mixed_IPD_payoffs.py b/pax/runners/experimental/runner_evo_mixed_IPD_payoffs.py similarity index 99% rename from pax/runners/runner_evo_mixed_IPD_payoffs.py rename to pax/runners/experimental/runner_evo_mixed_IPD_payoffs.py index f6ba9cdc..8c88236e 100644 --- a/pax/runners/runner_evo_mixed_IPD_payoffs.py +++ b/pax/runners/experimental/runner_evo_mixed_IPD_payoffs.py @@ -37,6 +37,8 @@ class EvoMixedIPDPayoffRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Each opponent has a different payoff matrix that follows the IPD conditions but each member + of the evo population plays against the same payoff matrices to ensure fair comparison. Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is diff --git a/pax/runners/runner_evo_mixed_lr.py b/pax/runners/experimental/runner_evo_mixed_lr.py similarity index 97% rename from pax/runners/runner_evo_mixed_lr.py rename to pax/runners/experimental/runner_evo_mixed_lr.py index ac1c3b5a..bb8942f7 100644 --- a/pax/runners/runner_evo_mixed_lr.py +++ b/pax/runners/experimental/runner_evo_mixed_lr.py @@ -37,6 +37,9 @@ class EvoMixedLRRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Each opponent has a different learning rate, but the members of the population + play against the same learning rates to ensure a fair comparison. + Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is @@ -212,7 +215,7 @@ def _inner_rollout(carry, unused): obs2, a2_mem, ) - jax.debug.print("env_params: {x}", x=env_params) + # jax.debug.print("env_params: {x}", x=env_params) (next_obs1, next_obs2), env_state, rewards, done, info = env.step( env_rng, env_state, @@ -338,10 +341,10 @@ def _rollout( a2_rng, agent2._mem.hidden, ) - # generate an array of shape [10] - random_numbers = jax.random.uniform(_rng_run, minval=1.0, maxval=1.0, shape=(10,)) - # # repeat the array 1000 times along the first dimension - learning_rates = jnp.tile(random_numbers, (1000, 1)) + # generate an array of shape [args.num_opps] + random_numbers = jax.random.uniform(_rng_run, minval=1e-5, maxval=1.0, shape=(args.num_opps,)) + # # repeat the array popsize-times along the first dimension + learning_rates = jnp.tile(random_numbers, (args.popsize, 1)) a2_state.opt_state[2].hyperparams['step_size'] = learning_rates # jax.debug.breakpoint() diff --git a/pax/runners/runner_evo_mixed_payoffs.py b/pax/runners/experimental/runner_evo_mixed_payoffs.py similarity index 99% rename from pax/runners/runner_evo_mixed_payoffs.py rename to pax/runners/experimental/runner_evo_mixed_payoffs.py index da3f9da7..00a254a2 100644 --- a/pax/runners/runner_evo_mixed_payoffs.py +++ b/pax/runners/experimental/runner_evo_mixed_payoffs.py @@ -37,6 +37,7 @@ class EvoMixedPayoffRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Payoff matrix is randomly sampled at each rollout. Each opponent has a different payoff matrix. Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is diff --git a/pax/runners/runner_evo_mixed_payoffs_gen.py b/pax/runners/experimental/runner_evo_mixed_payoffs_gen.py similarity index 99% rename from pax/runners/runner_evo_mixed_payoffs_gen.py rename to pax/runners/experimental/runner_evo_mixed_payoffs_gen.py index b9ae0bd5..f68f9fc6 100644 --- a/pax/runners/runner_evo_mixed_payoffs_gen.py +++ b/pax/runners/experimental/runner_evo_mixed_payoffs_gen.py @@ -37,6 +37,7 @@ class EvoMixedPayoffGenRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Payoff matrix is randomly sampled at each rollout. Each opponent has the same payoff matrix. Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is diff --git a/pax/runners/runner_evo_mixed_payoffs_input.py b/pax/runners/experimental/runner_evo_mixed_payoffs_input.py similarity index 98% rename from pax/runners/runner_evo_mixed_payoffs_input.py rename to pax/runners/experimental/runner_evo_mixed_payoffs_input.py index dafff84f..9601852e 100644 --- a/pax/runners/runner_evo_mixed_payoffs_input.py +++ b/pax/runners/experimental/runner_evo_mixed_payoffs_input.py @@ -37,6 +37,8 @@ class EvoMixedPayoffInputRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Add payoff matrices as input to agents so they don't have to figure out payoff matrices on the go. + Either randomly sample and set a payoff matrix Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is @@ -201,8 +203,8 @@ def _inner_rollout(carry, unused): # a1_rng = rngs[:, :, :, 1, :] # a2_rng = rngs[:, :, :, 2, :] rngs = rngs[:, :, :, 3, :] - print("OBS1 shape: ", obs1.shape) - print("env params shape: ", env_params.payoff_matrix.shape) + # print("OBS1 shape: ", obs1.shape) + # print("env params shape: ", env_params.payoff_matrix.shape) # flatten the payoff matrix and append it to the observations # the observations have shape (500, 10, 2, 5) and the payoff matrix has shape (10, 4, 2) # we want to append the payoff matrix to the observations so that the observations have shape (500, 10, 2, 5+8) @@ -290,7 +292,7 @@ def _outer_rollout(carry, unused): # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) - print("OBS2 shape: ", obs2.shape) + # print("OBS2 shape: ", obs2.shape) # payoff_matrix = env_params.payoff_matrix.reshape((10, 8)) # payoff_matrix = jnp.tile(jnp.expand_dims(jnp.tile(payoff_matrix, (500, 1, 1)), 2), (1, 1, 2, 1)) # obs2_update = jnp.concatenate((obs2, payoff_matrix), axis=3) diff --git a/pax/runners/runner_evo_mixed_payoffs_only_opp.py b/pax/runners/experimental/runner_evo_mixed_payoffs_only_opp.py similarity index 99% rename from pax/runners/runner_evo_mixed_payoffs_only_opp.py rename to pax/runners/experimental/runner_evo_mixed_payoffs_only_opp.py index 50324d4d..873aeefc 100644 --- a/pax/runners/runner_evo_mixed_payoffs_only_opp.py +++ b/pax/runners/experimental/runner_evo_mixed_payoffs_only_opp.py @@ -37,6 +37,8 @@ class EvoMixedPayoffOnlyOppRunner: It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. + Opponent plays a noisy payoff function of the original IPD payoff matrix. + Same noise applied to all opponents. Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is diff --git a/pax/runners/runner_evo_mixed_payoffs_input_onlymeta.py b/pax/runners/runner_evo_mixed_payoffs_input_onlymeta.py deleted file mode 100644 index cb787ca7..00000000 --- a/pax/runners/runner_evo_mixed_payoffs_input_onlymeta.py +++ /dev/null @@ -1,661 +0,0 @@ -import os -import time -from datetime import datetime -from typing import Any, Callable, NamedTuple - -import jax -import jax.numpy as jnp -from evosax import FitnessShaper - -import wandb -from pax.utils import MemoryState, TrainingState, save - -# TODO: import when evosax library is updated -# from evosax.utils import ESLog -from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats - -MAX_WANDB_CALLS = 1000 - - -class Sample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - - -class EvoMixedPayoffInputRunner: - """ - Evoluationary Strategy runner provides a convenient example for quickly writing - a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. - It composes together agents, watchers, and the environment. - Within the init, we declare vmaps and pmaps for training. - The environment provided must conform to a meta-environment. - Args: - agents (Tuple[agents]): - The set of agents that will run in the experiment. Note, ordering is - important for logic used in the class. - env (gymnax.envs.Environment): - The meta-environment that the agents will run in. - strategy (evosax.Strategy): - The evolutionary strategy that will be used to train the agents. - param_reshaper (evosax.param_reshaper.ParameterReshaper): - A function that reshapes the parameters of the agents into a format that can be - used by the strategy. - save_dir (string): - The directory to save the model to. - args (NamedTuple): - A tuple of experiment arguments used (usually provided by HydraConfig). - """ - - def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args - ): - self.args = args - self.algo = args.es.algo - self.es_params = es_params - self.generations = 0 - self.num_opps = args.num_opps - self.param_reshaper = param_reshaper - self.popsize = args.popsize - self.random_key = jax.random.PRNGKey(args.seed) - self.start_datetime = datetime.now() - self.save_dir = save_dir - self.start_time = time.time() - self.strategy = strategy - self.top_k = args.top_k - self.train_steps = 0 - self.train_episodes = 0 - self.ipd_stats = jax.jit(ipd_visitation) - self.cg_stats = jax.jit(jax.vmap(cg_visitation)) - self.ipditm_stats = jax.jit( - jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) - ) - - # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) - # Evo Runner also has an additional pmap dim (num_devices, ...) - # For the env we vmap over the rng but not params - - # num envs - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - - # num opps - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, 0), 0 # rng, state, actions, params - ) - # pop size - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( - jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - ) - self.split = jax.vmap( - jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), - (0, None), - ) - - self.num_outer_steps = args.num_outer_steps - agent1, agent2 = agents - - # vmap agents accordingly - # agent 1 is batched over popsize and num_opps - agent1.batch_init = jax.vmap( - jax.vmap( - agent1.make_initial_state, - (None, 0), # (params, rng) - (None, 0), # (TrainingState, MemoryState) - ), - # both for Population - ) - agent1.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent1.batch_policy = jax.jit( - jax.vmap( - jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), - ) - ) - - if args.agent2 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent2.batch_init = jax.jit( - jax.vmap(jax.vmap(agent2.make_initial_state)) - ) - else: - agent2.batch_init = jax.jit( - jax.vmap( - jax.vmap(agent2.make_initial_state, (0, None), 0), - (0, None), - 0, - ) - ) - - agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) - agent2.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent2.batch_update = jax.jit( - jax.vmap( - jax.vmap(agent2.update, (1, 0, 0, 0)), - (1, 0, 0, 0), - ) - ) - if args.agent2 != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - - a2_rng = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - - agent2._state, agent2._mem = agent2.batch_init( - a2_rng, - init_hidden, - ) - - # jit evo - strategy.ask = jax.jit(strategy.ask) - strategy.tell = jax.jit(strategy.tell) - param_reshaper.reshape = jax.jit(param_reshaper.reshape) - - def _inner_rollout(carry, unused): - """Runner for inner episode""" - ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ) = carry - - # unpack rngs - rngs = self.split(rngs, 4) - env_rng = rngs[:, :, :, 0, :] - - # a1_rng = rngs[:, :, :, 1, :] - # a2_rng = rngs[:, :, :, 2, :] - rngs = rngs[:, :, :, 3, :] - print("OBS1 shape: ", obs1.shape) - print("env params shape: ", env_params.payoff_matrix.shape) - # flatten the payoff matrix and append it to the observations - # the observations have shape (500, 10, 2, 5) and the payoff matrix has shape (10, 4, 2) - # we want to append the payoff matrix to the observations so that the observations have shape (500, 10, 2, 5+8) - # we want to flatten the payoff matrix so that it has shape (10, 8) - # This is the code - payoff_matrix = env_params.payoff_matrix.reshape((10, 8)) - payoff_matrix = jnp.tile(jnp.expand_dims(jnp.tile(payoff_matrix, (500, 1, 1)), 2), (1, 1, 2, 1)) - obs1 = jnp.concatenate((obs1, payoff_matrix), axis=3) - # obs2 = jnp.concatenate((obs2, payoff_matrix), axis=3) - - a1, a1_state, new_a1_mem = agent1.batch_policy( - a1_state, - obs1, - a1_mem, - ) - a2, a2_state, new_a2_mem = agent2.batch_policy( - a2_state, - obs2, - a2_mem, - ) - (next_obs1, next_obs2), env_state, rewards, done, info = env.step( - env_rng, - env_state, - (a1, a2), - env_params, - ) - - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) - traj2 = Sample( - obs2, - a2, - rewards[1], - new_a2_mem.extras["log_probs"], - new_a2_mem.extras["values"], - done, - a2_mem.hidden, - ) - return ( - rngs, - next_obs1, - next_obs2, - rewards[0], - rewards[1], - a1_state, - new_a1_mem, - a2_state, - new_a2_mem, - env_state, - env_params, - ), ( - traj1, - traj2, - ) - - def _outer_rollout(carry, unused): - """Runner for trial""" - # play episode of the game - vals, trajectories = jax.lax.scan( - _inner_rollout, - carry, - None, - length=args.num_inner_steps, - ) - ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ) = vals - # MFOS has to take a meta-action for each episode - if args.agent1 == "MFOS": - a1_mem = agent1.meta_policy(a1_mem) - # print("OBS2 shape: ", obs2.shape) - # payoff_matrix = env_params.payoff_matrix.reshape((10, 8)) - # payoff_matrix = jnp.tile(jnp.expand_dims(jnp.tile(payoff_matrix, (500, 1, 1)), 2), (1, 1, 2, 1)) - # obs2_update = jnp.concatenate((obs2, payoff_matrix), axis=3) - - # update second agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2_update, - a2_state, - a2_mem, - ) - return ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ), (*trajectories, a2_metrics) - - def _rollout( - _params: jnp.ndarray, - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _env_params: Any, - ): - # env reset - env_rngs = jnp.concatenate( - [jax.random.split(_rng_run, args.num_envs)] - * args.num_opps - * args.popsize - ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) - # set payoff matrix to random integers of shape [4,2] - _rng_run, payoff_rng = jax.random.split(_rng_run) - # payoff_matrix = -jax.random.randint(payoff_rng, minval=0, maxval=10, shape=(4,2), dtype=jnp.int8) - payoff_matrix = jnp.array([[-1, -1], [-3, 0], [0, -3], [-2, -2]]) - payoff_matrix = jnp.tile(payoff_matrix, (args.num_opps, 1, 1)) - # jax.debug.breakpoint() - _env_params.payoff_matrix = payoff_matrix - - obs, env_state = env.reset(env_rngs, _env_params) - rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] - - # Player 1 - _a1_state = _a1_state._replace(params=_params) - _a1_mem = agent1.batch_reset(_a1_mem, False) - # Player 2 - if args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(obs[1]) - - else: - # meta-experiments - init 2nd agent per trial - a2_rng = jnp.concatenate( - [jax.random.split(_rng_run, args.num_opps)] * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - a2_state, a2_mem = agent2.batch_init( - a2_rng, - agent2._mem.hidden, - ) - # generate an array of shape [10] - # random_numbers = jax.random.uniform(_rng_run, minval=1e-5, maxval=1.0, shape=(10,)) - # # repeat the array 1000 times along the first dimension - # learning_rates = jnp.tile(random_numbers, (1000, 1)) - # a2_state.opt_state[2].hyperparams['step_size'] = learning_rates - # jax.debug.breakpoint() - - # run trials - vals, stack = jax.lax.scan( - _outer_rollout, - ( - env_rngs, - *obs, - *rewards, - _a1_state, - _a1_mem, - a2_state, - a2_mem, - env_state, - _env_params, - ), - None, - length=self.num_outer_steps, - ) - - ( - env_rngs, - obs1, - obs2, - r1, - r2, - _a1_state, - _a1_mem, - a2_state, - a2_mem, - env_state, - _env_params, - ) = vals - traj_1, traj_2, a2_metrics = stack - - # Fitness - fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) - # Stats - if args.env_id == "coin_game": - env_stats = jax.tree_util.tree_map( - lambda x: x, - self.cg_stats(env_state), - ) - - rewards_1 = traj_1.rewards.sum(axis=1).mean() - rewards_2 = traj_2.rewards.sum(axis=1).mean() - - elif args.env_id in [ - "iterated_matrix_game", - ]: - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipd_stats( - traj_1.observations, - traj_1.actions, - obs1, - ), - ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - - elif args.env_id == "InTheMatrix": - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipditm_stats( - env_state, - traj_1, - traj_2, - args.num_envs, - ), - ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - else: - env_stats = {} - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - - return ( - fitness, - other_fitness, - env_stats, - rewards_1, - rewards_2, - a2_metrics, - ) - - self.rollout = jax.pmap( - _rollout, - in_axes=(0, None, None, None, None), - ) - - print( - f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" - ) - - def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, - ): - """Run training of agents in environment""" - print("Training") - print("------------------------------") - log_interval = max(num_iters / MAX_WANDB_CALLS, 5) - print(f"Number of Generations: {num_iters}") - print(f"Number of Meta Episodes: {self.num_outer_steps}") - print(f"Population Size: {self.popsize}") - print(f"Number of Environments: {self.args.num_envs}") - print(f"Number of Opponent: {self.args.num_opps}") - print(f"Log Interval: {log_interval}") - print("------------------------------") - # Initialize agents and RNG - agent1, agent2 = agents - rng, _ = jax.random.split(self.random_key) - - # Initialize evolution - num_gens = num_iters - strategy = self.strategy - es_params = self.es_params - param_reshaper = self.param_reshaper - popsize = self.popsize - num_opps = self.num_opps - evo_state = strategy.initialize(rng, es_params) - fit_shaper = FitnessShaper( - maximize=self.args.es.maximise, - centered_rank=self.args.es.centered_rank, - w_decay=self.args.es.w_decay, - z_score=self.args.es.z_score, - ) - es_logging = ESLog( - param_reshaper.total_params, - num_gens, - top_k=self.top_k, - maximize=True, - ) - log = es_logging.initialize() - - # Reshape a single agent's params before vmapping - init_hidden = jnp.tile( - agent1._mem.hidden, - (popsize, num_opps, 1, 1), - ) - a1_rng = jax.random.split(rng, popsize) - agent1._state, agent1._mem = agent1.batch_init( - a1_rng, - init_hidden, - ) - - a1_state, a1_mem = agent1._state, agent1._mem - - for gen in range(num_gens): - rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) - - # Ask - x, evo_state = strategy.ask(rng_evo, evo_state, es_params) - params = param_reshaper.reshape(x) - if self.args.num_devices == 1: - params = jax.tree_util.tree_map( - lambda x: jax.lax.expand_dims(x, (0,)), params - ) - # Evo Rollout - # jax.debug.breakpoint() - ( - fitness, - other_fitness, - env_stats, - rewards_1, - rewards_2, - a2_metrics, - ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) - - # Aggregate over devices - fitness = jnp.reshape(fitness, popsize * self.args.num_devices) - env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) - - # Tell - fitness_re = fit_shaper.apply(x, fitness) - - if self.args.es.mean_reduce: - fitness_re = fitness_re - fitness_re.mean() - evo_state = strategy.tell(x, fitness_re, evo_state, es_params) - - # Logging - log = es_logging.update(log, x, fitness) - - # Saving - if gen % self.args.save_interval == 0: - log_savepath = os.path.join(self.save_dir, f"generation_{gen}") - if self.args.num_devices > 1: - top_params = param_reshaper.reshape( - log["top_gen_params"][0 : self.args.num_devices] - ) - top_params = jax.tree_util.tree_map( - lambda x: x[0].reshape(x[0].shape[1:]), top_params - ) - else: - top_params = param_reshaper.reshape( - log["top_gen_params"][0:1] - ) - top_params = jax.tree_util.tree_map( - lambda x: x.reshape(x.shape[1:]), top_params - ) - save(top_params, log_savepath) - if watchers: - print(f"Saving generation {gen} locally and to WandB") - wandb.save(log_savepath) - else: - print(f"Saving iteration {gen} locally") - - if gen % log_interval == 0: - print(f"Generation: {gen}") - print( - "--------------------------------------------------------------------------" - ) - print( - f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" - ) - print( - f"Reward Per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" - ) - print( - f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" - ) - print( - "--------------------------------------------------------------------------" - ) - print( - f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" - f" | Std: {log['log_top_gen_std'][gen]}" - ) - print( - "--------------------------------------------------------------------------" - ) - print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") - print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") - print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") - print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") - print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") - print() - - if watchers: - wandb_log = { - "train_iteration": gen, - "train/fitness/player_1": float(fitness.mean()), - "train/fitness/player_2": float(other_fitness.mean()), - "train/fitness/top_overall_mean": log["log_top_mean"][gen], - "train/fitness/top_overall_std": log["log_top_std"][gen], - "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], - "train/fitness/top_gen_std": log["log_top_gen_std"][gen], - "train/fitness/gen_std": log["log_gen_std"][gen], - "train/time/minutes": float( - (time.time() - self.start_time) / 60 - ), - "train/time/seconds": float( - (time.time() - self.start_time) - ), - "train/reward_per_timestep/player_1": float( - rewards_1.mean() - ), - "train/reward_per_timestep/player_2": float( - rewards_2.mean() - ), - } - wandb_log.update(env_stats) - # loop through population - for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) - ): - wandb_log[ - f"train/fitness/top_overall_agent_{idx+1}" - ] = overall_fitness - wandb_log[ - f"train/fitness/top_gen_agent_{idx+1}" - ] = gen_fitness - - # player 2 metrics - # metrics [outer_timesteps, num_opps] - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics - ) - - agent2._logger.metrics.update(flattened_metrics) - for watcher, agent in zip(watchers, agents): - watcher(agent) - wandb_log = jax.tree_util.tree_map( - lambda x: x.item() if isinstance(x, jax.Array) else x, - wandb_log, - ) - wandb.log(wandb_log) - - return agents diff --git a/pax/runners/runner_evo_mixed_payoffs_pred.py b/pax/runners/runner_evo_mixed_payoffs_pred.py deleted file mode 100644 index 797d6870..00000000 --- a/pax/runners/runner_evo_mixed_payoffs_pred.py +++ /dev/null @@ -1,645 +0,0 @@ -import os -import time -from datetime import datetime -from typing import Any, Callable, NamedTuple - -import jax -import jax.numpy as jnp -from evosax import FitnessShaper - -import wandb -from pax.utils import MemoryState, TrainingState, save - -# TODO: import when evosax library is updated -# from evosax.utils import ESLog -from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats - -MAX_WANDB_CALLS = 1000 - - -class Sample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - - -class EvoMixedPayoffPredRunner: - """ - Evoluationary Strategy runner provides a convenient example for quickly writing - a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. - It composes together agents, watchers, and the environment. - Within the init, we declare vmaps and pmaps for training. - The environment provided must conform to a meta-environment. - Args: - agents (Tuple[agents]): - The set of agents that will run in the experiment. Note, ordering is - important for logic used in the class. - env (gymnax.envs.Environment): - The meta-environment that the agents will run in. - strategy (evosax.Strategy): - The evolutionary strategy that will be used to train the agents. - param_reshaper (evosax.param_reshaper.ParameterReshaper): - A function that reshapes the parameters of the agents into a format that can be - used by the strategy. - save_dir (string): - The directory to save the model to. - args (NamedTuple): - A tuple of experiment arguments used (usually provided by HydraConfig). - """ - - def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args - ): - self.args = args - self.algo = args.es.algo - self.es_params = es_params - self.generations = 0 - self.num_opps = args.num_opps - self.param_reshaper = param_reshaper - self.popsize = args.popsize - self.random_key = jax.random.PRNGKey(args.seed) - self.start_datetime = datetime.now() - self.save_dir = save_dir - self.start_time = time.time() - self.strategy = strategy - self.top_k = args.top_k - self.train_steps = 0 - self.train_episodes = 0 - self.ipd_stats = jax.jit(ipd_visitation) - self.cg_stats = jax.jit(jax.vmap(cg_visitation)) - self.ipditm_stats = jax.jit( - jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) - ) - - # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) - # Evo Runner also has an additional pmap dim (num_devices, ...) - # For the env we vmap over the rng but not params - - # num envs - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - - # num opps - env.reset = jax.vmap(env.reset, (0, None), 0) - env.step = jax.vmap( - env.step, (0, 0, 0, 0), 0 # rng, state, actions, params - ) - # pop size - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( - jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - ) - self.split = jax.vmap( - jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), - (0, None), - ) - - self.num_outer_steps = args.num_outer_steps - agent1, agent2 = agents - - # vmap agents accordingly - # agent 1 is batched over popsize and num_opps - agent1.batch_init = jax.vmap( - jax.vmap( - agent1.make_initial_state, - (None, 0), # (params, rng) - (None, 0), # (TrainingState, MemoryState) - ), - # both for Population - ) - agent1.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent1.batch_policy = jax.jit( - jax.vmap( - jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), - ) - ) - - if args.agent2 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent2.batch_init = jax.jit( - jax.vmap(jax.vmap(agent2.make_initial_state)) - ) - else: - agent2.batch_init = jax.jit( - jax.vmap( - jax.vmap(agent2.make_initial_state, (0, None), 0), - (0, None), - 0, - ) - ) - - agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) - agent2.batch_reset = jax.jit( - jax.vmap( - jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 - ), - static_argnums=1, - ) - - agent2.batch_update = jax.jit( - jax.vmap( - jax.vmap(agent2.update, (1, 0, 0, 0)), - (1, 0, 0, 0), - ) - ) - if args.agent2 != "NaiveEx": - # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - - a2_rng = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - - agent2._state, agent2._mem = agent2.batch_init( - a2_rng, - init_hidden, - ) - - # jit evo - strategy.ask = jax.jit(strategy.ask) - strategy.tell = jax.jit(strategy.tell) - param_reshaper.reshape = jax.jit(param_reshaper.reshape) - - def _inner_rollout(carry, unused): - """Runner for inner episode""" - ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ) = carry - - # unpack rngs - rngs = self.split(rngs, 4) - env_rng = rngs[:, :, :, 0, :] - - # a1_rng = rngs[:, :, :, 1, :] - # a2_rng = rngs[:, :, :, 2, :] - rngs = rngs[:, :, :, 3, :] - - a1, a1_state, new_a1_mem = agent1.batch_policy( - a1_state, - obs1, - a1_mem, - ) - a2, a2_state, new_a2_mem = agent2.batch_policy( - a2_state, - obs2, - a2_mem, - ) - (next_obs1, next_obs2), env_state, rewards, done, info = env.step( - env_rng, - env_state, - (a1, a2), - env_params, - ) - - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) - traj2 = Sample( - obs2, - a2, - rewards[1], - new_a2_mem.extras["log_probs"], - new_a2_mem.extras["values"], - done, - a2_mem.hidden, - ) - return ( - rngs, - next_obs1, - next_obs2, - rewards[0], - rewards[1], - a1_state, - new_a1_mem, - a2_state, - new_a2_mem, - env_state, - env_params, - ), ( - traj1, - traj2, - ) - - def _outer_rollout(carry, unused): - """Runner for trial""" - # play episode of the game - vals, trajectories = jax.lax.scan( - _inner_rollout, - carry, - None, - length=args.num_inner_steps, - ) - ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ) = vals - # MFOS has to take a meta-action for each episode - if args.agent1 == "MFOS": - a1_mem = agent1.meta_policy(a1_mem) - - # update second agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2, - a2_state, - a2_mem, - ) - return ( - rngs, - obs1, - obs2, - r1, - r2, - a1_state, - a1_mem, - a2_state, - a2_mem, - env_state, - env_params, - ), (*trajectories, a2_metrics) - - def _rollout( - _params: jnp.ndarray, - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _env_params: Any, - ): - # env reset - env_rngs = jnp.concatenate( - [jax.random.split(_rng_run, args.num_envs)] - * args.num_opps - * args.popsize - ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) - # set payoff matrix to random integers of shape [4,2] - _rng_run, payoff_rng = jax.random.split(_rng_run) - payoff_matrix = -jax.random.randint(payoff_rng, minval=0, maxval=10, shape=(4,2), dtype=jnp.int8) - payoff_matrix = jnp.tile(payoff_matrix, (args.num_opps, 1, 1)) - # jax.debug.breakpoint() - _env_params.payoff_matrix = payoff_matrix - - obs, env_state = env.reset(env_rngs, _env_params) - rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] - - # Player 1 - _a1_state = _a1_state._replace(params=_params) - _a1_mem = agent1.batch_reset(_a1_mem, False) - # Player 2 - if args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(obs[1]) - - else: - # meta-experiments - init 2nd agent per trial - a2_rng = jnp.concatenate( - [jax.random.split(_rng_run, args.num_opps)] * args.popsize - ).reshape(args.popsize, args.num_opps, -1) - a2_state, a2_mem = agent2.batch_init( - a2_rng, - agent2._mem.hidden, - ) - # generate an array of shape [10] - # random_numbers = jax.random.uniform(_rng_run, minval=1e-5, maxval=1.0, shape=(10,)) - # # repeat the array 1000 times along the first dimension - # learning_rates = jnp.tile(random_numbers, (1000, 1)) - # a2_state.opt_state[2].hyperparams['step_size'] = learning_rates - # jax.debug.breakpoint() - - # run trials - vals, stack = jax.lax.scan( - _outer_rollout, - ( - env_rngs, - *obs, - *rewards, - _a1_state, - _a1_mem, - a2_state, - a2_mem, - env_state, - _env_params, - ), - None, - length=self.num_outer_steps, - ) - - ( - env_rngs, - obs1, - obs2, - r1, - r2, - _a1_state, - _a1_mem, - a2_state, - a2_mem, - env_state, - _env_params, - ) = vals - traj_1, traj_2, a2_metrics = stack - - # Fitness - fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) - # Stats - if args.env_id == "coin_game": - env_stats = jax.tree_util.tree_map( - lambda x: x, - self.cg_stats(env_state), - ) - - rewards_1 = traj_1.rewards.sum(axis=1).mean() - rewards_2 = traj_2.rewards.sum(axis=1).mean() - - elif args.env_id in [ - "iterated_matrix_game", - ]: - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipd_stats( - traj_1.observations, - traj_1.actions, - obs1, - ), - ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - - elif args.env_id == "InTheMatrix": - env_stats = jax.tree_util.tree_map( - lambda x: x.mean(), - self.ipditm_stats( - env_state, - traj_1, - traj_2, - args.num_envs, - ), - ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - else: - env_stats = {} - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - - return ( - fitness, - other_fitness, - env_stats, - rewards_1, - rewards_2, - a2_metrics, - ) - - self.rollout = jax.pmap( - _rollout, - in_axes=(0, None, None, None, None), - ) - - print( - f"Time to Compile Jax Methods: {time.time() - self.start_time} Seconds" - ) - - def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, - ): - """Run training of agents in environment""" - print("Training") - print("------------------------------") - log_interval = max(num_iters / MAX_WANDB_CALLS, 5) - print(f"Number of Generations: {num_iters}") - print(f"Number of Meta Episodes: {self.num_outer_steps}") - print(f"Population Size: {self.popsize}") - print(f"Number of Environments: {self.args.num_envs}") - print(f"Number of Opponent: {self.args.num_opps}") - print(f"Log Interval: {log_interval}") - print("------------------------------") - # Initialize agents and RNG - agent1, agent2 = agents - rng, _ = jax.random.split(self.random_key) - - # Initialize evolution - num_gens = num_iters - strategy = self.strategy - es_params = self.es_params - param_reshaper = self.param_reshaper - popsize = self.popsize - num_opps = self.num_opps - evo_state = strategy.initialize(rng, es_params) - fit_shaper = FitnessShaper( - maximize=self.args.es.maximise, - centered_rank=self.args.es.centered_rank, - w_decay=self.args.es.w_decay, - z_score=self.args.es.z_score, - ) - es_logging = ESLog( - param_reshaper.total_params, - num_gens, - top_k=self.top_k, - maximize=True, - ) - log = es_logging.initialize() - - # Reshape a single agent's params before vmapping - init_hidden = jnp.tile( - agent1._mem.hidden, - (popsize, num_opps, 1, 1), - ) - a1_rng = jax.random.split(rng, popsize) - agent1._state, agent1._mem = agent1.batch_init( - a1_rng, - init_hidden, - ) - - a1_state, a1_mem = agent1._state, agent1._mem - - for gen in range(num_gens): - rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) - - # Ask - x, evo_state = strategy.ask(rng_evo, evo_state, es_params) - params = param_reshaper.reshape(x) - if self.args.num_devices == 1: - params = jax.tree_util.tree_map( - lambda x: jax.lax.expand_dims(x, (0,)), params - ) - # Evo Rollout - # jax.debug.breakpoint() - ( - fitness, - other_fitness, - env_stats, - rewards_1, - rewards_2, - a2_metrics, - ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) - - # Aggregate over devices - fitness = jnp.reshape(fitness, popsize * self.args.num_devices) - env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) - - # Tell - fitness_re = fit_shaper.apply(x, fitness) - - if self.args.es.mean_reduce: - fitness_re = fitness_re - fitness_re.mean() - evo_state = strategy.tell(x, fitness_re, evo_state, es_params) - - # Logging - log = es_logging.update(log, x, fitness) - - # Saving - if gen % self.args.save_interval == 0: - log_savepath = os.path.join(self.save_dir, f"generation_{gen}") - if self.args.num_devices > 1: - top_params = param_reshaper.reshape( - log["top_gen_params"][0 : self.args.num_devices] - ) - top_params = jax.tree_util.tree_map( - lambda x: x[0].reshape(x[0].shape[1:]), top_params - ) - else: - top_params = param_reshaper.reshape( - log["top_gen_params"][0:1] - ) - top_params = jax.tree_util.tree_map( - lambda x: x.reshape(x.shape[1:]), top_params - ) - save(top_params, log_savepath) - if watchers: - print(f"Saving generation {gen} locally and to WandB") - wandb.save(log_savepath) - else: - print(f"Saving iteration {gen} locally") - - if gen % log_interval == 0: - print(f"Generation: {gen}") - print( - "--------------------------------------------------------------------------" - ) - print( - f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" - ) - print( - f"Reward Per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" - ) - print( - f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" - ) - print( - "--------------------------------------------------------------------------" - ) - print( - f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" - f" | Std: {log['log_top_gen_std'][gen]}" - ) - print( - "--------------------------------------------------------------------------" - ) - print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") - print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") - print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") - print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") - print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") - print() - - if watchers: - wandb_log = { - "train_iteration": gen, - "train/fitness/player_1": float(fitness.mean()), - "train/fitness/player_2": float(other_fitness.mean()), - "train/fitness/top_overall_mean": log["log_top_mean"][gen], - "train/fitness/top_overall_std": log["log_top_std"][gen], - "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], - "train/fitness/top_gen_std": log["log_top_gen_std"][gen], - "train/fitness/gen_std": log["log_gen_std"][gen], - "train/time/minutes": float( - (time.time() - self.start_time) / 60 - ), - "train/time/seconds": float( - (time.time() - self.start_time) - ), - "train/reward_per_timestep/player_1": float( - rewards_1.mean() - ), - "train/reward_per_timestep/player_2": float( - rewards_2.mean() - ), - } - wandb_log.update(env_stats) - # loop through population - for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) - ): - wandb_log[ - f"train/fitness/top_overall_agent_{idx+1}" - ] = overall_fitness - wandb_log[ - f"train/fitness/top_gen_agent_{idx+1}" - ] = gen_fitness - - # player 2 metrics - # metrics [outer_timesteps, num_opps] - flattened_metrics = jax.tree_util.tree_map( - lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics - ) - - agent2._logger.metrics.update(flattened_metrics) - for watcher, agent in zip(watchers, agents): - watcher(agent) - wandb_log = jax.tree_util.tree_map( - lambda x: x.item() if isinstance(x, jax.Array) else x, - wandb_log, - ) - wandb.log(wandb_log) - - return agents