From a8a817c9f030da8ca0614252552b698aa6a9d003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Pr=C3=B6schel?= Date: Wed, 18 Oct 2023 15:13:12 +0200 Subject: [PATCH] tests failing --- .github/workflows/test.yaml | 5 +- optim_test.py | 45 -- pax/agents/lola/lola.py | 8 +- pax/agents/mfos_ppo/ppo_gru.py | 3 +- pax/agents/naive/naive.py | 6 +- pax/agents/naive/network.py | 19 +- pax/agents/ppo/batched_envs.py | 2 +- pax/agents/ppo/networks.py | 135 ++--- pax/agents/ppo/ppo.py | 144 +++--- pax/agents/ppo/ppo_gru.py | 8 +- pax/agents/tensor_strategies.py | 4 +- pax/envs/cournot.py | 24 +- pax/envs/fishery.py | 32 +- pax/envs/infinite_matrix_game.py | 4 +- pax/envs/rice/c_rice.py | 470 ++++++++++++------ pax/envs/rice/rice.py | 441 ++++++++++------ pax/envs/rice/sarl_rice.py | 86 ++-- pax/experiment.py | 84 +++- pax/runners/runner_eval.py | 149 ++++-- pax/runners/runner_eval_multishaper.py | 38 +- pax/runners/runner_evo.py | 124 +++-- pax/runners/runner_evo_multishaper.py | 113 +++-- pax/runners/runner_marl.py | 41 +- pax/runners/runner_marl_nplayer.py | 30 +- pax/runners/runner_weight_sharing.py | 48 +- pax/utils.py | 1 - pax/version.py | 4 +- pax/watchers/__init__.py | 166 ++++--- pax/watchers/c_rice.py | 94 +++- pax/watchers/cournot.py | 8 +- pax/watchers/fishery.py | 49 +- pax/watchers/rice.py | 93 +++- test/envs/test_cournot.py | 15 +- test/envs/test_fishery.py | 20 +- .../test_iterated_tensor_game_n_player.py | 28 +- test/envs/test_rice.py | 64 +-- test/test_strategies.py | 52 -- test/test_tensor_strategies.py | 1 + 38 files changed, 1586 insertions(+), 1072 deletions(-) delete mode 100644 optim_test.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dfc31072..e67a3b83 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,8 +1,8 @@ # .github/workflows/app.yaml name: PyTest -on: +on: pull_request: - branches: + branches: - main jobs: @@ -21,6 +21,7 @@ jobs: python -m pip install --upgrade pip - name: Install from repo in the test mode run: "pip install -e '.[dev]'" + - uses: pre-commit/action@v3.0.0 - name: Test with pytest run: | pip install pytest diff --git a/optim_test.py b/optim_test.py deleted file mode 100644 index 4ac26e27..00000000 --- a/optim_test.py +++ /dev/null @@ -1,45 +0,0 @@ -import os - -import jax -import jax.numpy as jnp -import jaxopt -import optax - -from pax.envs.rice.rice import EnvParams -from pax.envs.rice.sarl_rice import SarlRice -from jaxopt import ProjectedGradient, OptaxSolver -from jaxopt.projection import projection_non_negative - -jax.config.update("jax_enable_x64", True) - - -ep_length = 20 -env_dir = os.path.join("./pax/envs/rice") -sarl_env = SarlRice(config_folder=os.path.join(env_dir, "5_regions"), episode_length=ep_length) -env_params = EnvParams() - - -def objective(params, rng): - obs, state = sarl_env.reset(rng, env_params) - rewards = 0 - for i in range(ep_length): - obs, state, reward, done, info = sarl_env.step(rng, state, params[i], env_params) - rewards += jnp.asarray(reward).sum() - - return rewards - - -rng = jax.random.PRNGKey(0) - -w_init = jnp.ones((ep_length, sarl_env.num_actions)) * 0.5 - -opt = optax.adam(0.0001) -solver = OptaxSolver(opt=opt, fun=objective, maxiter=100) -# solver = jaxopt.LBFGS(fun=objective, maxiter=100) -# pg = ProjectedGradient(fun=objective, projection=projection_non_negative) -pg_sol = solver.run(w_init, rng=rng) -params, state = pg_sol -print(pg_sol) -final_reward = objective(params, rng) -print(state) -print(f"reward {final_reward}") diff --git a/pax/agents/lola/lola.py b/pax/agents/lola/lola.py index d36225ce..a2c5532a 100644 --- a/pax/agents/lola/lola.py +++ b/pax/agents/lola/lola.py @@ -289,7 +289,9 @@ def inner_loss( "loss_value": value_objective, } - def make_initial_state(key: Any, hidden) -> Tuple[TrainingState, MemoryState]: + def make_initial_state( + key: Any, hidden + ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) dummy_obs = jnp.zeros(shape=obs_spec) @@ -524,7 +526,7 @@ def lola_inlookahead_rollout(carry, unused): ) my_mem = carry[7] other_mems = carry[8] - # jax.debug.breakpoint() + # flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs vmap_trajectories = jax.tree_map( lambda x: jnp.swapaxes(x, 0, 1), trajectories @@ -539,7 +541,7 @@ def lola_inlookahead_rollout(carry, unused): rewards_self=vmap_trajectories[0].rewards_self, rewards_other=[traj.rewards for traj in vmap_trajectories[1:]], ) - # jax.debug.breakpoint() + # get gradients of opponents other_gradients = [] for idx in range(len((self.other_agents))): diff --git a/pax/agents/mfos_ppo/ppo_gru.py b/pax/agents/mfos_ppo/ppo_gru.py index 3e6605a8..2e130fa7 100644 --- a/pax/agents/mfos_ppo/ppo_gru.py +++ b/pax/agents/mfos_ppo/ppo_gru.py @@ -11,7 +11,8 @@ from pax.agents.agent import AgentInterface from pax.agents.mfos_ppo.networks import ( make_mfos_ipditm_network, - make_mfos_network, make_mfos_continuous_network, + make_mfos_network, + make_mfos_continuous_network, ) from pax.envs.rice.rice import Rice from pax.utils import TrainingState, get_advantages diff --git a/pax/agents/naive/naive.py b/pax/agents/naive/naive.py index f55fec1d..4eeaf9fb 100644 --- a/pax/agents/naive/naive.py +++ b/pax/agents/naive/naive.py @@ -9,7 +9,11 @@ from pax import utils from pax.agents.agent import AgentInterface -from pax.agents.naive.network import make_coingame_network, make_network, make_rice_network +from pax.agents.naive.network import ( + make_coingame_network, + make_network, + make_rice_network, +) from pax.utils import MemoryState, TrainingState, get_advantages diff --git a/pax/agents/naive/network.py b/pax/agents/naive/network.py index 7610166f..467f0d33 100644 --- a/pax/agents/naive/network.py +++ b/pax/agents/naive/network.py @@ -10,9 +10,9 @@ class CategoricalValueHead(hk.Module): """Network head that produces a categorical distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, + self, + num_values: int, + name: Optional[str] = None, ): super().__init__(name=name) self._logit_layer = hk.Linear( @@ -34,10 +34,10 @@ class ContinuousValueHead(hk.Module): """Network head that produces a continuous distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, - mean_activation: Optional[str] = None, + self, + num_values: int, + name: Optional[str] = None, + mean_activation: Optional[str] = None, ): super().__init__(name=name) self.mean_action = mean_activation @@ -65,7 +65,10 @@ def __call__(self, inputs: jnp.ndarray): scales = self._scale_layer(inputs) value = jnp.squeeze(self._value_layer(inputs), axis=-1) scales = jnp.maximum(scales, 0.01) - return distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), value + return ( + distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), + value, + ) class CNN(hk.Module): diff --git a/pax/agents/ppo/batched_envs.py b/pax/agents/ppo/batched_envs.py index 2497f302..7713a1e0 100644 --- a/pax/agents/ppo/batched_envs.py +++ b/pax/agents/ppo/batched_envs.py @@ -30,7 +30,7 @@ def step(self, actions: jnp.ndarray) -> TimeStep: rewards = [] observations = [] discounts = [] - for env, action in zip(self.envs, actions): + for env, action in zip(self.envs, actions, strict=True): t = env.step(int(action)) if t.step_type == 2: t_reset = env.reset() diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index 09bcd18f..9bd123a2 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -16,9 +16,9 @@ class CategoricalValueHead(hk.Module): """Network head that produces a categorical distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, + self, + num_values: int, + name: Optional[str] = None, ): super().__init__(name=name) self._logit_layer = hk.Linear( @@ -42,9 +42,9 @@ class CategoricalValueHead_ipd(hk.Module): """Network head that produces a categorical distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, + self, + num_values: int, + name: Optional[str] = None, ): super().__init__(name=name) self._logit_layer = hk.Linear( @@ -68,9 +68,9 @@ class CategoricalValueHeadSeparate(hk.Module): """Network head that produces a categorical distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, + self, + num_values: int, + name: Optional[str] = None, ): super().__init__(name=name) self._action_body = hk.nets.MLP( @@ -112,10 +112,10 @@ class CategoricalValueHeadSeparate_ipditm(hk.Module): """Network head that produces a categorical distribution and value.""" def __init__( - self, - num_values: int, - hidden_size: int, - name: Optional[str] = None, + self, + num_values: int, + hidden_size: int, + name: Optional[str] = None, ): super().__init__(name=name) self._action_body = hk.nets.MLP( @@ -157,11 +157,11 @@ class ContinuousValueHead(hk.Module): """Network head that produces a continuous distribution and value.""" def __init__( - self, - num_values: int, - name: Optional[str] = None, - mean_activation: Optional[str] = None, - with_bias=False, + self, + num_values: int, + name: Optional[str] = None, + mean_activation: Optional[str] = None, + with_bias=False, ): super().__init__(name=name) self.mean_action = mean_activation @@ -189,7 +189,10 @@ def __call__(self, inputs: jnp.ndarray): scales = self._scale_layer(inputs) value = jnp.squeeze(self._value_layer(inputs), axis=-1) scales = jnp.maximum(scales, 0.01) - return distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), value + return ( + distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), + value, + ) class Tabular(hk.Module): @@ -207,7 +210,7 @@ def __init__(self, num_values: int): ) def _input_to_onehot(input: jnp.ndarray): - chunks = jnp.array([9 ** 3, 9 ** 2, 9, 1], dtype=jnp.int32) + chunks = jnp.array([9**3, 9**2, 9, 1], dtype=jnp.int32) idx = input.nonzero(size=4)[0] idx = jnp.mod(idx, 9) idx = chunks * idx @@ -367,7 +370,9 @@ def forward_fn(inputs): activate_final=True, activation=jnp.tanh, ), - ContinuousValueHead(num_values=num_actions, name="cournot_value_head"), + ContinuousValueHead( + num_values=num_actions, name="cournot_value_head" + ), ] ) policy_value_network = hk.Sequential(layers) @@ -391,7 +396,9 @@ def forward_fn(inputs): activate_final=True, activation=jnp.tanh, ), - ContinuousValueHead(num_values=num_actions, name="fishery_value_head"), + ContinuousValueHead( + num_values=num_actions, name="fishery_value_head" + ), ] ) policy_value_network = hk.Sequential(layers) @@ -404,7 +411,9 @@ def forward_fn(inputs): def make_rice_sarl_network(num_actions: int, hidden_size: int): """Continuous action space network with values clipped between 0 and 1""" if float_precision == jnp.float16: - policy = jmp.get_policy('params=float16,compute=float16,output=float32') + policy = jmp.get_policy( + "params=float16,compute=float16,output=float32" + ) hk.mixed_precision.set_policy(hk.nets.MLP, policy) def forward_fn(inputs): @@ -416,7 +425,11 @@ def forward_fn(inputs): activate_final=True, activation=jax.nn.relu, ), - ContinuousValueHead(num_values=num_actions, name="rice_value_head", mean_activation="sigmoid"), + ContinuousValueHead( + num_values=num_actions, + name="rice_value_head", + mean_activation="sigmoid", + ), ] policy_value_network = hk.Sequential(layers) @@ -427,13 +440,13 @@ def forward_fn(inputs): def make_coingame_network( - num_actions: int, - tabular: bool, - with_cnn: bool, - separate: bool, - hidden_size: int, - output_channels: int, - kernel_shape: int, + num_actions: int, + tabular: bool, + with_cnn: bool, + separate: bool, + hidden_size: int, + output_channels: int, + kernel_shape: int, ): def forward_fn(inputs): layers = [] @@ -510,12 +523,12 @@ def forward_fn(inputs): def make_ipditm_network( - num_actions: int, - separate: bool, - with_cnn: bool, - hidden_size: int, - output_channels: int, - kernel_shape: int, + num_actions: int, + separate: bool, + with_cnn: bool, + hidden_size: int, + output_channels: int, + kernel_shape: int, ): def forward_fn(inputs: dict): layers = [] @@ -543,7 +556,7 @@ def make_GRU_ipd_network(num_actions: int, hidden_size: int): hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" gru = hk.GRU(hidden_size) @@ -561,7 +574,7 @@ def make_GRU_cartpole_network(num_actions: int): hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" torso = hk.nets.MLP( @@ -582,16 +595,16 @@ def forward_fn( def make_GRU_coingame_network( - num_actions: int, - with_cnn: bool, - hidden_size: int, - output_channels: int, - kernel_shape: Tuple[int], + num_actions: int, + with_cnn: bool, + hidden_size: int, + output_channels: int, + kernel_shape: Tuple[int], ): hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: if with_cnn: @@ -622,16 +635,16 @@ def forward_fn( def make_GRU_ipditm_network( - num_actions: int, - hidden_size: int, - separate: bool, - output_channels: int, - kernel_shape: Tuple[int], + num_actions: int, + hidden_size: int, + separate: bool, + output_channels: int, + kernel_shape: Tuple[int], ): hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" torso = CNN_ipditm(output_channels, kernel_shape) @@ -657,13 +670,13 @@ def forward_fn( def make_GRU_fishery_network( - num_actions: int, - hidden_size: int, + num_actions: int, + hidden_size: int, ): hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> tuple[tuple[MultivariateNormalDiag, Array], Any]: """forward function""" gru = hk.GRU( @@ -683,9 +696,9 @@ def forward_fn( def make_GRU_rice_network( - num_actions: int, - hidden_size: int, - v2=False, + num_actions: int, + hidden_size: int, + v2=False, ): # if float_precision == jnp.float16: # policy = jmp.get_policy('params=float16,compute=float16,output=float32') @@ -693,7 +706,7 @@ def make_GRU_rice_network( hidden_state = jnp.zeros((1, hidden_size)) def forward_fn( - inputs: jnp.ndarray, state: jnp.ndarray + inputs: jnp.ndarray, state: jnp.ndarray ) -> tuple[tuple[MultivariateNormalDiag, Array], Any]: gru = hk.GRU( hidden_size, @@ -703,7 +716,11 @@ def forward_fn( ) if v2: - cvh = ContinuousValueHead(num_values=num_actions, mean_activation="sigmoid", with_bias=True) + cvh = ContinuousValueHead( + num_values=num_actions, + mean_activation="sigmoid", + with_bias=True, + ) else: cvh = ContinuousValueHead(num_values=num_actions) embedding, state = gru(inputs, state) diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index 7c15fa3e..08cb8009 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -12,13 +12,21 @@ from pax.agents.ppo.networks import ( make_coingame_network, make_ipditm_network, - make_sarl_network, make_cournot_network, - make_fishery_network, make_rice_sarl_network, + make_sarl_network, + make_cournot_network, + make_fishery_network, + make_rice_sarl_network, ) from pax.envs.rice.c_rice import ClubRice from pax.envs.rice.rice import Rice from pax.envs.rice.sarl_rice import SarlRice -from pax.utils import Logger, MemoryState, TrainingState, get_advantages, float_precision +from pax.utils import ( + Logger, + MemoryState, + TrainingState, + get_advantages, + float_precision, +) class Batch(NamedTuple): @@ -40,29 +48,29 @@ class PPO(AgentInterface): """A simple PPO agent using JAX""" def __init__( - self, - network: NamedTuple, - optimizer: optax.GradientTransformation, - random_key: jnp.ndarray, - obs_spec: Tuple, - num_envs: int = 4, - num_minibatches: int = 16, - num_epochs: int = 4, - clip_value: bool = True, - value_coeff: float = 0.5, - anneal_entropy: bool = False, - entropy_coeff_start: float = 0.1, - entropy_coeff_end: float = 0.01, - entropy_coeff_horizon: int = 3_000_000, - ppo_clipping_epsilon: float = 0.2, - gamma: float = 0.99, - gae_lambda: float = 0.95, - tabular: bool = False, - player_id: int = 0, + self, + network: NamedTuple, + optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + obs_spec: Tuple, + num_envs: int = 4, + num_minibatches: int = 16, + num_epochs: int = 4, + clip_value: bool = True, + value_coeff: float = 0.5, + anneal_entropy: bool = False, + entropy_coeff_start: float = 0.1, + entropy_coeff_end: float = 0.01, + entropy_coeff_horizon: int = 3_000_000, + ppo_clipping_epsilon: float = 0.2, + gamma: float = 0.99, + gae_lambda: float = 0.95, + tabular: bool = False, + player_id: int = 0, ): @jax.jit def policy( - state: TrainingState, observation: jnp.ndarray, mem: MemoryState + state: TrainingState, observation: jnp.ndarray, mem: MemoryState ): """Agent policy to select actions and calculate agent specific information""" key, subkey = jax.random.split(state.random_key) @@ -78,7 +86,7 @@ def policy( @jax.jit def gae_advantages( - rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray + rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray ) -> jnp.ndarray: """Calculates the gae advantages from a sequence. Note that the arguments are of length = rollout length + 1""" @@ -107,14 +115,14 @@ def gae_advantages( return advantages, target_values def loss( - params: hk.Params, - timesteps: int, - observations: jnp.ndarray, - actions: jnp.array, - behavior_log_probs: jnp.array, - target_values: jnp.array, - advantages: jnp.array, - behavior_values: jnp.array, + params: hk.Params, + timesteps: int, + observations: jnp.ndarray, + actions: jnp.array, + behavior_log_probs: jnp.array, + target_values: jnp.array, + advantages: jnp.array, + behavior_values: jnp.array, ): """Surrogate loss using clipped probability ratios.""" distribution, values = network.apply(params, observations) @@ -136,7 +144,7 @@ def loss( # Value loss: MSE value_cost = value_coeff unclipped_value_error = target_values - values - unclipped_value_loss = unclipped_value_error ** 2 + unclipped_value_loss = unclipped_value_error**2 # Value clipping if clip_value: @@ -147,7 +155,7 @@ def loss( ppo_clipping_epsilon, ) clipped_value_error = target_values - clipped_values - clipped_value_loss = clipped_value_error ** 2 + clipped_value_loss = clipped_value_error**2 value_loss = jnp.mean( jnp.fmax(unclipped_value_loss, clipped_value_loss) ) @@ -159,8 +167,8 @@ def loss( if anneal_entropy: fraction = jnp.fmax(1 - timesteps / entropy_coeff_horizon, 0) entropy_cost = ( - fraction * entropy_coeff_start - + (1 - fraction) * entropy_coeff_end + fraction * entropy_coeff_start + + (1 - fraction) * entropy_coeff_end ) # Constant Entropy term else: @@ -169,9 +177,9 @@ def loss( # Total loss: Minimize policy and value loss; maximize entropy total_loss = ( - policy_loss - + entropy_cost * entropy_loss - + value_loss * value_cost + policy_loss + + entropy_cost * entropy_loss + + value_loss * value_cost ) return total_loss, { @@ -184,7 +192,7 @@ def loss( @jax.jit def sgd_step( - state: TrainingState, sample: NamedTuple + state: TrainingState, sample: NamedTuple ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Performs a minibatch SGD step, returning new state and metrics.""" @@ -241,8 +249,8 @@ def sgd_step( @jax.jit def model_update_minibatch( - carry: Tuple[hk.Params, optax.OptState, int], - minibatch: Batch, + carry: Tuple[hk.Params, optax.OptState, int], + minibatch: Batch, ) -> Tuple[ Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] ]: @@ -250,9 +258,9 @@ def model_update_minibatch( params, opt_state, timesteps = carry # Normalize advantages at the minibatch level before using them. advantages = ( - minibatch.advantages - - jnp.mean(minibatch.advantages, axis=0) - ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) + minibatch.advantages + - jnp.mean(minibatch.advantages, axis=0) + ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) gradients, metrics = grad_fn( params, timesteps, @@ -274,10 +282,10 @@ def model_update_minibatch( @jax.jit def model_update_epoch( - carry: Tuple[ - jnp.ndarray, hk.Params, optax.OptState, int, Batch - ], - unused_t: Tuple[()], + carry: Tuple[ + jnp.ndarray, hk.Params, optax.OptState, int, Batch + ], + unused_t: Tuple[()], ) -> Tuple[ Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], Dict[str, jnp.ndarray], @@ -342,7 +350,7 @@ def model_update_epoch( return new_state, new_memory, metrics def make_initial_state( - key: Any, hidden: jnp.ndarray + key: Any, hidden: jnp.ndarray ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) @@ -379,7 +387,7 @@ def make_initial_state( ) def prepare_batch( - traj_batch: NamedTuple, done: Any, action_extras: dict + traj_batch: NamedTuple, done: Any, action_extras: dict ): # Rollouts complete -> Training begins # Add an additional rollout step for advantage calculation @@ -439,11 +447,11 @@ def reset_memory(self, memory, eval=False) -> MemoryState: return memory def update( - self, - traj_batch, - obs: jnp.ndarray, - state: TrainingState, - mem: MemoryState, + self, + traj_batch, + obs: jnp.ndarray, + state: TrainingState, + mem: MemoryState, ): """Update the agent -> only called at the end of a trajectory""" _, _, mem = self._policy(state, obs, mem) @@ -453,7 +461,7 @@ def update( ) state, mem, metrics = self._sgd_step(state, traj_batch) self._logger.metrics["sgd_steps"] += ( - self._num_minibatches * self._num_epochs + self._num_minibatches * self._num_epochs ) self._logger.metrics["loss_total"] = metrics["loss_total"] self._logger.metrics["loss_policy"] = metrics["loss_policy"] @@ -465,14 +473,14 @@ def update( def make_agent( - args, - agent_args, - obs_spec, - action_spec, - seed: int, - num_iterations: int, - player_id: int, - tabular=False, + args, + agent_args, + obs_spec, + action_spec, + seed: int, + num_iterations: int, + player_id: int, + tabular=False, ): """Make PPO agent""" print(f"Making network for {args.env_id}") @@ -508,11 +516,13 @@ def make_agent( elif args.runner == "sarl": network = make_sarl_network(action_spec) else: - raise NotImplementedError(f"No ppo network implemented for env {args.env_id}") + raise NotImplementedError( + f"No ppo network implemented for env {args.env_id}" + ) # Optimizer transition_steps = ( - num_iterations * agent_args.num_epochs * agent_args.num_minibatches + num_iterations * agent_args.num_epochs * agent_args.num_minibatches ) if agent_args.lr_scheduling: diff --git a/pax/agents/ppo/ppo_gru.py b/pax/agents/ppo/ppo_gru.py index 166a2add..a1f8317b 100644 --- a/pax/agents/ppo/ppo_gru.py +++ b/pax/agents/ppo/ppo_gru.py @@ -13,7 +13,9 @@ make_GRU_cartpole_network, make_GRU_coingame_network, make_GRU_ipd_network, - make_GRU_ipditm_network, make_GRU_fishery_network, make_GRU_rice_network, + make_GRU_ipditm_network, + make_GRU_fishery_network, + make_GRU_rice_network, ) from pax.envs.rice.rice import Rice from pax.envs.rice.c_rice import ClubRice @@ -562,7 +564,9 @@ def make_gru_agent( agent_args.kernel_shape, ) else: - raise NotImplementedError(f"No gru network implemented for env {args.env_id}") + raise NotImplementedError( + f"No gru network implemented for env {args.env_id}" + ) gru_dim = initial_hidden_state.shape[1] diff --git a/pax/agents/tensor_strategies.py b/pax/agents/tensor_strategies.py index 429be0d5..478ca0fe 100644 --- a/pax/agents/tensor_strategies.py +++ b/pax/agents/tensor_strategies.py @@ -190,8 +190,8 @@ def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: # 0th state is all c...cc # 1st state is all c...cd # (2**num_playerth)-1 state is d...dd - # we cooperate in states _CCCC (0th and (2**num_player-1)th) and in start state (state 2**num_player)th ) - # otherwise we defect + # we cooperate in states _CCCC (0th and (2**num_player-1)th) + # and in start state (state 2**num_player)th ) otherwise we defect # 0 is cooperate, 1 is defect action = jnp.where(obs % jnp.exp2(num_players - 1) == 0, 0, 1) return action diff --git a/pax/envs/cournot.py b/pax/envs/cournot.py index d5335e25..83f91513 100644 --- a/pax/envs/cournot.py +++ b/pax/envs/cournot.py @@ -18,16 +18,17 @@ class EnvParams: b: float marginal_cost: float + class CournotGame(environment.Environment): def __init__(self, num_players: int, num_inner_steps: int): super().__init__() self.num_players = num_players def _step( - key: chex.PRNGKey, - state: EnvState, - actions: Tuple[float, ...], - params: EnvParams, + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, ): assert len(actions) == num_players t = state.outer_t @@ -61,7 +62,7 @@ def _step( ) def _reset( - key: chex.PRNGKey, params: EnvParams + key: chex.PRNGKey, params: EnvParams ) -> Tuple[Tuple, EnvState]: state = EnvState( inner_t=jnp.zeros((), dtype=jnp.int8), @@ -84,15 +85,18 @@ def num_actions(self) -> int: """Number of actions possible in environment.""" return 1 - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: """Action space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=(1,)) + return spaces.Box(low=0, high=float("inf"), shape=(1,)) def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=self.num_players + 1, dtype=jnp.float32) + return spaces.Box( + low=0, + high=float("inf"), + shape=self.num_players + 1, + dtype=jnp.float32, + ) @staticmethod def nash_policy(params: EnvParams) -> float: diff --git a/pax/envs/fishery.py b/pax/envs/fishery.py index 86497002..c86dc23e 100644 --- a/pax/envs/fishery.py +++ b/pax/envs/fishery.py @@ -33,8 +33,8 @@ def to_obs_array(params: EnvParams) -> jnp.ndarray: Biological sub-model: -The stock of fish is given by G and grows at a natural rate of g. -The stock is harvested at a rate of E (sum over agent actions) and the harvest is sold at a price of p. +The stock of fish is given by G and grows at a natural rate of g. +The stock is harvested at a rate of E (sum over agent actions) and the harvest is sold at a price of p. g growth rate e efficiency parameter @@ -53,10 +53,10 @@ def __init__(self, num_players: int, num_inner_steps: int): self.num_players = num_players def _step( - key: chex.PRNGKey, - state: EnvState, - actions: Tuple[float, ...], - params: EnvParams, + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, ): t = state.inner_t + 1 key, _ = jax.random.split(key, 2) @@ -71,10 +71,7 @@ def _step( # Prevent s from dropping below 0 H = jnp.clip(E * state.s * params.e, a_max=s_growth) s_next = s_growth - H - next_state = EnvState( - inner_t=t, outer_t=state.outer_t, - s=s_next - ) + next_state = EnvState(inner_t=t, outer_t=state.outer_t, s=s_next) reset_obs, reset_state = _reset(key, params) reset_state = reset_state.replace(outer_t=state.outer_t + 1) @@ -111,12 +108,12 @@ def _step( ) def _reset( - key: chex.PRNGKey, params: EnvParams + key: chex.PRNGKey, params: EnvParams ) -> Tuple[Tuple, EnvState]: state = EnvState( inner_t=jnp.zeros((), dtype=jnp.int16), outer_t=jnp.zeros((), dtype=jnp.int16), - s=params.s_0 + s=params.s_0, ) obs = jax.random.uniform(key, (num_players,)) obs = jnp.concatenate([obs, jnp.array([state.s])]) @@ -130,15 +127,18 @@ def num_actions(self) -> int: """Number of actions possible in environment.""" return 1 - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: """Action space of the environment.""" return spaces.Box(low=0, high=params.s_max, shape=(1,)) def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" - return spaces.Box(low=0, high=float('inf'), shape=self.num_players + 1, dtype=jnp.float32) + return spaces.Box( + low=0, + high=float("inf"), + shape=self.num_players + 1, + dtype=jnp.float32, + ) @staticmethod def equilibrium(params: EnvParams) -> float: diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 9199c5f6..33ea79d5 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -104,9 +104,7 @@ def num_actions(self) -> int: """Number of actions possible in environment.""" return 5 - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: """Action space of the environment.""" return spaces.Box(low=0, high=1, shape=(5,)) diff --git a/pax/envs/rice/c_rice.py b/pax/envs/rice/c_rice.py index 683cd117..ddf06c12 100644 --- a/pax/envs/rice/c_rice.py +++ b/pax/envs/rice/c_rice.py @@ -7,11 +7,34 @@ from gymnax.environments import environment, spaces from jax import Array -from pax.envs.rice.rice import EnvState, EnvParams, get_consumption, get_armington_agg, get_utility, get_social_welfare, \ - get_global_temperature, load_rice_params, get_exogenous_emissions, get_mitigation_cost, get_carbon_intensity, \ - get_abatement_cost, get_damages, get_production, get_gross_output, get_investment, zero_diag, \ - get_max_potential_exports, get_land_emissions, get_aux_m, get_global_carbon_mass, get_capital_depreciation, \ - get_capital, get_labor, get_production_factor, get_carbon_price +from pax.envs.rice.rice import ( + EnvState, + EnvParams, + get_consumption, + get_armington_agg, + get_utility, + get_social_welfare, + get_global_temperature, + load_rice_params, + get_exogenous_emissions, + get_mitigation_cost, + get_carbon_intensity, + get_abatement_cost, + get_damages, + get_production, + get_gross_output, + get_investment, + zero_diag, + get_max_potential_exports, + get_land_emissions, + get_aux_m, + get_global_carbon_mass, + get_capital_depreciation, + get_capital, + get_labor, + get_production_factor, + get_carbon_price, +) from pax.utils import float_precision eps = 1e-5 @@ -23,32 +46,35 @@ 1. Regions can join the club at any time 2. The club has a fixed tariff rate 3. Club members must implement a minimum mitigation rate - -If the mediator is enabled it can choose the club tariff rate and the minimum mitigation rate. - + +If the mediator is enabled it can choose the club tariff rate and the minimum mitigation rate. + """ class ClubRice(environment.Environment): env_id: str = "C-Rice-N" - def __init__(self, - config_folder: str, - has_mediator=False, - mediator_climate_weight=0, - mediator_utility_weight=1, - mediator_climate_objective=False, - default_club_tariff_rate=0.1, - default_club_mitigation_rate=0.3, - episode_length: int = 20, - ): + def __init__( + self, + config_folder: str, + has_mediator=False, + mediator_climate_weight=0, + mediator_utility_weight=1, + mediator_climate_objective=False, + default_club_tariff_rate=0.1, + default_club_mitigation_rate=0.3, + episode_length: int = 20, + ): super().__init__() params, num_regions = load_rice_params(config_folder) self.has_mediator = has_mediator self.num_players = num_regions self.episode_length = episode_length - self.num_actors = self.num_players + 1 if self.has_mediator else self.num_players + self.num_actors = ( + self.num_players + 1 if self.has_mediator else self.num_players + ) self.rice_constant = params["_RICE_GLOBAL_CONSTANT"] self.dice_constant = params["_DICE_CONSTANT"] self.region_constants = params["_REGIONS"] @@ -67,41 +93,62 @@ def __init__(self, self.join_club_action_n = 1 self.actions_n = ( - self.savings_action_n - + self.mitigation_rate_action_n - + self.export_action_n - + self.import_actions_n - + self.tariff_actions_n - + self.join_club_action_n + self.savings_action_n + + self.mitigation_rate_action_n + + self.export_action_n + + self.import_actions_n + + self.tariff_actions_n + + self.join_club_action_n ) # Determine the index of each action to slice them in the step function self.savings_action_index = 0 - self.mitigation_rate_action_index = self.savings_action_index + self.savings_action_n - self.export_action_index = self.mitigation_rate_action_index + self.mitigation_rate_action_n - self.tariffs_action_index = self.export_action_index + self.export_action_n - self.desired_imports_action_index = self.tariffs_action_index + self.tariff_actions_n - self.join_club_action_index = self.desired_imports_action_index + self.join_club_action_n + self.mitigation_rate_action_index = ( + self.savings_action_index + self.savings_action_n + ) + self.export_action_index = ( + self.mitigation_rate_action_index + self.mitigation_rate_action_n + ) + self.tariffs_action_index = ( + self.export_action_index + self.export_action_n + ) + self.desired_imports_action_index = ( + self.tariffs_action_index + self.tariff_actions_n + ) + self.join_club_action_index = ( + self.desired_imports_action_index + self.join_club_action_n + ) # Parameters for armington aggregation utility self.sub_rate = jnp.asarray(0.5, dtype=float_precision) self.dom_pref = jnp.asarray(0.5, dtype=float_precision) - self.for_pref = jnp.asarray([0.5 / (self.num_players - 1)] * self.num_players, dtype=float_precision) - self.default_club_tariff_rate = jnp.asarray(default_club_tariff_rate, dtype=float_precision) - self.default_club_mitigation_rate = jnp.asarray(default_club_mitigation_rate, dtype=float_precision) + self.for_pref = jnp.asarray( + [0.5 / (self.num_players - 1)] * self.num_players, + dtype=float_precision, + ) + self.default_club_tariff_rate = jnp.asarray( + default_club_tariff_rate, dtype=float_precision + ) + self.default_club_mitigation_rate = jnp.asarray( + default_club_mitigation_rate, dtype=float_precision + ) if mediator_climate_objective: self.mediator_climate_weight = 1 self.mediator_utility_weight = 0 else: - self.mediator_climate_weight = jnp.asarray(mediator_climate_weight, dtype=float_precision) - self.mediator_utility_weight = jnp.asarray(mediator_utility_weight, dtype=float_precision) + self.mediator_climate_weight = jnp.asarray( + mediator_climate_weight, dtype=float_precision + ) + self.mediator_utility_weight = jnp.asarray( + mediator_utility_weight, dtype=float_precision + ) def _step( - key: chex.PRNGKey, - state: EnvState, - actions: Tuple[float, ...], - params: EnvParams, + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, ): t = state.inner_t + 1 # Rice equations expect to start at t=1 key, _ = jax.random.split(key, 2) @@ -109,7 +156,10 @@ def _step( t_at = state.global_temperature[0] global_exogenous_emissions = get_exogenous_emissions( - self.dice_constant["xf_0"], self.dice_constant["xf_1"], self.dice_constant["xt_f"], t + self.dice_constant["xf_0"], + self.dice_constant["xf_1"], + self.dice_constant["xt_f"], + t, ) actions = jnp.asarray(actions).astype(float_precision).squeeze() @@ -121,24 +171,28 @@ def _step( region_actions = actions # TODO it'd be better if this variable was categorical in the agent network - club_membership_all = jnp.round(region_actions[:, self.join_club_action_index]).astype(jnp.int8) + club_membership_all = jnp.round( + region_actions[:, self.join_club_action_index] + ).astype(jnp.int8) mitigation_cost_all = get_mitigation_cost( self.rice_constant["xp_b"], self.rice_constant["xtheta_2"], self.rice_constant["xdelta_pb"], state.intensity_all, - t + t, ) intensity_all = get_carbon_intensity( state.intensity_all, self.rice_constant["xg_sigma"], self.rice_constant["xdelta_sigma"], self.dice_constant["xDelta"], - t + t, ) if has_mediator: - club_mitigation_rate = actions[0, self.mitigation_rate_action_index] + club_mitigation_rate = actions[ + 0, self.mitigation_rate_action_index + ] club_tariff_rate = actions[0, self.tariffs_action_index] else: # Get the maximum carbon price of non-members from the last timestep @@ -155,47 +209,78 @@ def _step( mitigation_rate_all = jnp.where( club_membership_all == 1, club_mitigation_rate, - region_actions[:, self.mitigation_rate_action_index] + region_actions[:, self.mitigation_rate_action_index], ) abatement_cost_all = get_abatement_cost( - mitigation_rate_all, mitigation_cost_all, - self.rice_constant["xtheta_2"] + mitigation_rate_all, + mitigation_cost_all, + self.rice_constant["xtheta_2"], + ) + damages_all = get_damages( + t_at, + self.region_params["xa_1"], + self.region_params["xa_2"], + self.region_params["xa_3"], ) - damages_all = get_damages(t_at, self.region_params["xa_1"], self.region_params["xa_2"], - self.region_params["xa_3"]) production_all = get_production( state.production_factor_all, state.capital_all, state.labor_all, self.region_params["xgamma"], ) - gross_output_all = get_gross_output(damages_all, abatement_cost_all, production_all) + gross_output_all = get_gross_output( + damages_all, abatement_cost_all, production_all + ) balance_all = state.balance_all * (1 + self.rice_constant["r"]) - investment_all = get_investment(region_actions[:, self.savings_action_index], gross_output_all) + investment_all = get_investment( + region_actions[:, self.savings_action_index], gross_output_all + ) # Trade - desired_imports = region_actions[:, - self.desired_imports_action_index: self.desired_imports_action_index + self.import_actions_n] + desired_imports = region_actions[ + :, + self.desired_imports_action_index : self.desired_imports_action_index + + self.import_actions_n, + ] # Countries cannot import from themselves desired_imports = zero_diag(desired_imports) total_desired_imports = desired_imports.sum(axis=1) - clipped_desired_imports = jnp.clip(total_desired_imports, 0, gross_output_all) - desired_imports = desired_imports * ( - # Transpose to apply the scaling row and not column-wise - clipped_desired_imports / (total_desired_imports + eps))[:, - jnp.newaxis] + clipped_desired_imports = jnp.clip( + total_desired_imports, 0, gross_output_all + ) + desired_imports = ( + desired_imports + * ( + # Transpose to apply the scaling row and not column-wise + clipped_desired_imports + / (total_desired_imports + eps) + )[:, jnp.newaxis] + ) init_capital_multiplier = 10.0 - debt_ratio = init_capital_multiplier * balance_all / self.region_params["xK_0"] + debt_ratio = ( + init_capital_multiplier + * balance_all + / self.region_params["xK_0"] + ) debt_ratio = jnp.clip(debt_ratio, -1.0, 0.0) scaled_imports = desired_imports * (1 + debt_ratio) max_potential_exports = get_max_potential_exports( - region_actions[:, self.export_action_index], gross_output_all, investment_all) + region_actions[:, self.export_action_index], + gross_output_all, + investment_all, + ) # Summing along columns yields the total number of exports requested by other regions total_desired_exports = jnp.sum(scaled_imports, axis=0) - clipped_desired_exports = jnp.clip(total_desired_exports, 0, max_potential_exports) - scaled_imports = scaled_imports * clipped_desired_exports / (total_desired_exports + eps) + clipped_desired_exports = jnp.clip( + total_desired_exports, 0, max_potential_exports + ) + scaled_imports = ( + scaled_imports + * clipped_desired_exports + / (total_desired_exports + eps) + ) prev_tariffs = state.future_tariff tariffed_imports = scaled_imports * (1 - prev_tariffs) @@ -205,13 +290,28 @@ def _step( total_exports = scaled_imports.sum(axis=0) balance_all = balance_all + self.dice_constant["xDelta"] * ( - total_exports - scaled_imports.sum(axis=1)) + total_exports - scaled_imports.sum(axis=1) + ) - c_dom = get_consumption(gross_output_all, investment_all, total_exports) - consumption_all = get_armington_agg(c_dom, tariffed_imports, self.sub_rate, self.dom_pref, self.for_pref) - utility_all = get_utility(state.labor_all, consumption_all, self.rice_constant["xalpha"]) - social_welfare_all = get_social_welfare(utility_all, self.rice_constant["xrho"], - self.dice_constant["xDelta"], t) + c_dom = get_consumption( + gross_output_all, investment_all, total_exports + ) + consumption_all = get_armington_agg( + c_dom, + tariffed_imports, + self.sub_rate, + self.dom_pref, + self.for_pref, + ) + utility_all = get_utility( + state.labor_all, consumption_all, self.rice_constant["xalpha"] + ) + social_welfare_all = get_social_welfare( + utility_all, + self.rice_constant["xrho"], + self.dice_constant["xDelta"], + t, + ) # Update ecology m_at = state.global_carbon_mass[0] @@ -226,14 +326,17 @@ def _step( ) global_land_emissions = get_land_emissions( - self.dice_constant["xE_L0"], self.dice_constant["xdelta_EL"], t, self.num_players + self.dice_constant["xE_L0"], + self.dice_constant["xdelta_EL"], + t, + self.num_players, ) aux_m_all = get_aux_m( state.intensity_all, mitigation_rate_all, production_all, - global_land_emissions + global_land_emissions, ) global_carbon_mass = get_global_carbon_mass( @@ -247,11 +350,16 @@ def _step( self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] ) capital_all = get_capital( - capital_depreciation, state.capital_all, + capital_depreciation, + state.capital_all, self.dice_constant["xDelta"], - investment_all + investment_all, + ) + labor_all = get_labor( + state.labor_all, + self.region_params["xL_a"], + self.region_params["xl_g"], ) - labor_all = get_labor(state.labor_all, self.region_params["xL_a"], self.region_params["xl_g"]) production_factor_all = get_production_factor( state.production_factor_all, self.region_params["xg_A"], @@ -259,39 +367,45 @@ def _step( self.dice_constant["xDelta"], t, ) - carbon_price_all = get_carbon_price(mitigation_cost_all, intensity_all, - mitigation_rate_all, - self.rice_constant["xtheta_2"], - damages_all) + carbon_price_all = get_carbon_price( + mitigation_cost_all, + intensity_all, + mitigation_rate_all, + self.rice_constant["xtheta_2"], + damages_all, + ) - desired_future_tariffs = region_actions[:, - self.tariffs_action_index: self.tariffs_action_index + self.num_players] + desired_future_tariffs = region_actions[ + :, + self.tariffs_action_index : self.tariffs_action_index + + self.num_players, + ] # Club members impose a minimum tariff of the club tariff rate future_tariffs = jnp.where( (club_membership_all == 1).reshape(-1, 1), jnp.clip(desired_future_tariffs, a_min=club_tariff_rate), - desired_future_tariffs + desired_future_tariffs, ) # Club members don't impose tariffs on themselves or other club members - membership_mask = club_membership_all.reshape(-1, 1) * club_membership_all + membership_mask = ( + club_membership_all.reshape(-1, 1) * club_membership_all + ) future_tariffs = future_tariffs * (1 - membership_mask) future_tariffs = zero_diag(future_tariffs) next_state = EnvState( - inner_t=state.inner_t + 1, outer_t=state.outer_t, + inner_t=state.inner_t + 1, + outer_t=state.outer_t, global_temperature=global_temperature, global_carbon_mass=global_carbon_mass, global_exogenous_emissions=global_exogenous_emissions, global_land_emissions=global_land_emissions, - labor_all=labor_all, capital_all=capital_all, production_factor_all=production_factor_all, intensity_all=intensity_all, balance_all=balance_all, - future_tariff=future_tariffs, - gross_output_all=gross_output_all, investment_all=investment_all, production_all=production_all, @@ -302,7 +416,6 @@ def _step( consumption_all=consumption_all, damages_all=damages_all, abatement_cost_all=abatement_cost_all, - tariff_revenue_all=tariff_revenue_all, carbon_price_all=carbon_price_all, club_membership_all=club_membership_all, @@ -313,12 +426,22 @@ def _step( obs = [] if self.has_mediator: - obs.append(self._generate_mediator_observation(next_state, club_mitigation_rate, club_tariff_rate)) + obs.append( + self._generate_mediator_observation( + next_state, club_mitigation_rate, club_tariff_rate + ) + ) for i in range(self.num_players): - obs.append(self._generate_observation(i, next_state, club_mitigation_rate, club_tariff_rate)) - - obs = jax.tree_map(lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs)) + obs.append( + self._generate_observation( + i, next_state, club_mitigation_rate, club_tariff_rate + ) + ) + + obs = jax.tree_map( + lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs) + ) result_state = jax.tree_map( lambda x, y: jnp.where(done, x, y), @@ -328,11 +451,21 @@ def _step( rewards = result_state.utility_all if self.has_mediator: - temp_increase = next_state.global_temperature[0] - state.global_temperature[0] + temp_increase = ( + next_state.global_temperature[0] + - state.global_temperature[0] + ) # Rescale the social reward to make the weights comparable - social_reward = result_state.utility_all.sum() / (self.num_players * 10) - mediator_reward = - self.mediator_climate_weight * temp_increase + self.mediator_utility_weight * social_reward - rewards = jnp.insert(result_state.utility_all, 0, mediator_reward) + social_reward = result_state.utility_all.sum() / ( + self.num_players * 10 + ) + mediator_reward = ( + -self.mediator_climate_weight * temp_increase + + self.mediator_utility_weight * social_reward + ) + rewards = jnp.insert( + result_state.utility_all, 0, mediator_reward + ) return ( tuple(obs), @@ -343,15 +476,23 @@ def _step( ) def _reset( - key: chex.PRNGKey, params: EnvParams + key: chex.PRNGKey, params: EnvParams ) -> Tuple[Tuple, EnvState]: state = self._get_initial_state() obs = [] club_state = jax.random.uniform(key, (2,)) if self.has_mediator: - obs.append(self._generate_mediator_observation(state, club_state[0], club_state[1])) + obs.append( + self._generate_mediator_observation( + state, club_state[0], club_state[1] + ) + ) for i in range(self.num_players): - obs.append(self._generate_observation(i, state, club_state[0], club_state[1])) + obs.append( + self._generate_observation( + i, state, club_state[0], club_state[1] + ) + ) return tuple(obs), state self.step = jax.jit(_step) @@ -361,66 +502,98 @@ def _get_initial_state(self) -> EnvState: return EnvState( inner_t=jnp.zeros((), dtype=jnp.int16), outer_t=jnp.zeros((), dtype=jnp.int16), - global_temperature=jnp.array([self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]]), + global_temperature=jnp.array( + [self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]] + ), global_carbon_mass=jnp.array( - [self.dice_constant["xM_AT_0"], self.dice_constant["xM_UP_0"], self.dice_constant["xM_LO_0"]]), + [ + self.dice_constant["xM_AT_0"], + self.dice_constant["xM_UP_0"], + self.dice_constant["xM_LO_0"], + ] + ), global_exogenous_emissions=jnp.zeros((), dtype=float_precision), global_land_emissions=jnp.zeros((), dtype=float_precision), labor_all=self.region_params["xL_0"], capital_all=self.region_params["xK_0"], production_factor_all=self.region_params["xA_0"], intensity_all=self.region_params["xsigma_0"], - balance_all=jnp.zeros(self.num_players, dtype=float_precision), - future_tariff=jnp.zeros((self.num_players, self.num_players), dtype=float_precision), - - gross_output_all=jnp.zeros(self.num_players, dtype=float_precision), + future_tariff=jnp.zeros( + (self.num_players, self.num_players), dtype=float_precision + ), + gross_output_all=jnp.zeros( + self.num_players, dtype=float_precision + ), investment_all=jnp.zeros(self.num_players, dtype=float_precision), production_all=jnp.zeros(self.num_players, dtype=float_precision), utility_all=jnp.zeros(self.num_players, dtype=float_precision), - social_welfare_all=jnp.zeros(self.num_players, dtype=float_precision), + social_welfare_all=jnp.zeros( + self.num_players, dtype=float_precision + ), capital_depreciation_all=jnp.zeros(0, dtype=float_precision), - mitigation_cost_all=jnp.zeros(self.num_players, dtype=float_precision), + mitigation_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), consumption_all=jnp.zeros(self.num_players, dtype=float_precision), damages_all=jnp.zeros(self.num_players, dtype=float_precision), - abatement_cost_all=jnp.zeros(self.num_players, dtype=float_precision), - - tariff_revenue_all=jnp.zeros(self.num_players, dtype=float_precision), - carbon_price_all=jnp.zeros(self.num_players, dtype=float_precision), + abatement_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + tariff_revenue_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + carbon_price_all=jnp.zeros( + self.num_players, dtype=float_precision + ), club_membership_all=jnp.zeros(self.num_players, dtype=jnp.int8), ) - def _generate_observation(self, index: int, state: EnvState, club_mitigation_rate, club_tariff_rate) -> Array: - return jnp.concatenate([ - # Public features - jnp.asarray([index]), - jnp.asarray([state.inner_t]), - jnp.array([club_mitigation_rate]), - jnp.array([club_tariff_rate]), - state.global_temperature, - state.global_carbon_mass, - state.gross_output_all, - state.investment_all, - state.abatement_cost_all, - state.tariff_revenue_all, - state.club_membership_all, - ], dtype=float_precision) - - def _generate_mediator_observation(self, state: EnvState, club_mitigation_rate, club_tariff_rate) -> Array: - return jnp.concatenate([ - # Public features - jnp.zeros(1), - jnp.asarray([state.inner_t]), - jnp.array([club_mitigation_rate]), - jnp.array([club_tariff_rate]), - state.global_temperature, - state.global_carbon_mass, - state.gross_output_all, - state.investment_all, - state.abatement_cost_all, - state.tariff_revenue_all, - state.club_membership_all, - ], dtype=float_precision) + def _generate_observation( + self, + index: int, + state: EnvState, + club_mitigation_rate, + club_tariff_rate, + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.asarray([index]), + jnp.asarray([state.inner_t]), + jnp.array([club_mitigation_rate]), + jnp.array([club_tariff_rate]), + state.global_temperature, + state.global_carbon_mass, + state.gross_output_all, + state.investment_all, + state.abatement_cost_all, + state.tariff_revenue_all, + state.club_membership_all, + ], + dtype=float_precision, + ) + + def _generate_mediator_observation( + self, state: EnvState, club_mitigation_rate, club_tariff_rate + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.zeros(1), + jnp.asarray([state.inner_t]), + jnp.array([club_mitigation_rate]), + jnp.array([club_tariff_rate]), + state.global_temperature, + state.global_carbon_mass, + state.gross_output_all, + state.investment_all, + state.abatement_cost_all, + state.tariff_revenue_all, + state.club_membership_all, + ], + dtype=float_precision, + ) @property def name(self) -> str: @@ -430,16 +603,21 @@ def name(self) -> str: def num_actions(self) -> int: return self.actions_n - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: return spaces.Box(low=0, high=1, shape=(self.actions_n,)) def observation_space(self, params: EnvParams) -> spaces.Box: init_state = self._get_initial_state() obs = self._generate_observation(0, init_state, 0, 0) - return spaces.Box(low=0, high=float('inf'), shape=obs.shape, dtype=float_precision) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=float_precision + ) -def get_club_mitigation_rates(coalition_price, intensity, theta_2, mitigation_cost, damages): - return pow(coalition_price * intensity / (theta_2 * mitigation_cost * damages), 1 / (theta_2 - 1)) +def get_club_mitigation_rates( + coalition_price, intensity, theta_2, mitigation_cost, damages +): + return pow( + coalition_price * intensity / (theta_2 * mitigation_cost * damages), + 1 / (theta_2 - 1), + ) diff --git a/pax/envs/rice/rice.py b/pax/envs/rice/rice.py index 1fc69081..9c5f5df8 100644 --- a/pax/envs/rice/rice.py +++ b/pax/envs/rice/rice.py @@ -68,7 +68,9 @@ class EnvParams: class Rice(environment.Environment): env_id: str = "Rice-N" - def __init__(self, config_folder: str, has_mediator=False, episode_length=20): + def __init__( + self, config_folder: str, has_mediator=False, episode_length=20 + ): super().__init__() # TODO refactor all the constants to use env_params @@ -78,7 +80,9 @@ def __init__(self, config_folder: str, has_mediator=False, episode_length=20): params, num_regions = load_rice_params(config_folder) self.has_mediator = has_mediator self.num_players = num_regions - self.num_actors = self.num_players + 1 if self.has_mediator else self.num_players + self.num_actors = ( + self.num_players + 1 if self.has_mediator else self.num_players + ) self.rice_constant = params["_RICE_GLOBAL_CONSTANT"] self.dice_constant = params["_DICE_CONSTANT"] self.region_constants = params["_REGIONS"] @@ -96,31 +100,42 @@ def __init__(self, config_folder: str, has_mediator=False, episode_length=20): self.tariff_actions_n = self.num_players self.actions_n = ( - self.savings_action_n - + self.mitigation_rate_action_n - + self.export_action_n - + self.import_actions_n - + self.tariff_actions_n + self.savings_action_n + + self.mitigation_rate_action_n + + self.export_action_n + + self.import_actions_n + + self.tariff_actions_n ) # Determine the index of each action to slice them in the step function self.savings_action_index = 0 - self.mitigation_rate_action_index = self.savings_action_index + self.savings_action_n - self.export_action_index = self.mitigation_rate_action_index + self.mitigation_rate_action_n - self.tariffs_action_index = self.export_action_index + self.export_action_n - self.desired_imports_action_index = self.tariffs_action_index + self.import_actions_n + self.mitigation_rate_action_index = ( + self.savings_action_index + self.savings_action_n + ) + self.export_action_index = ( + self.mitigation_rate_action_index + self.mitigation_rate_action_n + ) + self.tariffs_action_index = ( + self.export_action_index + self.export_action_n + ) + self.desired_imports_action_index = ( + self.tariffs_action_index + self.import_actions_n + ) # Parameters for armington aggregation utility self.sub_rate = jnp.asarray(0.5, dtype=float_precision) self.dom_pref = jnp.asarray(0.5, dtype=float_precision) - self.for_pref = jnp.asarray([0.5 / (self.num_players - 1)] * self.num_players, dtype=float_precision) + self.for_pref = jnp.asarray( + [0.5 / (self.num_players - 1)] * self.num_players, + dtype=float_precision, + ) self.episode_length = episode_length def _step( - key: chex.PRNGKey, - state: EnvState, - actions: Tuple[float, ...], - params: EnvParams, + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, ): t = state.inner_t + 1 # Rice equations expect to start at t=1 key, _ = jax.random.split(key, 2) @@ -128,7 +143,10 @@ def _step( t_at = state.global_temperature[0] global_exogenous_emissions = get_exogenous_emissions( - self.dice_constant["xf_0"], self.dice_constant["xf_1"], self.dice_constant["xt_f"], t + self.dice_constant["xf_0"], + self.dice_constant["xf_1"], + self.dice_constant["xt_f"], + t, ) actions = jnp.asarray(actions).astype(float_precision).squeeze() @@ -144,46 +162,77 @@ def _step( self.rice_constant["xtheta_2"], self.rice_constant["xdelta_pb"], state.intensity_all, - t + t, ) abatement_cost_all = get_abatement_cost( - region_actions[:, self.mitigation_rate_action_index], mitigation_cost_all, - self.rice_constant["xtheta_2"] + region_actions[:, self.mitigation_rate_action_index], + mitigation_cost_all, + self.rice_constant["xtheta_2"], + ) + damages_all = get_damages( + t_at, + self.region_params["xa_1"], + self.region_params["xa_2"], + self.region_params["xa_3"], ) - damages_all = get_damages(t_at, self.region_params["xa_1"], self.region_params["xa_2"], - self.region_params["xa_3"]) production_all = get_production( state.production_factor_all, state.capital_all, state.labor_all, self.region_params["xgamma"], ) - gross_output_all = get_gross_output(damages_all, abatement_cost_all, production_all) + gross_output_all = get_gross_output( + damages_all, abatement_cost_all, production_all + ) balance_all = state.balance_all * (1 + self.rice_constant["r"]) - investment_all = get_investment(region_actions[:, self.savings_action_index], gross_output_all) + investment_all = get_investment( + region_actions[:, self.savings_action_index], gross_output_all + ) # Trade - desired_imports = region_actions[:, - self.desired_imports_action_index: self.desired_imports_action_index + self.import_actions_n] + desired_imports = region_actions[ + :, + self.desired_imports_action_index : self.desired_imports_action_index + + self.import_actions_n, + ] # Countries cannot import from themselves desired_imports = zero_diag(desired_imports) total_desired_imports = desired_imports.sum(axis=1) - clipped_desired_imports = jnp.clip(total_desired_imports, 0, gross_output_all) - desired_imports = desired_imports * ( - # Transpose to apply the scaling row and not column-wise - clipped_desired_imports / (total_desired_imports + eps))[:, - jnp.newaxis] + clipped_desired_imports = jnp.clip( + total_desired_imports, 0, gross_output_all + ) + desired_imports = ( + desired_imports + * ( + # Transpose to apply the scaling row and not column-wise + clipped_desired_imports + / (total_desired_imports + eps) + )[:, jnp.newaxis] + ) init_capital_multiplier = 10.0 - debt_ratio = init_capital_multiplier * balance_all / self.region_params["xK_0"] + debt_ratio = ( + init_capital_multiplier + * balance_all + / self.region_params["xK_0"] + ) debt_ratio = jnp.clip(debt_ratio, -1.0, 0.0) scaled_imports = desired_imports * (1 + debt_ratio) max_potential_exports = get_max_potential_exports( - region_actions[:, self.export_action_index], gross_output_all, investment_all) + region_actions[:, self.export_action_index], + gross_output_all, + investment_all, + ) # Summing along columns yields the total number of exports requested by other regions total_desired_exports = jnp.sum(scaled_imports, axis=0) - clipped_desired_exports = jnp.clip(total_desired_exports, 0, max_potential_exports) - scaled_imports = scaled_imports * clipped_desired_exports / (total_desired_exports + eps) + clipped_desired_exports = jnp.clip( + total_desired_exports, 0, max_potential_exports + ) + scaled_imports = ( + scaled_imports + * clipped_desired_exports + / (total_desired_exports + eps) + ) prev_tariffs = state.future_tariff tariffed_imports = scaled_imports * (1 - prev_tariffs) @@ -193,13 +242,28 @@ def _step( total_exports = scaled_imports.sum(axis=0) balance_all = balance_all + self.dice_constant["xDelta"] * ( - total_exports - scaled_imports.sum(axis=1)) + total_exports - scaled_imports.sum(axis=1) + ) - c_dom = get_consumption(gross_output_all, investment_all, total_exports) - consumption_all = get_armington_agg(c_dom, tariffed_imports, self.sub_rate, self.dom_pref, self.for_pref) - utility_all = get_utility(state.labor_all, consumption_all, self.rice_constant["xalpha"]) - social_welfare_all = get_social_welfare(utility_all, self.rice_constant["xrho"], - self.dice_constant["xDelta"], t) + c_dom = get_consumption( + gross_output_all, investment_all, total_exports + ) + consumption_all = get_armington_agg( + c_dom, + tariffed_imports, + self.sub_rate, + self.dom_pref, + self.for_pref, + ) + utility_all = get_utility( + state.labor_all, consumption_all, self.rice_constant["xalpha"] + ) + social_welfare_all = get_social_welfare( + utility_all, + self.rice_constant["xrho"], + self.dice_constant["xDelta"], + t, + ) # Update ecology m_at = state.global_carbon_mass[0] @@ -214,14 +278,17 @@ def _step( ) global_land_emissions = get_land_emissions( - self.dice_constant["xE_L0"], self.dice_constant["xdelta_EL"], t, self.num_players + self.dice_constant["xE_L0"], + self.dice_constant["xdelta_EL"], + t, + self.num_players, ) aux_m_all = get_aux_m( state.intensity_all, region_actions[:, self.mitigation_rate_action_index], production_all, - global_land_emissions + global_land_emissions, ) global_carbon_mass = get_global_carbon_mass( @@ -235,18 +302,23 @@ def _step( self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] ) capital_all = get_capital( - capital_depreciation, state.capital_all, + capital_depreciation, + state.capital_all, self.dice_constant["xDelta"], - investment_all + investment_all, ) intensity_all = get_carbon_intensity( state.intensity_all, self.rice_constant["xg_sigma"], self.rice_constant["xdelta_sigma"], self.dice_constant["xDelta"], - t + t, + ) + labor_all = get_labor( + state.labor_all, + self.region_params["xL_a"], + self.region_params["xl_g"], ) - labor_all = get_labor(state.labor_all, self.region_params["xL_a"], self.region_params["xl_g"]) production_factor_all = get_production_factor( state.production_factor_all, self.region_params["xg_A"], @@ -254,27 +326,31 @@ def _step( self.dice_constant["xDelta"], t, ) - carbon_price_all = get_carbon_price(mitigation_cost_all, intensity_all, - region_actions[:, self.mitigation_rate_action_index], - self.rice_constant["xtheta_2"], - damages_all) + carbon_price_all = get_carbon_price( + mitigation_cost_all, + intensity_all, + region_actions[:, self.mitigation_rate_action_index], + self.rice_constant["xtheta_2"], + damages_all, + ) next_state = EnvState( - inner_t=state.inner_t + 1, outer_t=state.outer_t, + inner_t=state.inner_t + 1, + outer_t=state.outer_t, global_temperature=global_temperature, global_carbon_mass=global_carbon_mass, global_exogenous_emissions=global_exogenous_emissions, global_land_emissions=global_land_emissions, - labor_all=labor_all, capital_all=capital_all, production_factor_all=production_factor_all, intensity_all=intensity_all, balance_all=balance_all, - - future_tariff=region_actions[:, - self.tariffs_action_index: self.tariffs_action_index + self.num_players], - + future_tariff=region_actions[ + :, + self.tariffs_action_index : self.tariffs_action_index + + self.num_players, + ], gross_output_all=gross_output_all, investment_all=investment_all, production_all=production_all, @@ -285,10 +361,11 @@ def _step( consumption_all=consumption_all, damages_all=damages_all, abatement_cost_all=abatement_cost_all, - tariff_revenue_all=tariff_revenue_all, carbon_price_all=carbon_price_all, - club_membership_all=jnp.zeros(self.num_players, dtype=jnp.int8), + club_membership_all=jnp.zeros( + self.num_players, dtype=jnp.int8 + ), ) reset_obs, reset_state = _reset(key, params) @@ -296,12 +373,16 @@ def _step( obs = [] if self.has_mediator: - obs.append(self._generate_mediator_observation(actions, next_state)) + obs.append( + self._generate_mediator_observation(actions, next_state) + ) for i in range(self.num_players): obs.append(self._generate_observation(i, actions, next_state)) - obs = jax.tree_map(lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs)) + obs = jax.tree_map( + lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs) + ) state = jax.tree_map( lambda x, y: jnp.where(done, x, y), @@ -310,7 +391,9 @@ def _step( ) if self.has_mediator: - rewards = jnp.insert(state.utility_all, 0, state.utility_all.sum()) + rewards = jnp.insert( + state.utility_all, 0, state.utility_all.sum() + ) else: rewards = state.utility_all @@ -323,10 +406,15 @@ def _step( ) def _reset( - key: chex.PRNGKey, params: EnvParams + key: chex.PRNGKey, params: EnvParams ) -> Tuple[Tuple, EnvState]: state = self._get_initial_state() - actions = jnp.asarray([jax.random.uniform(key, (self.num_actions,)) for _ in range(self.num_actors)]) + actions = jnp.asarray( + [ + jax.random.uniform(key, (self.num_actions,)) + for _ in range(self.num_actors) + ] + ) obs = [] if self.has_mediator: obs.append(self._generate_mediator_observation(actions, state)) @@ -341,83 +429,111 @@ def _get_initial_state(self) -> EnvState: return EnvState( inner_t=jnp.zeros((), dtype=jnp.int16), outer_t=jnp.zeros((), dtype=jnp.int16), - global_temperature=jnp.array([self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]]), + global_temperature=jnp.array( + [self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]] + ), global_carbon_mass=jnp.array( - [self.dice_constant["xM_AT_0"], self.dice_constant["xM_UP_0"], self.dice_constant["xM_LO_0"]]), + [ + self.dice_constant["xM_AT_0"], + self.dice_constant["xM_UP_0"], + self.dice_constant["xM_LO_0"], + ] + ), global_exogenous_emissions=jnp.zeros((), dtype=float_precision), global_land_emissions=jnp.zeros((), dtype=float_precision), labor_all=self.region_params["xL_0"], capital_all=self.region_params["xK_0"], production_factor_all=self.region_params["xA_0"], intensity_all=self.region_params["xsigma_0"], - balance_all=jnp.zeros(self.num_players, dtype=float_precision), - future_tariff=jnp.zeros((self.num_players, self.num_players), dtype=float_precision), - - gross_output_all=jnp.zeros(self.num_players, dtype=float_precision), + future_tariff=jnp.zeros( + (self.num_players, self.num_players), dtype=float_precision + ), + gross_output_all=jnp.zeros( + self.num_players, dtype=float_precision + ), investment_all=jnp.zeros(self.num_players, dtype=float_precision), production_all=jnp.zeros(self.num_players, dtype=float_precision), utility_all=jnp.zeros(self.num_players, dtype=float_precision), - social_welfare_all=jnp.zeros(self.num_players, dtype=float_precision), + social_welfare_all=jnp.zeros( + self.num_players, dtype=float_precision + ), capital_depreciation_all=jnp.zeros(0, dtype=float_precision), - mitigation_cost_all=jnp.zeros(self.num_players, dtype=float_precision), + mitigation_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), consumption_all=jnp.zeros(self.num_players, dtype=float_precision), damages_all=jnp.zeros(self.num_players, dtype=float_precision), - abatement_cost_all=jnp.zeros(self.num_players, dtype=float_precision), - - tariff_revenue_all=jnp.zeros(self.num_players, dtype=float_precision), - carbon_price_all=jnp.zeros(self.num_players, dtype=float_precision), + abatement_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + tariff_revenue_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + carbon_price_all=jnp.zeros( + self.num_players, dtype=float_precision + ), club_membership_all=jnp.zeros(self.num_players, dtype=jnp.int8), ) - def _generate_observation(self, index: int, actions: chex.ArrayDevice, state: EnvState) -> Array: - return jnp.concatenate([ - # Public features - jnp.asarray([index]), - jnp.asarray([state.inner_t]), - state.global_temperature, - state.global_carbon_mass, - jnp.asarray([state.global_exogenous_emissions]), - jnp.asarray([state.global_land_emissions]), - state.labor_all, - state.capital_all, - state.gross_output_all, - state.consumption_all, - state.investment_all, - state.balance_all, - state.tariff_revenue_all, - state.carbon_price_all, - state.club_membership_all, - # Private features - jnp.asarray([state.damages_all[index]]), - jnp.asarray([state.abatement_cost_all[index]]), - jnp.asarray([state.production_all[index]]), - # All agent actions - actions.ravel() - ], dtype=float_precision) - - def _generate_mediator_observation(self, actions: chex.ArrayDevice, state: EnvState) -> Array: - return jnp.concatenate([ - jnp.zeros(1), - jnp.asarray([state.inner_t]), - state.global_temperature, - state.global_carbon_mass, - jnp.asarray([state.global_exogenous_emissions]), - jnp.asarray([state.global_land_emissions]), - state.labor_all, - state.capital_all, - state.gross_output_all, - state.consumption_all, - state.investment_all, - state.balance_all, - state.tariff_revenue_all, - state.carbon_price_all, - state.club_membership_all, - # Maintain same dimensionality as for other players - jnp.zeros(3), - # All agent actions - actions.ravel() - ], dtype=float_precision) + def _generate_observation( + self, index: int, actions: chex.ArrayDevice, state: EnvState + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.asarray([index]), + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.tariff_revenue_all, + state.carbon_price_all, + state.club_membership_all, + # Private features + jnp.asarray([state.damages_all[index]]), + jnp.asarray([state.abatement_cost_all[index]]), + jnp.asarray([state.production_all[index]]), + # All agent actions + actions.ravel(), + ], + dtype=float_precision, + ) + + def _generate_mediator_observation( + self, actions: chex.ArrayDevice, state: EnvState + ) -> Array: + return jnp.concatenate( + [ + jnp.zeros(1), + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.tariff_revenue_all, + state.carbon_price_all, + state.club_membership_all, + # Maintain same dimensionality as for other players + jnp.zeros(3), + # All agent actions + actions.ravel(), + ], + dtype=float_precision, + ) @property def name(self) -> str: @@ -427,15 +543,17 @@ def name(self) -> str: def num_actions(self) -> int: return self.actions_n - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: return spaces.Box(low=0, high=1, shape=(self.actions_n,)) def observation_space(self, params: EnvParams) -> spaces.Box: init_state = self._get_initial_state() - obs = self._generate_observation(0, jnp.zeros(self.num_actions * self.num_actors), init_state) - return spaces.Box(low=0, high=float('inf'), shape=obs.shape, dtype=float_precision) + obs = self._generate_observation( + 0, jnp.zeros(self.num_actions * self.num_actors), init_state + ) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=float_precision + ) def load_rice_params(config_dir=None): @@ -454,7 +572,7 @@ def load_rice_params(config_dir=None): # _REGIONS is a list of dictionaries base_params["_REGIONS"] = [] - for idx, param in enumerate(region_params): + for param in region_params: region_to_append = param["_RICE_CONSTANT"] for k in base_params["_RICE_CONSTANT_DEFAULT"].keys(): if k not in region_to_append.keys(): @@ -465,20 +583,28 @@ def load_rice_params(config_dir=None): base_params["_REGION_PARAMS"] = {} for k in base_params["_RICE_CONSTANT_DEFAULT"].keys(): base_params["_REGION_PARAMS"][k] = [] - for idx, param in enumerate(region_params): - parameter_value = param["_RICE_CONSTANT"].get(k, base_params["_RICE_CONSTANT_DEFAULT"][k]) + for param in region_params: + parameter_value = param["_RICE_CONSTANT"].get( + k, base_params["_RICE_CONSTANT_DEFAULT"][k] + ) base_params["_REGION_PARAMS"][k].append(parameter_value) - base_params["_REGION_PARAMS"][k] = jnp.asarray(base_params["_REGION_PARAMS"][k], dtype=float_precision) + base_params["_REGION_PARAMS"][k] = jnp.asarray( + base_params["_REGION_PARAMS"][k], dtype=float_precision + ) return base_params, len(region_params) def zero_diag(matrix: jax.Array) -> jax.Array: - return matrix - matrix * jnp.eye(matrix.shape[0], matrix.shape[1], dtype=matrix.dtype) + return matrix - matrix * jnp.eye( + matrix.shape[0], matrix.shape[1], dtype=matrix.dtype + ) def get_exogenous_emissions(f_0, f_1, t_f, timestep): - return f_0 + jnp.min(jnp.array([f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1)])) + return f_0 + jnp.min( + jnp.array([f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1)]) + ) def get_land_emissions(e_l0, delta_el, timestep, num_regions): @@ -499,7 +625,9 @@ def get_abatement_cost(mitigation_rate, mitigation_cost, theta_2): def get_production(production_factor, capital, labor, gamma): """Obtain the amount of goods produced.""" - return production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma) + return ( + production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma) + ) def get_gross_output(damages, abatement_cost, production): @@ -515,7 +643,9 @@ def get_consumption(gross_output, investment, total_exports): def get_max_potential_exports(x_max, gross_output, investment): - return jnp.min(jnp.array([x_max * gross_output, gross_output - investment]), axis=0) + return jnp.min( + jnp.array([x_max * gross_output, gross_output - investment]), axis=0 + ) def get_capital_depreciation(x_delta_k, x_delta): @@ -524,10 +654,11 @@ def get_capital_depreciation(x_delta_k, x_delta): # Returns shape 2 def get_global_temperature( - phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions + phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions ): return jnp.dot(phi_t, temperature) + jnp.dot( - b_t, f_2x * jnp.log(m_at / m_at_1750) / jnp.log(2) + exogenous_emissions + b_t, + f_2x * jnp.log(m_at / m_at_1750) / jnp.log(2) + exogenous_emissions, ) @@ -550,12 +681,20 @@ def get_labor(labor, l_a, l_g): def get_production_factor(production_factor, g_a, delta_a, delta, timestep): return production_factor * ( - jnp.exp(0.0033) + g_a * jnp.exp(-delta_a * delta * (timestep - 1)) + jnp.exp(0.0033) + g_a * jnp.exp(-delta_a * delta * (timestep - 1)) ) -def get_carbon_price(mitigation_cost, intensity, mitigation_rate, theta_2, damages): - return theta_2 * damages * pow(mitigation_rate, theta_2 - 1) * mitigation_cost / intensity +def get_carbon_price( + mitigation_cost, intensity, mitigation_rate, theta_2, damages +): + return ( + theta_2 + * damages + * pow(mitigation_rate, theta_2 - 1) + * mitigation_cost + / intensity + ) def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep): @@ -569,9 +708,9 @@ def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep): def get_utility(labor, consumption, alpha): return ( - (labor / 1000.0) - * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1) - / (1 - alpha) + (labor / 1000.0) + * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1) + / (1 - alpha) ) @@ -580,11 +719,11 @@ def get_social_welfare(utility, rho, delta, timestep): def get_armington_agg( - c_dom, - imports, # np.array - sub_rate=0.5, # in (0,1) - dom_pref=0.5, # in [0,1] - for_pref=None, # np.array + c_dom, + imports, # np.array + sub_rate=0.5, # in (0,1) + dom_pref=0.5, # in [0,1] + for_pref=None, # np.array ): """ Armington aggregate from Lessmann, 2009. @@ -605,7 +744,7 @@ def get_armington_agg( relative preference for foreign goods from that country. """ - c_dom_pref = dom_pref * (c_dom ** sub_rate) + c_dom_pref = dom_pref * (c_dom**sub_rate) # Axis 1 because we consider imports not exports c_for_pref = jnp.sum(for_pref * pow(imports, sub_rate), axis=1) diff --git a/pax/envs/rice/sarl_rice.py b/pax/envs/rice/sarl_rice.py index 46d3b840..9742bd2c 100644 --- a/pax/envs/rice/sarl_rice.py +++ b/pax/envs/rice/sarl_rice.py @@ -16,30 +16,42 @@ class SarlRice(environment.Environment): env_id: str = "SarlRice-N" - def __init__(self, config_folder: str, fixed_mitigation_rate: int = None, episode_length: int = 20): + def __init__( + self, + config_folder: str, + fixed_mitigation_rate: int = None, + episode_length: int = 20, + ): super().__init__() self.rice = Rice(config_folder, episode_length=episode_length) self.fixed_mitigation_rate = fixed_mitigation_rate def _step( - key: chex.PRNGKey, - state: EnvState, - action: chex.Array, - params: EnvParams, + key: chex.PRNGKey, + state: EnvState, + action: chex.Array, + params: EnvParams, ): actions = jnp.asarray(jnp.split(action, self.rice.num_players)) # To facilitate learning and since we want to optimize a social planner # we disable defection actions actions = actions.at[:, self.rice.export_action_index].set(1.0) - actions = actions.at[:, self.rice.tariffs_action_index: - self.rice.tariffs_action_index + self.rice.num_players].set(0.0) + actions = actions.at[ + :, + self.rice.tariffs_action_index : self.rice.tariffs_action_index + + self.rice.num_players, + ].set(0.0) if self.fixed_mitigation_rate is not None: - actions = actions.at[:, self.rice.mitigation_rate_action_index].set(self.fixed_mitigation_rate) + actions = actions.at[ + :, self.rice.mitigation_rate_action_index + ].set(self.fixed_mitigation_rate) # actions = actions.at[:, self.rice.desired_imports_action_index: # self.rice.desired_imports_action_index + self.rice.num_players].set(0.0) - obs, state, rewards, done, info = self.rice.step(key, state, tuple(actions), params) + obs, state, rewards, done, info = self.rice.step( + key, state, tuple(actions), params + ) return ( self._generate_observation(state), @@ -50,7 +62,7 @@ def _step( ) def _reset( - key: chex.PRNGKey, params: EnvParams + key: chex.PRNGKey, params: EnvParams ) -> Tuple[chex.Array, EnvState]: _, state = self.rice.reset(key, params) return self._generate_observation(state), state @@ -59,28 +71,30 @@ def _reset( self.reset = jax.jit(_reset) def _generate_observation(self, state: EnvState): - return jnp.concatenate([ - # Public features - jnp.asarray([state.inner_t]), - state.global_temperature, - state.global_carbon_mass, - jnp.asarray([state.global_exogenous_emissions]), - jnp.asarray([state.global_land_emissions]), - state.labor_all, - state.capital_all, - state.gross_output_all, - state.consumption_all, - state.investment_all, - state.balance_all, - state.production_factor_all, - state.intensity_all, - state.mitigation_cost_all, - state.damages_all, - state.abatement_cost_all, - state.production_all, - state.utility_all, - state.social_welfare_all, - ]) + return jnp.concatenate( + [ + # Public features + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.production_factor_all, + state.intensity_all, + state.mitigation_cost_all, + state.damages_all, + state.abatement_cost_all, + state.production_all, + state.utility_all, + state.social_welfare_all, + ] + ) @property def name(self) -> str: @@ -90,12 +104,12 @@ def name(self) -> str: def num_actions(self) -> int: return self.rice.num_actions * self.rice.num_players - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Box: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: return spaces.Box(low=0, high=1, shape=(self.num_actions,)) def observation_space(self, params: EnvParams) -> spaces.Box: _, state = self.reset(jax.random.PRNGKey(0), params) obs = self._generate_observation(state) - return spaces.Box(low=0, high=float('inf'), shape=obs.shape, dtype=jnp.float32) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=jnp.float32 + ) diff --git a/pax/experiment.py b/pax/experiment.py index a29f3cb3..4467f444 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -199,7 +199,8 @@ def env_setup(args, logger=None): a=args.a, b=args.b, marginal_cost=args.marginal_cost ) env = CournotGame( - num_players=args.num_players, num_inner_steps=args.num_inner_steps, + num_players=args.num_players, + num_inner_steps=args.num_inner_steps, ) if logger: logger.info( @@ -215,7 +216,8 @@ def env_setup(args, logger=None): s_max=args.s_max, ) env = Fishery( - num_players=args.num_players, num_inner_steps=args.num_inner_steps, + num_players=args.num_players, + num_inner_steps=args.num_inner_steps, ) if logger: logger.info( @@ -236,9 +238,15 @@ def env_setup(args, logger=None): env = ClubRice( config_folder=args.config_folder, has_mediator=args.has_mediator, - mediator_climate_objective=args.get("mediator_climate_objective", None), - default_club_mitigation_rate=args.get("default_club_mitigation_rate", None), - default_club_tariff_rate=args.get("default_club_tariff_rate", None), + mediator_climate_objective=args.get( + "mediator_climate_objective", None + ), + default_club_mitigation_rate=args.get( + "default_club_mitigation_rate", None + ), + default_club_tariff_rate=args.get( + "default_club_tariff_rate", None + ), mediator_climate_weight=args.get("mediator_climate_weight", None), mediator_utility_weight=args.get("mediator_utility_weight", None), ) @@ -428,8 +436,12 @@ def get_LOLA_agent(seed, player_id): ) def get_PPO_memory_agent(seed, player_id): - default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None) - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": @@ -445,8 +457,12 @@ def get_PPO_memory_agent(seed, player_id): ) def get_PPO_agent(seed, player_id): - default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None) - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": @@ -462,8 +478,12 @@ def get_PPO_agent(seed, player_id): ) def get_PPO_tabular_agent(seed, player_id): - default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None) - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": @@ -481,8 +501,12 @@ def get_PPO_tabular_agent(seed, player_id): return ppo_agent def get_mfos_agent(seed, player_id): - default_player_args = omegaconf.OmegaConf.select(args, "ppo_default", default=None) - agent_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id), default=default_player_args) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + agent_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": @@ -582,9 +606,15 @@ def get_stay_agent(seed, player_id): logger.info("Using Independent Learners") return agent_1 else: - default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None) - agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in - range(1, args.num_players + 1)] + default_agent = omegaconf.OmegaConf.select( + args, "agent_default", default=None + ) + agent_strategies = [ + omegaconf.OmegaConf.select( + args, "agent" + str(i), default=default_agent + ) + for i in range(1, args.num_players + 1) + ] for strategy in agent_strategies: assert strategy in strategies @@ -598,12 +628,8 @@ def get_stay_agent(seed, player_id): ] agents = [] for idx, strategy in enumerate(agent_strategies): - agents.append( - strategies[strategy](seeds[idx], pids[idx]) - ) - logger.info( - f"Agent Pair: {agents}" - ) + agents.append(strategies[strategy](seeds[idx], pids[idx])) + logger.info(f"Agent Pair: {agents}") logger.info(f"Agent seeds: {seeds}") return agents @@ -613,7 +639,7 @@ def watcher_setup(args, logger): """Set up watcher variables.""" def ppo_memory_log( - agent, + agent, ): losses = losses_ppo(agent) if args.env_id not in [ @@ -717,9 +743,15 @@ def naive_pg_log(agent): return agent_1_log else: agent_log = [] - default_agent = omegaconf.OmegaConf.select(args, "agent_default", default=None) - agent_strategies = [omegaconf.OmegaConf.select(args, "agent" + str(i), default=default_agent) for i in - range(1, args.num_players + 1)] + default_agent = omegaconf.OmegaConf.select( + args, "agent_default", default=None + ) + agent_strategies = [ + omegaconf.OmegaConf.select( + args, "agent" + str(i), default=default_agent + ) + for i in range(1, args.num_players + 1) + ] for strategy in agent_strategies: assert strategy in strategies agent_log.append(strategies[strategy]) diff --git a/pax/runners/runner_eval.py b/pax/runners/runner_eval.py index 68a0275c..c66b95d0 100644 --- a/pax/runners/runner_eval.py +++ b/pax/runners/runner_eval.py @@ -134,7 +134,7 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order + agent_order, ) = carry # unpack rngs @@ -144,7 +144,7 @@ def _inner_rollout(carry, unused): a1_actions = [] new_a1_memories = [] - for _obs, _mem in zip(obs1, a1_mem): + for _obs, _mem in zip(obs1, a1_mem, strict=True): a1_action, a1_state, new_a1_memory = agent1.batch_policy( a1_state, _obs, @@ -155,7 +155,7 @@ def _inner_rollout(carry, unused): a2_actions = [] new_a2_memories = [] - for _obs, _mem in zip(obs2, a2_mem): + for _obs, _mem in zip(obs2, a2_mem, strict=True): a2_action, a2_state, new_a2_memory = agent2.batch_policy( a2_state, _obs, @@ -177,37 +177,57 @@ def _inner_rollout(carry, unused): rewards = jnp.asarray(rewards)[inv_agent_order] a1_trajectories = [ - Sample(observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden) for observation, action, reward, new_memory, memory in - zip(obs1, a2_actions, rewards[:self.args.agent1_roles], new_a1_memories, a1_mem)] + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs1, + a2_actions, + rewards[: self.args.agent1_roles], + new_a1_memories, + a1_mem, + strict=True, + ) + ] a2_trajectories = [ - Sample(observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden) for observation, action, reward, new_memory, memory in - zip(obs2, a2_actions, rewards[self.args.agent1_roles:], new_a2_memories, a2_mem)] + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[self.args.agent1_roles :], + new_a2_memories, + a2_mem, + strict=True, + ) + ] return ( rngs, - tuple(obs[:self.args.agent1_roles]), - tuple(obs[self.args.agent1_roles:]), - tuple(rewards[:self.args.agent1_roles]), - tuple(rewards[self.args.agent1_roles:]), + tuple(obs[: self.args.agent1_roles]), + tuple(obs[self.args.agent1_roles :]), + tuple(rewards[: self.args.agent1_roles]), + tuple(rewards[self.args.agent1_roles :]), a1_state, tuple(new_a1_memories), a2_state, tuple(new_a2_memories), env_state, env_params, - agent_order + agent_order, ), ( a1_trajectories, a2_trajectories, @@ -235,7 +255,7 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order + agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": @@ -247,7 +267,9 @@ def _outer_rollout(carry, unused): a2_metrics = {} else: new_a2_memories = [] - for _obs, mem, traj in zip(obs2, a2_mem, stack[1]): + for _obs, mem, traj in zip( + obs2, a2_mem, stack[1], strict=True + ): a2_state, a2_mem, a2_metrics = agent2.batch_update( traj, _obs, @@ -268,7 +290,7 @@ def _outer_rollout(carry, unused): new_a2_memories, env_state, env_params, - agent_order + agent_order, ), (*stack, a2_metrics) self.rollout = jax.jit(_outer_rollout) @@ -283,7 +305,9 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): a1_state, a1_mem = agent1._state, agent1._mem a2_state, a2_mem = agent2._state, agent2._mem - preload_agent_2 = self.model_path2 is not None and self.run_path2 is not None + preload_agent_2 = ( + self.model_path2 is not None and self.run_path2 is not None + ) if watchers and not self.args.wandb.mode == "offline": wandb.restore( @@ -291,7 +315,9 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): ) if preload_agent_2: wandb.restore( - name=self.model_path2, run_path=self.run_path2, root=os.getcwd() + name=self.model_path2, + run_path=self.run_path2, + root=os.getcwd(), ) pretrained_params = load(self.model_path) @@ -315,12 +341,14 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): for i in range(num_episodes): rng, _ = jax.random.split(rng) rngs = jnp.concatenate( - [jax.random.split(rng, self.args.num_envs)] * self.args.num_opps + [jax.random.split(rng, self.args.num_envs)] + * self.args.num_opps ).reshape((self.args.num_opps, self.args.num_envs, -1)) obs, env_state = env.reset(rngs, env_params) rewards = [jnp.zeros((self.args.num_opps, self.args.num_envs))] * ( - self.args.agent1_roles + self.args.agent2_roles) + self.args.agent1_roles + self.args.agent2_roles + ) if i % self.args.agent2_reset_interval == 0: if self.args.agent2 == "NaiveEx": @@ -333,9 +361,12 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): if preload_agent_2: # If we are preloading agent 2 we want to keep the state - a2_state = a2_state._replace(params=jax.tree_util.tree_map( - lambda x: jnp.expand_dims(x, 0), a2_pretrained_params - )) + a2_state = a2_state._replace( + params=jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, 0), + a2_pretrained_params, + ) + ) a2_mem = (a2_mem,) * self.args.agent2_roles agent_order = jnp.arange(self.args.num_players) @@ -347,17 +378,17 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): self.rollout, ( rngs, - tuple(obs[:self.args.agent1_roles]), - tuple(obs[self.args.agent1_roles:]), - tuple(rewards[:self.args.agent1_roles]), - tuple(rewards[self.args.agent1_roles:]), + tuple(obs[: self.args.agent1_roles]), + tuple(obs[self.args.agent1_roles :]), + tuple(rewards[: self.args.agent1_roles]), + tuple(rewards[self.args.agent1_roles :]), a1_state, a1_mem, a2_state, a2_mem, env_state, env_params, - agent_order + agent_order, ), None, length=self.args.num_steps, @@ -376,7 +407,7 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): _a2_mem, env_state, env_params, - agent_order + agent_order, ) = vals traj_1, traj_2, env_states, a2_metrics = stack @@ -396,27 +427,41 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): rewards_2 = rewards_2.sum(axis=1).mean() elif self.args.env_id == "Fishery": - env_stats = fishery_stats(traj_1 + traj_2, self.args.num_players) + env_stats = fishery_stats( + traj_1 + traj_2, self.args.num_players + ) rewards_1 = rewards_1.sum(axis=1).mean() rewards_2 = rewards_2.sum(axis=1).mean() elif self.args.env_id == "Rice-N": - env_stats = rice_eval_stats(traj_1 + traj_2, env_states, env) + env_stats = rice_eval_stats( + traj_1 + traj_2, env_states, env + ) env_stats = jax.tree_util.tree_map( lambda x: x.tolist(), env_stats, ) - env_stats = env_stats | rice_stats(traj_1 + traj_2, self.args.num_players, self.args.has_mediator) + env_stats = env_stats | rice_stats( + traj_1 + traj_2, + self.args.num_players, + self.args.has_mediator, + ) rewards_1 = rewards_1.sum(axis=1).mean() rewards_2 = rewards_2.sum(axis=1).mean() elif self.args.env_id == "C-Rice-N": - env_stats = c_rice_eval_stats(traj_1 + traj_2, env_states, env) + env_stats = c_rice_eval_stats( + traj_1 + traj_2, env_states, env + ) env_stats = jax.tree_util.tree_map( lambda x: x.tolist(), env_stats, ) - env_stats = env_stats | c_rice_stats(traj_1 + traj_2, self.args.num_players, self.args.has_mediator) + env_stats = env_stats | c_rice_stats( + traj_1 + traj_2, + self.args.num_players, + self.args.has_mediator, + ) rewards_1 = rewards_1.sum(axis=1).mean() rewards_2 = rewards_2.sum(axis=1).mean() @@ -441,11 +486,17 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): env_stats = {} # Due to the permutation and roles the rewards are not in order - agent_indices = jnp.array([0] * self.args.agent1_roles + [1] * self.args.agent2_roles) + agent_indices = jnp.array( + [0] * self.args.agent1_roles + [1] * self.args.agent2_roles + ) agent_indices = agent_indices[agent_order] # Omit rich per timestep statistics for cleaner logging - printable_env_stats = {k: v for k, v in env_stats.items() if not k.startswith("states")} + printable_env_stats = { + k: v + for k, v in env_stats.items() + if not k.startswith("states") + } print(f"Env Stats: {printable_env_stats}") print( f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" @@ -458,10 +509,10 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics ) agent2._logger.metrics = ( - agent2._logger.metrics | flattened_metrics + agent2._logger.metrics | flattened_metrics ) - for watcher, agent in zip(watchers, agents): + for watcher, agent in zip(watchers, agents, strict=True): watcher(agent) wandb.log( { diff --git a/pax/runners/runner_eval_multishaper.py b/pax/runners/runner_eval_multishaper.py index a7353944..ea2f698f 100644 --- a/pax/runners/runner_eval_multishaper.py +++ b/pax/runners/runner_eval_multishaper.py @@ -113,7 +113,7 @@ def _reshape_opp_dim(x): self.num_targets = args.num_players - args.num_shapers self.num_outer_steps = self.args.num_outer_steps shapers = agents[: self.num_shapers] - targets = agents[self.num_shapers:] + targets = agents[self.num_shapers :] # set up agents # batch MemoryState not TrainingState for agent_idx, shaper_agent in enumerate(shapers): @@ -252,10 +252,10 @@ def _inner_rollout(carry, unused): env_params, ) shapers_next_obs = all_agent_next_obs[: self.num_shapers] - targets_next_obs = all_agent_next_obs[self.num_shapers:] + targets_next_obs = all_agent_next_obs[self.num_shapers :] shapers_reward, targets_reward = ( all_agent_rewards[: self.num_shapers], - all_agent_rewards[self.num_shapers:], + all_agent_rewards[self.num_shapers :], ) shapers_traj = [ @@ -329,7 +329,7 @@ def _outer_rollout(carry, unused): ) # update second agent - targets_traj = trajectories[self.num_shapers:] + targets_traj = trajectories[self.num_shapers :] for agent_idx, target_agent in enumerate(targets): ( targets_state[agent_idx], @@ -356,10 +356,10 @@ def _outer_rollout(carry, unused): ), (trajectories, targets_metrics) def _rollout( - _rng_run: jnp.ndarray, - _shapers_state: List[TrainingState], - _shapers_mem: List[MemoryState], - _env_params: Any, + _rng_run: jnp.ndarray, + _shapers_state: List[TrainingState], + _shapers_mem: List[MemoryState], + _env_params: Any, ): # env reset rngs = jnp.concatenate( @@ -368,10 +368,10 @@ def _rollout( obs, env_state = env.reset(rngs, _env_params) shapers_obs = obs[: self.num_shapers] - targets_obs = obs[self.num_shapers:] + targets_obs = obs[self.num_shapers :] rewards = [ - jnp.zeros((args.num_opps, args.num_envs), dtype=jnp.float32) - ] * args.num_players + jnp.zeros((args.num_opps, args.num_envs), dtype=jnp.float32) + ] * args.num_players # Player 1 for agent_idx, shaper_agent in enumerate(shapers): _shapers_mem[agent_idx] = shaper_agent.batch_reset( @@ -411,9 +411,9 @@ def _rollout( ( rngs, tuple(obs[: self.num_shapers]), - tuple(obs[self.num_shapers:]), + tuple(obs[self.num_shapers :]), tuple(rewards[: self.num_shapers]), - tuple(rewards[self.num_shapers:]), + tuple(rewards[self.num_shapers :]), _shapers_state, targets_state, _shapers_mem, @@ -440,7 +440,7 @@ def _rollout( ) = vals trajectories, targets_metrics = stack shapers_traj = trajectories[: self.num_shapers] - targets_traj = trajectories[self.num_shapers:] + targets_traj = trajectories[self.num_shapers :] # reset memory for agent_idx, shaper_agent in enumerate(shapers): @@ -493,7 +493,7 @@ def run_loop(self, env_params, agents, watchers): print("Training") print("-----------------------") shaper_agents = agents[: self.num_shapers] - target_agents = agents[self.num_shapers:] + target_agents = agents[self.num_shapers :] rng, _ = jax.random.split(self.random_key) # get initial state and memory @@ -587,7 +587,7 @@ def run_loop(self, env_params, agents, watchers): for traj in list_traj1 ] shaper_traj = trajectories[: self.num_shapers] - target_traj = trajectories[self.num_shapers:] + target_traj = trajectories[self.num_shapers :] # log agent one watchers[0](agents[0]) @@ -620,7 +620,7 @@ def run_loop(self, env_params, agents, watchers): for traj in trajectories ) ) - / len(trajectories) + / len(trajectories) } for i in range(len(list_of_env_stats)) ] @@ -632,7 +632,7 @@ def run_loop(self, env_params, agents, watchers): for traj in shaper_traj ) ) - / len(shaper_traj) + / len(shaper_traj) } for i in range(len(list_of_env_stats)) ] @@ -644,7 +644,7 @@ def run_loop(self, env_params, agents, watchers): for traj in target_traj ) ) - / len(target_traj) + / len(target_traj) } for i in range(len(list_of_env_stats)) ] diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index 92b3cbde..f5b4e459 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -46,8 +46,10 @@ class EvoRunner: A tuple of experiment arguments used (usually provided by HydraConfig). """ + # TODO fix C901 (function too complex) + # flake8: noqa: C901 def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args + self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): self.args = args self.algo = args.es.algo @@ -185,7 +187,7 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order + agent_order, ) = carry # unpack rngs @@ -200,7 +202,7 @@ def _inner_rollout(carry, unused): ) a2_actions = [] new_a2_memories = [] - for _obs, _mem in zip(obs2, a2_mem): + for _obs, _mem in zip(obs2, a2_mem, strict=True): a2_action, a2_state, new_a2_memory = agent2.batch_policy( a2_state, _obs, @@ -231,14 +233,24 @@ def _inner_rollout(carry, unused): a1_mem.hidden, ) a2_trajectories = [ - Sample(observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden) for observation, action, reward, new_memory, memory in - zip(obs2, a2_actions, rewards[1:], new_a2_memories, a2_mem)] + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[1:], + new_a2_memories, + a2_mem, + strict=True, + ) + ] return ( rngs, @@ -252,7 +264,7 @@ def _inner_rollout(carry, unused): tuple(new_a2_memories), env_state, env_params, - agent_order + agent_order, ), ( traj1, a2_trajectories, @@ -279,7 +291,7 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, - agent_order + agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": @@ -287,7 +299,9 @@ def _outer_rollout(carry, unused): # update second agent new_a2_memories = [] - for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]): + for _obs, mem, traj in zip( + obs2, a2_mem, trajectories[1], strict=True + ): a2_state, a2_mem, a2_metrics = agent2.batch_update( traj, _obs, @@ -307,16 +321,16 @@ def _outer_rollout(carry, unused): tuple(new_a2_memories), env_state, env_params, - agent_order + agent_order, ), (*trajectories, a2_metrics) def _rollout( - _params: jnp.ndarray, - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _a2_state: TrainingState, - _env_params: Any, + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _a2_state: TrainingState, + _env_params: Any, ): # env reset env_rngs = jnp.concatenate( @@ -326,8 +340,12 @@ def _rollout( ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) obs, env_state = env.reset(env_rngs, _env_params) - rewards = [jnp.zeros((args.popsize, args.num_opps, args.num_envs), dtype=float_precision)] * ( - 1 + args.agent2_roles) + rewards = [ + jnp.zeros( + (args.popsize, args.num_opps, args.num_envs), + dtype=float_precision, + ) + ] * (1 + args.agent2_roles) # Player 1 _a1_state = _a1_state._replace(params=_params) @@ -369,7 +387,7 @@ def _rollout( (a2_mem,) * args.agent2_roles, env_state, _env_params, - agent_order + agent_order, ), None, length=self.num_outer_steps, @@ -387,13 +405,15 @@ def _rollout( a2_mem, env_state, _env_params, - agent_order + agent_order, ) = vals traj_1, traj_2, a2_metrics = stack # Fitness fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - agent_2_rewards = jnp.concatenate([traj.rewards for traj in traj_2]) + agent_2_rewards = jnp.concatenate( + [traj.rewards for traj in traj_2] + ) other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) rewards_1 = traj_1.rewards.mean() rewards_2 = agent_2_rewards.mean() @@ -431,18 +451,18 @@ def _rollout( elif args.env_id == "Cournot": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), - self.cournot_stats( - traj_1.observations, - _env_params, - 2 - ), + self.cournot_stats(traj_1.observations, _env_params, 2), ) elif args.env_id == "Fishery": env_stats = fishery_stats([traj_1] + traj_2, args.num_players) elif args.env_id == "Rice-N": - env_stats = rice_stats([traj_1] + traj_2, args.num_players, args.has_mediator) + env_stats = rice_stats( + [traj_1] + traj_2, args.num_players, args.has_mediator + ) elif args.env_id == "C-Rice-N": - env_stats = c_rice_stats([traj_1] + traj_2, args.num_players, args.has_mediator) + env_stats = c_rice_stats( + [traj_1] + traj_2, args.num_players, args.has_mediator + ) else: env_stats = {} @@ -453,7 +473,7 @@ def _rollout( rewards_1, rewards_2, a2_metrics, - a2_state + a2_state, ) self.rollout = jax.pmap( @@ -466,11 +486,11 @@ def _rollout( ) def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, + self, + env_params, + agents, + num_iters: int, + watchers: Callable, ): """Run training of agents in environment""" print("Training") @@ -538,8 +558,8 @@ def run_loop( a2_state = None if self.args.num_devices == 1 and a2_state is not None: - # The first rollout returns a2_state with an extra batch dim that will cause issues when passing - # it back to the vmapped batch_policy + # The first rollout returns a2_state with an extra batch dim that + # will cause issues when passing it back to the vmapped batch_policy a2_state = jax.tree_util.tree_map( lambda w: jnp.squeeze(w, axis=0), a2_state ) @@ -552,11 +572,15 @@ def run_loop( rewards_1, rewards_2, a2_metrics, - a2_state - ) = self.rollout(params, rng_run, a1_state, a1_mem, a2_state, env_params) + a2_state, + ) = self.rollout( + params, rng_run, a1_state, a1_mem, a2_state, env_params + ) # Aggregate over devices - fitness = jnp.reshape(fitness, popsize * self.args.num_devices).astype(dtype=jnp.float32) + fitness = jnp.reshape( + fitness, popsize * self.args.num_devices + ).astype(dtype=jnp.float32) env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) # Tell @@ -572,10 +596,12 @@ def run_loop( is_last_loop = gen == num_iters - 1 # Saving if gen % self.args.save_interval == 0 or is_last_loop: - log_savepath1 = os.path.join(self.save_dir, f"generation_{gen}") + log_savepath1 = os.path.join( + self.save_dir, f"generation_{gen}" + ) if self.args.num_devices > 1: top_params = param_reshaper.reshape( - log["top_gen_params"][0: self.args.num_devices] + log["top_gen_params"][0 : self.args.num_devices] ) top_params = jax.tree_util.tree_map( lambda x: x[0].reshape(x[0].shape[1:]), top_params @@ -655,7 +681,9 @@ def run_loop( wandb_log.update(env_stats) # loop through population for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) + zip( + log["top_fitness"], log["top_gen_fitness"], strict=True + ) ): wandb_log[ f"train/fitness/top_overall_agent_{idx + 1}" @@ -671,7 +699,7 @@ def run_loop( ) agent2._logger.metrics.update(flattened_metrics) - for watcher, agent in zip(watchers, agents): + for watcher, agent in zip(watchers, agents, strict=True): watcher(agent) wandb_log = jax.tree_util.tree_map( lambda x: x.item() if isinstance(x, jax.Array) else x, diff --git a/pax/runners/runner_evo_multishaper.py b/pax/runners/runner_evo_multishaper.py index 31f0fdcd..6aae4902 100644 --- a/pax/runners/runner_evo_multishaper.py +++ b/pax/runners/runner_evo_multishaper.py @@ -57,7 +57,7 @@ class MultishaperEvoRunner: """ def __init__( - self, agents, env, strategy, es_params, param_reshaper, save_dir, args + self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): self.args = args self.algo = args.es.algo @@ -108,7 +108,7 @@ def __init__( self.num_targets = args.num_players - args.num_shapers self.num_outer_steps = args.num_outer_steps shapers = agents[: self.num_shapers] - targets = agents[self.num_shapers:] + targets = agents[self.num_shapers :] # vmap agents accordingly # shapers are batched over popsize and num_opps @@ -263,10 +263,10 @@ def _inner_rollout(carry, unused): env_params, ) shapers_next_obs = all_agent_next_obs[: self.num_shapers] - targets_next_obs = all_agent_next_obs[self.num_shapers:] + targets_next_obs = all_agent_next_obs[self.num_shapers :] shapers_reward, targets_reward = ( all_agent_rewards[: self.num_shapers], - all_agent_rewards[self.num_shapers:], + all_agent_rewards[self.num_shapers :], ) shapers_traj = [ @@ -344,7 +344,7 @@ def _outer_rollout(carry, unused): shapers_mem[agent_idx] ) # update opponents - targets_traj = trajectories[self.num_shapers:] + targets_traj = trajectories[self.num_shapers :] for agent_idx, target_agent in enumerate(targets): ( targets_state[agent_idx], @@ -371,11 +371,11 @@ def _outer_rollout(carry, unused): ), (trajectories, targets_metrics) def _rollout( - _params: List[jnp.ndarray], - _rng_run: jnp.ndarray, - _shapers_state: List[TrainingState], - _shapers_mem: List[MemoryState], - _env_params: Any, + _params: List[jnp.ndarray], + _rng_run: jnp.ndarray, + _shapers_state: List[TrainingState], + _shapers_mem: List[MemoryState], + _env_params: Any, ): # env reset env_rngs = jnp.concatenate( @@ -386,10 +386,10 @@ def _rollout( obs, env_state = env.reset(env_rngs, _env_params) shapers_obs = obs[: self.num_shapers] - targets_obs = obs[self.num_shapers:] + targets_obs = obs[self.num_shapers :] rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] * args.num_players + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] * args.num_players # Shapers for agent_idx, shaper_agent in enumerate(shapers): @@ -439,9 +439,9 @@ def _rollout( ( env_rngs, tuple(obs[: self.num_shapers]), - tuple(obs[self.num_shapers:]), + tuple(obs[self.num_shapers :]), tuple(rewards[: self.num_shapers]), - tuple(rewards[self.num_shapers:]), + tuple(rewards[self.num_shapers :]), _shapers_state, targets_state, _shapers_mem, @@ -467,7 +467,7 @@ def _rollout( ) = vals trajectories, targets_metrics = stack shapers_traj = trajectories[: self.num_shapers] - targets_traj = trajectories[self.num_shapers:] + targets_traj = trajectories[self.num_shapers :] # Fitness shapers_fitness = [ @@ -519,11 +519,11 @@ def _rollout( ) def run_loop( - self, - env_params, - agents, - num_iters: int, - watchers: Callable, + self, + env_params, + agents, + num_iters: int, + watchers: Callable, ): """Run training of agents in environment""" print("Training") @@ -561,18 +561,17 @@ def run_loop( z_score=self.args.es.z_score, ) es_logging = [ - ESLog( - param_reshaper.total_params, - num_gens, - top_k=self.top_k, - maximize=True, - ) - ] * self.num_shapers + ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + ] * self.num_shapers logs = [es_log.initialize() for es_log in es_logging] # Reshape a single agent's params before vmapping shaper_agents = agents[: self.num_shapers] - target_agents = agents[self.num_shapers:] init_hiddens = [ jnp.tile( @@ -601,7 +600,7 @@ def run_loop( shapers_params = [] old_evo_states = evo_states evo_states = [] - for shaper_idx, shaper_agent in enumerate(shaper_agents): + for shaper_idx, _ in enumerate(shaper_agents): shaper_rng, rng_evo = jax.random.split(rng_evo, 2) x, evo_state = strategy.ask( shaper_rng, old_evo_states[shaper_idx], es_params @@ -636,21 +635,23 @@ def run_loop( # Tell fitness_re = [ fit_shaper.apply(x, fitness) - for x, fitness in zip(xs, shapers_fitness) + for x, fitness in zip(xs, shapers_fitness, strict=True) ] if self.args.es.mean_reduce: fitness_re = [fit_re - fit_re.mean() for fit_re in fitness_re] evo_states = [ strategy.tell(x, fit_re, evo_state, es_params) - for x, fit_re, evo_state in zip(xs, fitness_re, evo_states) + for x, fit_re, evo_state in zip( + xs, fitness_re, evo_states, strict=True + ) ] # Logging logs = [ es_log.update(log, x, fitness) for es_log, log, x, fitness in zip( - es_logging, logs, xs, shapers_fitness + es_logging, logs, xs, shapers_fitness, strict=True ) ] # Saving @@ -662,7 +663,7 @@ def run_loop( if self.args.num_devices > 1: top_params = param_reshaper.reshape( logs[shaper_idx]["top_gen_params"][ - 0: self.args.num_devices + 0 : self.args.num_devices ] ) top_params = jax.tree_util.tree_map( @@ -726,7 +727,9 @@ def run_loop( ] rewards_strs = shaper_rewards_strs + target_rewards_strs rewards_val = shaper_rewards_val + target_rewards_val - rewards_dict = dict(zip(rewards_strs, rewards_val)) + rewards_dict = dict( + zip(rewards_strs, rewards_val, strict=True) + ) shaper_fitness_str = [ "train/fitness/shaper_" + str(i) @@ -745,7 +748,9 @@ def run_loop( fitness_strs = shaper_fitness_str + target_fitness_str fitness_vals = shaper_fitness_val + target_fitness_val - fitness_dict = dict(zip(fitness_strs, fitness_vals)) + fitness_dict = dict( + zip(fitness_strs, fitness_vals, strict=True) + ) shaper_welfare = float( sum([reward.mean() for reward in shapers_rewards]) @@ -765,22 +770,22 @@ def run_loop( / self.args.num_players ) wandb_log = { - "train_iteration": gen, - # "train/fitness/top_overall_mean": log["log_top_mean"][gen], - # "train/fitness/top_overall_std": log["log_top_std"][gen], - # "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], - # "train/fitness/top_gen_std": log["log_top_gen_std"][gen], - # "train/fitness/gen_std": log["log_gen_std"][gen], - "train/time/minutes": float( - (time.time() - self.start_time) / 60 - ), - "train/time/seconds": float( - (time.time() - self.start_time) - ), - "train/welfare/shaper": shaper_welfare, - "train/welfare/target": target_welfare, - "train/global_welfare": global_welfare, - } | rewards_dict + "train_iteration": gen, + # "train/fitness/top_overall_mean": log["log_top_mean"][gen], + # "train/fitness/top_overall_std": log["log_top_std"][gen], + # "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + # "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + # "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/welfare/shaper": shaper_welfare, + "train/welfare/target": target_welfare, + "train/global_welfare": global_welfare, + } | rewards_dict wandb_log = wandb_log | fitness_dict wandb_log.update(env_stats) # # loop through population @@ -796,7 +801,9 @@ def run_loop( # other player metrics # metrics [outer_timesteps, num_opps] - for agent, metrics in zip(agents[1:], targets_metrics): + for agent, metrics in zip( + agents[1:], targets_metrics, strict=True + ): flattened_metrics = jax.tree_util.tree_map( lambda x: jnp.sum(jnp.mean(x, 1)), metrics ) diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index fc53cd16..92d1511c 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -233,15 +233,16 @@ def _inner_rollout(carry, unused): a1_mem.hidden, a1_mem.th, ) - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) + else: + traj1 = Sample( + obs1, + a1, + rewards[0], + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) traj2 = Sample( obs2, a2, @@ -316,12 +317,12 @@ def _outer_rollout(carry, unused): ), (*trajectories, a2_metrics) def _rollout( - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _a1_mem: MemoryState, - _a2_state: TrainingState, - _a2_mem: MemoryState, - _env_params: Any, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _a2_state: TrainingState, + _a2_mem: MemoryState, + _env_params: Any, ): # env reset rngs = jnp.concatenate( @@ -452,11 +453,7 @@ def _rollout( elif args.env_id == "Cournot": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), - self.cournot_stats( - traj_1.observations, - _env_params, - 2 - ), + self.cournot_stats(traj_1.observations, _env_params, 2), ) elif args.env_id == "Fishery": env_stats = fishery_stats([traj_1, traj_2], 2) @@ -546,14 +543,14 @@ def run_loop(self, env_params, agents, num_iters, watchers): ) if self.args.agent1 != "LOLA": agent1._logger.metrics = ( - agent1._logger.metrics | flattened_metrics_1 + agent1._logger.metrics | flattened_metrics_1 ) # metrics [outer_timesteps, num_opps] flattened_metrics_2 = jax.tree_util.tree_map( lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics ) agent2._logger.metrics = ( - agent2._logger.metrics | flattened_metrics_2 + agent2._logger.metrics | flattened_metrics_2 ) for watcher, agent in zip(watchers, agents): diff --git a/pax/runners/runner_marl_nplayer.py b/pax/runners/runner_marl_nplayer.py index f3f78d19..5244a098 100644 --- a/pax/runners/runner_marl_nplayer.py +++ b/pax/runners/runner_marl_nplayer.py @@ -24,6 +24,7 @@ class LOLASample(NamedTuple): rewards_self: jnp.ndarray rewards_other: jnp.ndarray + class Sample(NamedTuple): """Object containing a batch of data""" @@ -349,12 +350,12 @@ def _outer_rollout(carry, unused): ), (trajectories, other_agent_metrics) def _rollout( - _rng_run: jnp.ndarray, - first_agent_state: TrainingState, - first_agent_mem: MemoryState, - other_agent_state: List[TrainingState], - other_agent_mem: List[MemoryState], - _env_params: Any, + _rng_run: jnp.ndarray, + first_agent_state: TrainingState, + first_agent_mem: MemoryState, + other_agent_state: List[TrainingState], + other_agent_mem: List[MemoryState], + _env_params: Any, ): # env reset rngs = jnp.concatenate( @@ -363,8 +364,8 @@ def _rollout( obs, env_state = env.reset(rngs, _env_params) rewards = [ - jnp.zeros((args.num_opps, args.num_envs)), - ] * args.num_players + jnp.zeros((args.num_opps, args.num_envs)), + ] * args.num_players # Player 1 first_agent_mem = agent1.batch_reset(first_agent_mem, False) # Other players @@ -502,7 +503,9 @@ def _rollout( total_env_stats = jax.tree_util.tree_map( lambda x: x, self.cournot_stats( - trajectories[0].observations, _env_params, args.num_players + trajectories[0].observations, + _env_params, + args.num_players, ), ) elif args.env_id == "Fishery": @@ -541,7 +544,6 @@ def run_loop(self, env_params, agents, num_iters, watchers): other_agent_mem = [None] * len(other_agents) other_agent_state = [None] * len(other_agents) - for agent_idx, non_first_agent in enumerate(other_agents): other_agent_state[agent_idx], other_agent_mem[agent_idx] = ( non_first_agent._state, @@ -602,14 +604,16 @@ def run_loop(self, env_params, agents, num_iters, watchers): ) if self.args.agent1 != "LOLA": agent1._logger.metrics = ( - agent1._logger.metrics | flattened_metrics_1 + agent1._logger.metrics | flattened_metrics_1 ) - for agent, metric in zip(other_agents, other_agent_metrics): + for agent, metric in zip( + other_agents, other_agent_metrics + ): flattened_metrics = jax.tree_util.tree_map( lambda x: jnp.mean(x), first_agent_metrics ) agent._logger.metrics = ( - agent._logger.metrics | flattened_metrics + agent._logger.metrics | flattened_metrics ) for watcher, agent in zip(watchers, agents): diff --git a/pax/runners/runner_weight_sharing.py b/pax/runners/runner_weight_sharing.py index 36e2201d..291d6ec4 100644 --- a/pax/runners/runner_weight_sharing.py +++ b/pax/runners/runner_weight_sharing.py @@ -20,6 +20,7 @@ class WeightSharingRunner: """Holds the runner's state.""" + id = "weight_sharing" def __init__(self, agent, env, save_dir, args): @@ -93,14 +94,19 @@ def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]: ) trajectories = [ - Sample(observation, - action, - reward * jnp.logical_not(done), - new_memory.extras["log_probs"], - new_memory.extras["values"], - done, - memory.hidden) for observation, action, reward, memory, new_memory in - zip(obs, actions, rewards, memories, new_memories)] + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, memory, new_memory in zip( + obs, actions, rewards, memories, new_memories, strict=True + ) + ] return ( rngs, @@ -112,10 +118,10 @@ def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]: ), trajectories def _rollout( - _rng_run: jnp.ndarray, - _a1_state: TrainingState, - _memories: List[MemoryState], - _env_params: Any, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _memories: List[MemoryState], + _env_params: Any, ): # env reset rngs = jnp.concatenate( @@ -164,12 +170,22 @@ def _rollout( rewards = [] num_episodes = jnp.sum(trajectories[0].dones) for traj in trajectories: - rewards.append(jnp.where(num_episodes != 0, jnp.sum(traj.rewards) / num_episodes, 0)) + rewards.append( + jnp.where( + num_episodes != 0, + jnp.sum(traj.rewards) / num_episodes, + 0, + ) + ) env_stats = {} if args.env_id == "Rice-N": - env_stats = rice_stats(trajectories, args.num_players, args.has_mediator) + env_stats = rice_stats( + trajectories, args.num_players, args.has_mediator + ) elif args.env_id == "C-Rice-N": - env_stats = c_rice_stats(trajectories, args.num_players, args.has_mediator) + env_stats = c_rice_stats( + trajectories, args.num_players, args.has_mediator + ) elif args.env_id == "Fishery": env_stats = fishery_stats(trajectories, args.num_players) @@ -234,7 +250,7 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): lambda x: jnp.mean(x), a1_metrics ) agent._logger.metrics = ( - agent._logger.metrics | flattened_metrics_1 + agent._logger.metrics | flattened_metrics_1 ) watcher(agent) diff --git a/pax/utils.py b/pax/utils.py index c9e4c32e..3acbd7d1 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -5,7 +5,6 @@ import chex import haiku as hk import jax -import jax.numpy as jnp import numpy as np import optax from jax import numpy as jnp diff --git a/pax/version.py b/pax/version.py index 35bd3280..3e6e1313 100644 --- a/pax/version.py +++ b/pax/version.py @@ -1,2 +1,2 @@ -__version__ = '0.1.0b+5fbcd0d' -git_version = '5fbcd0d7e6925d156162e86b1e4ecc6e0d3c1a61' +__version__ = "0.1.0b+5fbcd0d" +git_version = "5fbcd0d7e6925d156162e86b1e4ecc6e0d3c1a61" diff --git a/pax/watchers/__init__.py b/pax/watchers/__init__.py index 24c921eb..089842b5 100644 --- a/pax/watchers/__init__.py +++ b/pax/watchers/__init__.py @@ -42,7 +42,7 @@ def policy_logger(agent) -> dict: ] # [layer_name]['w'] log_pi = nn.softmax(weights) probs = { - "policy/" + str(s): p[0] for (s, p) in zip(State, log_pi) + "policy/" + str(s): p[0] for (s, p) in zip(State, log_pi, strict=True) } # probability of cooperating is p[0] return probs @@ -50,10 +50,14 @@ def policy_logger(agent) -> dict: def value_logger(agent) -> dict: weights = agent.critic_optimizer.target["Dense_0"]["kernel"] values = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + f"value/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, weights, strict=True) } values.update( - {f"value/{str(s)}.defect": p[1] for (s, p) in zip(State, weights)} + { + f"value/{str(s)}.defect": p[1] + for (s, p) in zip(State, weights, strict=True) + } ) return values @@ -66,12 +70,12 @@ def policy_logger_dqn(agent) -> None: target_steps = agent.target_step_updates probs = { f"policy/player_{str(pid)}/{str(s)}.cooperate": p[0] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } probs.update( { f"policy/player_{str(pid)}/{str(s)}.defect": p[1] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } ) probs.update({"policy/target_step_updates": target_steps}) @@ -84,12 +88,12 @@ def value_logger_dqn(agent) -> dict: target_steps = agent.target_step_updates values = { f"value/player_{str(pid)}/{str(s)}.cooperate": p[0] - for (s, p) in zip(State, weights) + for (s, p) in zip(State, weights, strict=True) } values.update( { f"value/player_{str(pid)}/{str(s)}.defect": p[1] - for (s, p) in zip(State, weights) + for (s, p) in zip(State, weights, strict=True) } ) values.update({"value/target_step_updates": target_steps}) @@ -100,7 +104,10 @@ def policy_logger_ppo(agent: PPO) -> dict: weights = agent._state.params["categorical_value_head/~/linear"]["w"] pi = nn.softmax(weights) sgd_steps = agent._total_steps / agent._num_steps - probs = {f"policy/{str(s)}.cooperate": p[0] for (s, p) in zip(State, pi)} + probs = { + f"policy/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, pi, strict=True) + } probs.update({"policy/total_steps": sgd_steps}) return probs @@ -111,7 +118,8 @@ def value_logger_ppo(agent: PPO) -> dict: ] # 5 x 1 matrix sgd_steps = agent._total_steps / agent._num_steps probs = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + f"value/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, weights, strict=True) } probs.update({"value/total_steps": sgd_steps}) return probs @@ -214,7 +222,7 @@ def policy_logger_naive(agent) -> None: sgd_steps = agent._total_steps / agent._num_steps probs = { f"policy/{str(s)}/{agent.player_id}.cooperate": p[0] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } probs.update({"policy/total_steps": sgd_steps}) return probs @@ -222,7 +230,7 @@ def policy_logger_naive(agent) -> None: class ESLog(object): def __init__( - self, num_dims: int, num_generations: int, top_k: int, maximize: bool + self, num_dims: int, num_generations: int, top_k: int, maximize: bool ): """Simple jittable logging tool for ES rollouts.""" self.num_dims = num_dims @@ -235,55 +243,55 @@ def initialize(self) -> chex.ArrayTree: """Initialize the logger storage.""" log = { "top_fitness": jnp.zeros(self.top_k) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "top_params": jnp.zeros((self.top_k, self.num_dims)) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_top_1": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_top_mean": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_top_std": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "top_gen_fitness": jnp.zeros(self.top_k) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "top_gen_params": jnp.zeros((self.top_k, self.num_dims)) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_gen_1": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_top_gen_mean": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_top_gen_std": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_gen_mean": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "log_gen_std": jnp.zeros(self.num_generations) - - 1e10 * self.maximize - + 1e10 * (1 - self.maximize), + - 1e10 * self.maximize + + 1e10 * (1 - self.maximize), "gen_counter": 0, } return log # @partial(jax.jit, static_argnums=(0,)) def update( - self, log: chex.ArrayTree, x: chex.Array, fitness: chex.Array + self, log: chex.ArrayTree, x: chex.Array, fitness: chex.Array ) -> chex.ArrayTree: """Update the logging storage with newest data.""" # Check if there are solutions better than current archive def get_top_idx(maximize: bool, vals: jnp.ndarray) -> jnp.ndarray: top_idx = maximize * ((-1) * vals).argsort() + ( - (1 - maximize) * vals.argsort() + (1 - maximize) * vals.argsort() ) return top_idx @@ -347,13 +355,13 @@ def load(self, filename: str): return es_logger def plot( - self, - log, - title, - ylims=None, - fig=None, - ax=None, - no_legend=False, + self, + log, + title, + ylims=None, + fig=None, + ax=None, + no_legend=False, ): """Plot fitness trajectory from evo logger over generations.""" import matplotlib.pyplot as plt @@ -391,7 +399,7 @@ def plot( def ipd_visitation( - observations: jnp.ndarray, actions: jnp.ndarray, final_obs: jnp.ndarray + observations: jnp.ndarray, actions: jnp.ndarray, final_obs: jnp.ndarray ) -> dict: # obs [num_outer_steps, num_inner_steps, num_opps, num_envs, ...] # final_t [num_opps, num_envs, ...] @@ -402,7 +410,7 @@ def ipd_visitation( state_actions = jnp.reshape( state_actions, (num_timesteps,) + state_actions.shape[2:], - ) + ) # assume final step taken is cooperate final_obs = jax.lax.expand_dims(2 * jnp.argmax(final_obs, axis=-1), [0]) state_actions = jnp.append(state_actions, final_obs, axis=0) @@ -500,10 +508,10 @@ def generate_grouped_combs_strs(num_players): ] visitation_dict = ( - dict(zip(visitation_strs, state_freq)) - | dict(zip(prob_strs, state_probs)) - | dict(zip(grouped_visitation_strs, grouped_state_freq)) - | dict(zip(grouped_prob_strs, grouped_state_probs)) + dict(zip(visitation_strs, state_freq, strict=True)) + | dict(zip(prob_strs, state_probs, strict=True)) + | dict(zip(grouped_visitation_strs, grouped_state_freq, strict=True)) + | dict(zip(grouped_prob_strs, grouped_state_probs, strict=True)) ) return visitation_dict @@ -799,14 +807,20 @@ def third_party_punishment_visitation( total_punishment_str = "total_punishment" visitation_dict = ( - dict(zip(all_game_visitation_strs, action_freq)) - | dict(zip(all_game_prob_strs, action_probs)) - | dict(zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq)) - | dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs)) - | dict(zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq)) - | dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs)) - | dict(zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq)) - | dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs)) + dict(zip(all_game_visitation_strs, action_freq, strict=True)) + | dict(zip(all_game_prob_strs, action_probs, strict=True)) + | dict( + zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq, strict=True) + ) + | dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs, strict=True)) + | dict( + zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq, strict=True) + ) + | dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs, strict=True)) + | dict( + zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq, strict=True) + ) + | dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs, strict=True)) | {pl1_total_defects_prob_str: pl1_total_defects_prob} | {pl2_total_defects_prob_str: pl2_total_defects_prob} | {pl3_total_defects_prob_str: pl3_total_defects_prob} @@ -864,10 +878,6 @@ def third_party_random_visitation( game2 = jnp.stack([prev_cd_actions[:, 1], prev_cd_actions[:, 2]], axis=-1) game3 = jnp.stack([prev_cd_actions[:, 2], prev_cd_actions[:, 0]], axis=-1) - selected_game_actions = jnp.stack( # len_idx3x2 - [game1, game2, game3], axis=1 - ) - b2i = 2 ** jnp.arange(2 - 1, -1, -1) pl1_vs_pl2 = (game1 * b2i).sum(axis=-1) pl2_vs_pl3 = (game2 * b2i).sum(axis=-1) @@ -996,10 +1006,10 @@ def third_party_random_visitation( ) visitation_dict = ( - dict(zip(game_prob_strs, action_probs)) - | dict(zip(game1_prob_strs, pl1_v_pl2_action_probs)) - | dict(zip(game2_prob_strs, pl2_v_pl3_action_probs)) - | dict(zip(game3_prob_strs, pl3_v_pl1_action_probs)) + dict(zip(game_prob_strs, action_probs, strict=True)) + | dict(zip(game1_prob_strs, pl1_v_pl2_action_probs, strict=True)) + | dict(zip(game2_prob_strs, pl2_v_pl3_action_probs, strict=True)) + | dict(zip(game3_prob_strs, pl3_v_pl1_action_probs, strict=True)) | {game_selected_punish_str: game_selected_punish} | {pl1_defects_prob_str: pl1_defect_prob} | {pl2_defects_prob_str: pl2_defect_prob} @@ -1121,16 +1131,16 @@ def cg_visitation(state: NamedTuple) -> dict: def ipditm_stats( - state: EnvState, traj1: NamedTuple, traj2: NamedTuple, num_envs: int + state: EnvState, traj1: NamedTuple, traj2: NamedTuple, num_envs: int ) -> dict: from pax.envs.in_the_matrix import Actions """Compute statistics for IPDITM.""" interacts1 = ( - jnp.count_nonzero(traj1.actions == Actions.interact) / num_envs + jnp.count_nonzero(traj1.actions == Actions.interact) / num_envs ) interacts2 = ( - jnp.count_nonzero(traj2.actions == Actions.interact) / num_envs + jnp.count_nonzero(traj2.actions == Actions.interact) / num_envs ) soft_reset_mask = jnp.where(traj1.rewards != 0, 1, 0) @@ -1138,17 +1148,17 @@ def ipditm_stats( num_sft_resets = jnp.maximum(1, num_soft_resets) coops1 = ( - soft_reset_mask * traj1.observations["inventory"][..., 0] - ).sum() / (num_envs * num_sft_resets) + soft_reset_mask * traj1.observations["inventory"][..., 0] + ).sum() / (num_envs * num_sft_resets) defect1 = ( - soft_reset_mask * traj1.observations["inventory"][..., 1] - ).sum() / (num_envs * num_sft_resets) + soft_reset_mask * traj1.observations["inventory"][..., 1] + ).sum() / (num_envs * num_sft_resets) coops2 = ( - soft_reset_mask * traj2.observations["inventory"][..., 0] - ).sum() / (num_envs * num_sft_resets) + soft_reset_mask * traj2.observations["inventory"][..., 0] + ).sum() / (num_envs * num_sft_resets) defect2 = ( - soft_reset_mask * traj2.observations["inventory"][..., 1] - ).sum() / (num_envs * num_sft_resets) + soft_reset_mask * traj2.observations["inventory"][..., 1] + ).sum() / (num_envs * num_sft_resets) rewards1 = traj1.rewards.sum() / num_envs rewards2 = traj2.rewards.sum() / num_envs @@ -1172,5 +1182,3 @@ def ipditm_stats( "train/final_reward/player1": f_rewards1, "train/final_reward/player2": f_rewards2, } - - diff --git a/pax/watchers/c_rice.py b/pax/watchers/c_rice.py index 38e531a1..fec7967a 100644 --- a/pax/watchers/c_rice.py +++ b/pax/watchers/c_rice.py @@ -9,7 +9,9 @@ @partial(jax.jit, static_argnums=(1, 2)) -def c_rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: bool) -> dict: +def c_rice_stats( + trajectories: List[NamedTuple], num_players: int, mediator: bool +) -> dict: traj = trajectories[0] # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim result = { @@ -37,18 +39,27 @@ def c_rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: boo result[label] = jnp.mean(traj.observations[..., start:end]) num_episodes = jnp.sum(traj.dones) - result["final_temperature"] = jnp.where(num_episodes != 0, - jnp.sum(traj.dones * traj.observations[..., 4]) / num_episodes, 0) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 4]) / num_episodes, + 0, + ) - region_trajectories = trajectories if mediator is False else trajectories[1:] - total_reward = jnp.array([jnp.sum(_traj.rewards) for _traj in region_trajectories]).sum() - result["train/total_reward_per_episode"] = jnp.where(num_episodes != 0, total_reward / num_episodes, 0) + region_trajectories = ( + trajectories if mediator is False else trajectories[1:] + ) + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in region_trajectories] + ).sum() + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, total_reward / num_episodes, 0 + ) # Reward per region (necessary to compare weight_sharing against distributed methods) for i, _traj in enumerate(region_trajectories): - result[f"train/total_reward_per_episode_region_{i}"] = jnp.where(num_episodes != 0, - _traj.rewards.sum() / num_episodes, - 0) + result[f"train/total_reward_per_episode_region_{i}"] = jnp.where( + num_episodes != 0, _traj.rewards.sum() / num_episodes, 0 + ) return result @@ -57,7 +68,9 @@ def c_rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: boo @partial(jax.jit, static_argnums=(2)) -def c_rice_eval_stats(trajectories: List[NamedTuple], env_state: EnvState, env: ClubRice) -> dict: +def c_rice_eval_stats( + trajectories: List[NamedTuple], env_state: EnvState, env: ClubRice +) -> dict: # In the stacked env_state the inner steps are one long sequence in the first dimension # But to compute step statistics we need it to be episodes x steps x ... ep_count = int(env_state.global_temperature.shape[1] / 20) @@ -85,7 +98,9 @@ def add_atrib(name, value, axis): add_atrib("carbon_mass", env_state.global_carbon_mass, axis=(0, 2, 3)) add_atrib("labor", env_state.labor_all, axis=(0, 2, 3)) add_atrib("capital", env_state.capital_all, axis=(0, 2, 3)) - add_atrib("production_factor", env_state.production_factor_all, axis=(0, 2, 3)) + add_atrib( + "production_factor", env_state.production_factor_all, axis=(0, 2, 3) + ) add_atrib("intensity", env_state.intensity_all, axis=(0, 2, 3)) add_atrib("balance", env_state.balance_all, axis=(0, 2, 3)) add_atrib("gross_output", env_state.gross_output_all, axis=(0, 2, 3)) @@ -101,33 +116,58 @@ def add_atrib(name, value, axis): add_atrib("mitigation_cost", env_state.mitigation_cost_all, axis=(0, 2, 3)) add_atrib("damages", env_state.damages_all, axis=(0, 2, 3)) - actions = jnp.stack([traj.actions for traj in trajectories]) # n_players x rollouts x steps x ... x n_actions # reshape -> n_players x (rollouts * n_episodes) x episode_steps x ... x n_actions # transpose -> episode_steps x n_players x (rollouts * n_episodes) x ... x n_actions actions = jax.tree_util.tree_map( - lambda x: x.reshape((x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:])) - .transpose((2, 0, 1, 3, 4, 5)), + lambda x: x.reshape( + (x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:]) + ).transpose((2, 0, 1, 3, 4, 5)), actions, ) - add_atrib("savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4)) - add_atrib("mitigation_rate", actions[..., env.mitigation_rate_action_index], axis=(2, 3, 4)) - add_atrib("export_limit", actions[..., env.export_action_index], axis=(2, 3, 4)) - add_atrib("club_join_action", actions[..., env.join_club_action_index], axis=(2, 3, 4)) - add_atrib("imports", - actions[..., env.desired_imports_action_index: env.desired_imports_action_index + env.import_actions_n], - axis=(2, 3, 4)) - add_atrib("tariffs", - actions[..., env.tariffs_action_index: env.tariffs_action_index + env.tariff_actions_n], - axis=(2, 3, 4)) + add_atrib( + "savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4) + ) + add_atrib( + "mitigation_rate", + actions[..., env.mitigation_rate_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "export_limit", actions[..., env.export_action_index], axis=(2, 3, 4) + ) + add_atrib( + "club_join_action", + actions[..., env.join_club_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "imports", + actions[ + ..., + env.desired_imports_action_index : env.desired_imports_action_index + + env.import_actions_n, + ], + axis=(2, 3, 4), + ) + add_atrib( + "tariffs", + actions[ + ..., + env.tariffs_action_index : env.tariffs_action_index + + env.tariff_actions_n, + ], + axis=(2, 3, 4), + ) observations = trajectories[0].observations observations = jax.tree_util.tree_map( - lambda x: x.reshape((x.shape[0] * ep_count, ep_length, *x.shape[2:])) - .transpose((1, 0, 2, 3, 4)), + lambda x: x.reshape( + (x.shape[0] * ep_count, ep_length, *x.shape[2:]) + ).transpose((1, 0, 2, 3, 4)), observations, - ) + ) add_atrib("club_mitigation_rate", observations[..., 2], axis=(1, 2, 3)) add_atrib("club_tariff_rate", observations[..., 3], axis=(1, 2, 3)) diff --git a/pax/watchers/cournot.py b/pax/watchers/cournot.py index d10c966e..63439a72 100644 --- a/pax/watchers/cournot.py +++ b/pax/watchers/cournot.py @@ -7,7 +7,9 @@ @partial(jax.jit, static_argnums=2) -def cournot_stats(observations: jnp.ndarray, params: CournotEnvParams, num_players: int) -> dict: +def cournot_stats( + observations: jnp.ndarray, params: CournotEnvParams, num_players: int +) -> dict: opt_quantity = CournotGame.nash_policy(params) actions = observations[..., :num_players] @@ -15,7 +17,9 @@ def cournot_stats(observations: jnp.ndarray, params: CournotEnvParams, num_playe stats = { "cournot/average_quantity": average_quantity, - "cournot/quantity_loss": jnp.mean((opt_quantity - average_quantity) ** 2), + "cournot/quantity_loss": jnp.mean( + (opt_quantity - average_quantity) ** 2 + ), } for i in range(num_players): diff --git a/pax/watchers/fishery.py b/pax/watchers/fishery.py index 6724707f..49f4251a 100644 --- a/pax/watchers/fishery.py +++ b/pax/watchers/fishery.py @@ -18,16 +18,30 @@ def fishery_stats(trajectories: List[NamedTuple], num_players: int) -> dict: "fishery/stock": jnp.mean(stock_obs), "fishery/effort_mean": actions.mean(), "fishery/effort_std": actions.std(), - "fishery/final_stock": jnp.where(completed_episodes != 0, jnp.sum(traj.dones * stock_obs) / jnp.sum(traj.dones), 0), + "fishery/final_stock": jnp.where( + completed_episodes != 0, + jnp.sum(traj.dones * stock_obs) / jnp.sum(traj.dones), + 0, + ), } for i in range(num_players): stats["fishery/effort_" + str(i)] = jnp.mean(traj.observations[..., i]) - stats["fishery/effort_" + str(i) + "_std"] = jnp.mean(traj.observations[..., i]) - stats["train/total_reward_" + str(i)] = jnp.where(completed_episodes != 0, trajectories[i].rewards.sum() / completed_episodes, 0) + stats["fishery/effort_" + str(i) + "_std"] = jnp.mean( + traj.observations[..., i] + ) + stats["train/total_reward_" + str(i)] = jnp.where( + completed_episodes != 0, + trajectories[i].rewards.sum() / completed_episodes, + 0, + ) - total_reward = jnp.array([jnp.sum(_traj.rewards) for _traj in trajectories]).sum() - stats["train/total_reward_per_episode"] = jnp.where(completed_episodes != 0, total_reward / completed_episodes, 0) + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in trajectories] + ).sum() + stats["train/total_reward_per_episode"] = jnp.where( + completed_episodes != 0, total_reward / completed_episodes, 0 + ) return stats @@ -39,17 +53,22 @@ def fishery_eval_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict: ep_length = np.arange(len(effort_1)).tolist() # Plot the effort on the same graph - effort_plot = wandb.plot.line_series(xs=ep_length, ys=[effort_1, effort_2], - keys=["effort_1", "effort_2"], - xname="step", - title="Agent effort over one episode") + effort_plot = wandb.plot.line_series( + xs=ep_length, + ys=[effort_1, effort_2], + keys=["effort_1", "effort_2"], + xname="step", + title="Agent effort over one episode", + ) stock_obs = traj1.observations[..., 0].squeeze().tolist() - stock_table = wandb.Table(data=[[x, y] for (x, y) in zip(ep_length, stock_obs)], columns=["step", "stock"]) + stock_table = wandb.Table( + data=[[x, y] for (x, y) in zip(ep_length, stock_obs, strict=True)], + columns=["step", "stock"], + ) # Plot the stock in a separate graph - stock_plot = wandb.plot.line(stock_table, x="step", y="stock", title="Stock over one episode") + stock_plot = wandb.plot.line( + stock_table, x="step", y="stock", title="Stock over one episode" + ) - return { - "fishery/stock": stock_plot, - "fishery/effort": effort_plot - } + return {"fishery/stock": stock_plot, "fishery/effort": effort_plot} diff --git a/pax/watchers/rice.py b/pax/watchers/rice.py index ffb5bcdd..293449cd 100644 --- a/pax/watchers/rice.py +++ b/pax/watchers/rice.py @@ -8,7 +8,9 @@ @partial(jax.jit, static_argnums=(1, 2)) -def rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: bool) -> dict: +def rice_stats( + trajectories: List[NamedTuple], num_players: int, mediator: bool +) -> dict: traj = trajectories[0] # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim result = { @@ -39,18 +41,27 @@ def rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: bool) result[label] = jnp.mean(traj.observations[..., start:end]) num_episodes = jnp.sum(traj.dones) - result["final_temperature"] = jnp.where(num_episodes != 0, - jnp.sum(traj.dones * traj.observations[..., 2]) / num_episodes, 0) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 2]) / num_episodes, + 0, + ) - region_trajectories = trajectories if mediator is False else trajectories[1:] - total_reward = jnp.array([jnp.sum(_traj.rewards) for _traj in region_trajectories]).sum() - result["train/total_reward_per_episode"] = jnp.where(num_episodes != 0, total_reward / num_episodes, 0) + region_trajectories = ( + trajectories if mediator is False else trajectories[1:] + ) + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in region_trajectories] + ).sum() + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, total_reward / num_episodes, 0 + ) # Reward per region (necessary to compare weight_sharing against distributed methods) for i, _traj in enumerate(region_trajectories): - result[f"train/total_reward_per_episode_region_{i}"] = jnp.where(num_episodes != 0, - _traj.rewards.sum() / num_episodes, - 0) + result[f"train/total_reward_per_episode_region_{i}"] = jnp.where( + num_episodes != 0, _traj.rewards.sum() / num_episodes, 0 + ) return result @@ -59,7 +70,9 @@ def rice_stats(trajectories: List[NamedTuple], num_players: int, mediator: bool) @partial(jax.jit, static_argnums=(2)) -def rice_eval_stats(trajectories: List[NamedTuple], env_state: EnvState, env: Rice) -> dict: +def rice_eval_stats( + trajectories: List[NamedTuple], env_state: EnvState, env: Rice +) -> dict: # In the stacked env_state the inner steps are one long sequence in the first dimension # But to compute step statistics we need it to be episodes x steps x ... ep_count = int(env_state.global_temperature.shape[1] / 20) @@ -87,7 +100,9 @@ def add_atrib(name, value, axis): add_atrib("carbon_mass", env_state.global_carbon_mass, axis=(0, 2, 3)) add_atrib("labor", env_state.labor_all, axis=(0, 2, 3)) add_atrib("capital", env_state.capital_all, axis=(0, 2, 3)) - add_atrib("production_factor", env_state.production_factor_all, axis=(0, 2, 3)) + add_atrib( + "production_factor", env_state.production_factor_all, axis=(0, 2, 3) + ) add_atrib("intensity", env_state.intensity_all, axis=(0, 2, 3)) add_atrib("balance", env_state.balance_all, axis=(0, 2, 3)) add_atrib("gross_output", env_state.gross_output_all, axis=(0, 2, 3)) @@ -108,22 +123,44 @@ def add_atrib(name, value, axis): # reshape -> n_players x (rollouts * n_episodes) x episode_steps x ... x n_actions # transpose -> episode_steps x n_players x (rollouts * n_episodes) x ... x n_actions actions = jax.tree_util.tree_map( - lambda x: x.reshape((x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:])) - .transpose((2, 0, 1, 3, 4, 5)), + lambda x: x.reshape( + (x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:]) + ).transpose((2, 0, 1, 3, 4, 5)), actions, ) - add_atrib("savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4)) - add_atrib("mitigation_rate", actions[..., env.mitigation_rate_action_index], axis=(2, 3, 4)) - add_atrib("export_limit", actions[..., env.export_action_index], axis=(2, 3, 4)) - add_atrib("imports", - actions[..., env.desired_imports_action_index: env.desired_imports_action_index + env.import_actions_n], - axis=(2, 3, 4)) - add_atrib("tariffs", - actions[..., env.tariffs_action_index: env.tariffs_action_index + env.tariff_actions_n], - axis=(2, 3, 4)) + add_atrib( + "savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4) + ) + add_atrib( + "mitigation_rate", + actions[..., env.mitigation_rate_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "export_limit", actions[..., env.export_action_index], axis=(2, 3, 4) + ) + add_atrib( + "imports", + actions[ + ..., + env.desired_imports_action_index : env.desired_imports_action_index + + env.import_actions_n, + ], + axis=(2, 3, 4), + ) + add_atrib( + "tariffs", + actions[ + ..., + env.tariffs_action_index : env.tariffs_action_index + + env.tariff_actions_n, + ], + axis=(2, 3, 4), + ) return result + @partial(jax.jit, static_argnums=(1,)) def rice_sarl_stats(traj: NamedTuple, num_players: int) -> dict: # obs shape: num_steps x num_envs x obs_dim @@ -131,7 +168,6 @@ def rice_sarl_stats(traj: NamedTuple, num_players: int) -> dict: # Actions "savings_rate": jnp.mean(traj.actions[..., 0]), "mitigation_rate": jnp.mean(traj.actions[..., 1]), - "temperature": jnp.mean(traj.observations[..., 1]), "land_temperature": jnp.mean(traj.observations[..., 2]), "carbon_atmosphere": jnp.mean(traj.observations[..., 3]), @@ -160,8 +196,13 @@ def rice_sarl_stats(traj: NamedTuple, num_players: int) -> dict: result[label] = jnp.mean(traj.observations[..., start:end]) num_episodes = jnp.sum(traj.dones) - result["final_temperature"] = jnp.where(num_episodes != 0, - jnp.sum(traj.dones * traj.observations[..., 1]) / num_episodes, 0) - result["train/total_reward_per_episode"] = jnp.where(num_episodes != 0, traj.rewards.sum() / num_episodes, 0) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 1]) / num_episodes, + 0, + ) + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, traj.rewards.sum() / num_episodes, 0 + ) return result diff --git a/test/envs/test_cournot.py b/test/envs/test_cournot.py index a989cb59..6d17e796 100644 --- a/test/envs/test_cournot.py +++ b/test/envs/test_cournot.py @@ -17,7 +17,10 @@ def test_single_cournot_game(): obs, env_state = env.reset(rng, env_params) obs, env_state, rewards, done, info = env.step( - rng, env_state, tuple([nash_action for _ in range(n_player)]), env_params + rng, + env_state, + tuple([nash_action for _ in range(n_player)]), + env_params, ) assert all(element == rewards[0] for element in rewards) @@ -26,14 +29,18 @@ def test_single_cournot_game(): nash_reward = CournotGame.nash_reward(env_params) assert nash_reward == 1800 assert jnp.isclose(nash_reward / n_player, rewards[0], atol=0.01) - expected_obs = jnp.array([60/n_player for _ in range(n_player)] + [40]) + expected_obs = jnp.array( + [60 / n_player for _ in range(n_player)] + [40] + ) assert jnp.allclose(obs[0], expected_obs, atol=0.01) assert jnp.allclose(obs[0], obs[1], atol=0.0) social_opt_action = jnp.array([45 / n_player]) obs, env_state = env.reset(rng, env_params) obs, env_state, rewards, done, info = env.step( - rng, env_state, tuple([social_opt_action for _ in range(n_player)]), env_params + rng, + env_state, + tuple([social_opt_action for _ in range(n_player)]), + env_params, ) assert jnp.asarray(rewards).sum() == 2025 - diff --git a/test/envs/test_fishery.py b/test/envs/test_fishery.py index 73124151..659e83bd 100644 --- a/test/envs/test_fishery.py +++ b/test/envs/test_fishery.py @@ -9,14 +9,7 @@ def test_fishery_convergence(): ep_length = 300 env = Fishery(num_players=2, num_inner_steps=ep_length) - env_params = EnvParams( - g=0.15, - e=0.009, - P=200, - w=0.9, - s_0=1.0, - s_max=1.0 - ) + env_params = EnvParams(g=0.15, e=0.009, P=200, w=0.9, s_0=1.0, s_max=1.0) # response parameter obs, env_state = env.reset(rng, env_params) @@ -34,10 +27,16 @@ def test_fishery_convergence(): # Check convergence at the end of an episode if i % ep_length == ep_length - 2: - S_star = env_params.w / (env_params.P * env_params.e * env_params.s_max) # 0.5 + S_star = env_params.w / ( + env_params.P * env_params.e * env_params.s_max + ) # 0.5 assert jnp.isclose(S_star, env_state.s, atol=0.01) - H_star = (env_params.w * env_params.g / (env_params.P * env_params.e)) * (1 - S_star) # 0.0375 + H_star = ( + env_params.w * env_params.g / (env_params.P * env_params.e) + ) * ( + 1 - S_star + ) # 0.0375 assert jnp.isclose(H_star, info["H"], atol=0.01) E_star = (env_params.g / env_params.e) * (1 - S_star) # ~ 8.3333 @@ -49,4 +48,3 @@ def test_fishery_convergence(): assert env_state.outer_t == i // ep_length assert env_state.s == env_params.s_0 assert step_reward == 0 - diff --git a/test/envs/test_iterated_tensor_game_n_player.py b/test/envs/test_iterated_tensor_game_n_player.py index bffbd423..a80ab63e 100644 --- a/test/envs/test_iterated_tensor_game_n_player.py +++ b/test/envs/test_iterated_tensor_game_n_player.py @@ -9,7 +9,7 @@ IteratedTensorGameNPlayer, ) -####### 2 PLAYERS ####### +# 2 PLAYERS payoff_table_2pl = [ [4, jnp.nan], [2, 5], @@ -25,7 +25,7 @@ dc_obs = 2 dd_obs = 3 -####### 3 PLAYERS ####### +# 3 PLAYERS payoff_table_3pl = [ [4, jnp.nan], [2.66, 5.66], @@ -84,7 +84,7 @@ ddc_obs = 6 ddd_obs = 7 -####### 4 PLAYERS ####### +# 4 PLAYERS payoff_table_4pl = [ [4, jnp.nan], [3, 6], @@ -207,6 +207,7 @@ dddc_obs = 14 dddd_obs = 15 + # ###### Begin actual tests ####### @pytest.mark.parametrize("payoff", [payoff_table_2pl]) def test_single_batch_2pl(payoff) -> None: @@ -214,7 +215,7 @@ def test_single_batch_2pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 2 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=2, num_inner_steps=5, num_outer_steps=1 ) @@ -230,7 +231,7 @@ def test_single_batch_2pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 2 player + # test 2 player # cc obs, env_state, rewards, done, info = env.step( rng, env_state, (0 * action, 0 * action), env_params @@ -248,7 +249,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##dc + # dc obs, env_state, rewards, done, info = env.step( rng, env_state, (1 * action, 0 * action), env_params ) @@ -264,7 +265,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##cd + # cd obs, env_state, rewards, done, info = env.step( rng, env_state, (0 * action, 1 * action), env_params ) @@ -279,7 +280,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##dd + # dd obs, env_state, rewards, done, info = env.step( rng, env_state, (1 * action, 1 * action), env_params ) @@ -301,7 +302,7 @@ def test_single_batch_4pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 4 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=num_players, num_inner_steps=100, num_outer_steps=1 ) @@ -317,7 +318,7 @@ def test_single_batch_4pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 4 player + # test 4 player # cccc cccc = (0 * action, 0 * action, 0 * action, 0 * action) obs, env_state, rewards, done, info = env.step( @@ -510,7 +511,7 @@ def test_single_batch_3pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 3 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=num_players, num_inner_steps=100, num_outer_steps=1 ) @@ -526,7 +527,7 @@ def test_single_batch_3pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 3 player + # test 3 player # ccc ccc = (0 * action, 0 * action, 0 * action) obs, env_state, rewards, done, info = env.step( @@ -617,9 +618,6 @@ def test_single_batch_3pl(payoff) -> None: def test_batch_diff_actions() -> None: - num_envs = 5 - all_ones = jnp.ones((num_envs,)) - rng = jax.random.PRNGKey(0) # 5 envs, with actions ccc, dcc, cdc, ddd, ccd diff --git a/test/envs/test_rice.py b/test/envs/test_rice.py index a48de488..82dbdc6a 100644 --- a/test/envs/test_rice.py +++ b/test/envs/test_rice.py @@ -2,13 +2,9 @@ import time import jax -import jax.numpy as jnp -import tree from pax.envs.rice.rice import Rice, EnvParams -# config.update('jax_disable_jit', True) - file_dir = os.path.join(os.path.dirname(__file__)) config_folder = os.path.join(file_dir, "../../pax/envs/rice/5_regions") num_players = 5 @@ -33,43 +29,31 @@ def test_rice(): ) for j in range(num_players): # assert all obs positive - assert env_state.consumption_all[j].item() >= 0, "consumption cannot be negative!" - assert env_state.production_all[j].item() >= 0, "production cannot be negative!" - assert env_state.labor_all[j].item() >= 0, "labor cannot be negative!" - assert env_state.capital_all[j].item() >= 0, "capital cannot be negative!" - assert env_state.gross_output_all[j].item() >= 0, "gross output cannot be negative!" - assert env_state.investment_all[j].item() >= 0, "investment cannot be negative!" + assert ( + env_state.consumption_all[j].item() >= 0 + ), "consumption cannot be negative!" + assert ( + env_state.production_all[j].item() >= 0 + ), "production cannot be negative!" + assert ( + env_state.labor_all[j].item() >= 0 + ), "labor cannot be negative!" + assert ( + env_state.capital_all[j].item() >= 0 + ), "capital cannot be negative!" + assert ( + env_state.gross_output_all[j].item() >= 0 + ), "gross output cannot be negative!" + assert ( + env_state.investment_all[j].item() >= 0 + ), "investment cannot be negative!" if done is True: assert env_state.inner_t == 0 assert (i + 1) / ep_length == 0 - assert (i + 1) % ep_length == env_state.inner_t, "inner_t not updating correctly" - - -def test_rice_regression(data_regression): - rng = jax.random.PRNGKey(0) - - env = Rice(config_folder=config_folder, episode_length=ep_length) # Make sure to define 'config_folder' - env_params = EnvParams() - obs, env_state = env.reset(rng, env_params) - - results = {} - - for i in range(ep_length): - action = jnp.ones((env.num_actions,)) * 0.5 # Deterministic action for testing - actions = tuple([action for _ in range(num_players)]) - obs, env_state, rewards, done, info = env.step(rng, env_state, actions, env_params) - - stored_obs = [_obs.round(1).tolist() for _obs in obs] - stored_env_state = [leaf.round(1).tolist() for leaf in tree.flatten(env_state)] - stored_rewards = [reward.round(1).item() for reward in rewards] - results[i] = { - "obs": stored_obs, - "env_state": stored_env_state, - "rewards": stored_rewards, - } - - data_regression.check(results) + assert ( + i + 1 + ) % ep_length == env_state.inner_t, "inner_t not updating correctly" def rice_performance_benchmark(): @@ -82,7 +66,7 @@ def rice_performance_benchmark(): start_time = time.time() - for i in range(iterations * ep_length): + for _ in range(iterations * ep_length): # Do random actions key, _ = jax.random.split(rng, 2) action = jax.random.uniform(rng, (env.num_actions,)) @@ -97,7 +81,9 @@ def rice_performance_benchmark(): # Print or log the total time taken for all iterations print(f"Total iterations:\t{iterations * ep_length}") print(f"Total time taken:\t{total_time:.4f} seconds") - print(f"Average step duration:\t{total_time / (iterations * ep_length):.6f} seconds") + print( + f"Average step duration:\t{total_time / (iterations * ep_length):.6f} seconds" + ) # Run a benchmark diff --git a/test/test_strategies.py b/test/test_strategies.py index ccc38fcc..1d4f9b9b 100644 --- a/test/test_strategies.py +++ b/test/test_strategies.py @@ -247,58 +247,6 @@ def test_naive_tft(): assert jnp.allclose(0.99, reward0, atol=0.01) -def test_naive_tft(): - num_envs = 2 - rng = jnp.concatenate([jax.random.PRNGKey(0)] * num_envs).reshape( - num_envs, -1 - ) - - env = InfiniteMatrixGame(num_steps=jnp.inf) - env_params = InfiniteMatrixGameParams( - payoff_matrix=[[2, 2], [0, 3], [3, 0], [1, 1]], gamma=0.96 - ) - - # vmaps - split = jax.vmap(jax.random.split, in_axes=(0, None)) - env.reset = jax.jit( - jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) - ) - env.step = jax.jit( - jax.vmap( - env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) - ) - ) - - (obs1, _), env_state = env.reset(rng, env_params) - agent = NaiveExact( - action_dim=5, - env_params=env_params, - lr=1, - num_envs=num_envs, - player_id=0, - ) - agent_state, agent_memory = agent.make_initial_state(obs1) - tft_action = jnp.tile( - 20 * jnp.array([1.0, -1.0, 1.0, -1.0, 1.0]), (num_envs, 1) - ) - - for _ in range(50): - rng, _ = split(rng, 2) - action, agent_state, agent_memory = agent._policy( - agent_state, obs1, agent_memory - ) - - (obs1, _), env_state, rewards, _, _ = env.step( - rng, env_state, (action, tft_action), env_params - ) - - action, _, _ = agent._policy(agent_state, obs1, agent_memory) - _, _, (reward0, reward1), _, _ = env.step( - rng, env_state, (action, tft_action), env_params - ) - assert jnp.allclose(2.0, reward0, atol=0.01) - - def test_naive_tft_as_second_player(): num_envs = 2 rng = jnp.concatenate([jax.random.PRNGKey(0)] * num_envs).reshape( diff --git a/test/test_tensor_strategies.py b/test/test_tensor_strategies.py index ab2e36d6..c6e55701 100644 --- a/test/test_tensor_strategies.py +++ b/test/test_tensor_strategies.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp from pax.agents.tensor_strategies import (