From 9d3fa62e34279a338c07cffcbf208edc8a95e7ba Mon Sep 17 00:00:00 2001 From: Christoph Proeschel Date: Wed, 18 Oct 2023 16:42:01 +0200 Subject: [PATCH] Add Rice-N, C-Rice-N, Fishery, Cournot competition and a parameter sharing runner (#161) * most of the changes * most of the changes * more stuff * Add cournot game * test optimal policy * Add more cournot configs, fixes and first draft of the fishery environment * fix cournot test * fix cournot test and config * fix fishery tests * improvements: cournot optimum v nash optimum, fishery configs * fishery eval checkpoint * nplayer * n player fixes, n player cournot * add rice environment * checkpoint: parity between pax rice and ai4coop rice * Add 5 regions rice configuration * add a rice_n regression * rice consistency checkpoint * fully vectorized rice environment * checkpoint * checkpoint * fix fishery * cleanup * more refactoring * fixes * Add pytest regression * checkpoint * refactor watchers file and fix some types and unused imports * more experiments plus rice refactor * tests failing * reformat version file * exclude version file * fix exclude statement * exclude version file in github action * fix * fix * another attempt --------- Co-authored-by: alexandrasouly --- .github/workflows/test.yaml | 5 +- .gitignore | 2 - .vscode/launch.json | 4 +- .vscode/settings.json | 3 + README.md | 4 +- docs/envs.md | 20 +- pax/agents/hyper/networks.py | 5 +- pax/agents/lola/lola.py | 18 +- pax/agents/lola/network.py | 2 - pax/agents/mfos_ppo/networks.py | 26 +- pax/agents/mfos_ppo/ppo_gru.py | 17 +- pax/agents/naive/buffer.py | 3 +- pax/agents/naive/naive.py | 9 +- pax/agents/naive/network.py | 60 +- pax/agents/ppo/batched_envs.py | 2 +- pax/agents/ppo/buffer.py | 2 +- pax/agents/ppo/networks.py | 193 +- pax/agents/ppo/ppo.py | 45 +- pax/agents/ppo/ppo_gru.py | 40 +- pax/agents/strategies.py | 1 - pax/agents/tensor_strategies.py | 9 +- pax/conf/config.yaml | 15 +- pax/conf/experiment/c_rice/debug.yaml | 81 + .../c_rice/eval_mediator_gs_ppo.yaml | 93 + pax/conf/experiment/c_rice/marl_baseline.yaml | 55 + .../experiment/c_rice/mediator_gs_ppo.yaml | 83 + pax/conf/experiment/c_rice/shaper_v_ppo.yaml | 90 + .../experiment/c_rice/weight_sharing.yaml | 53 + pax/conf/experiment/cournot/gs_v_ppo.yaml | 77 + .../experiment/cournot/marl_baseline.yaml | 51 + pax/conf/experiment/cournot/shaper_v_ppo.yaml | 100 + .../experiment/cournot/weight_sharing.yaml | 100 + .../experiment/fishery/eval_gs_v_ppo.yaml | 109 + .../fishery/eval_marl_baseline.yaml | 79 + .../fishery/eval_weight_sharing_v_ppo.yaml | 64 + pax/conf/experiment/fishery/gs_v_ppo.yaml | 103 + .../experiment/fishery/marl_baseline.yaml | 58 + pax/conf/experiment/fishery/mfos_v_ppo.yaml | 101 + pax/conf/experiment/fishery/shaper_v_ppo.yaml | 102 + .../experiment/fishery/weight_sharing.yaml | 56 + pax/conf/experiment/ipd/inf_mfos_v_nl.yaml | 9 +- pax/conf/experiment/ipd/mfos_v_ppo.yaml | 24 +- pax/conf/experiment/rice/debug.yaml | 82 + .../experiment/rice/eval_shaper_v_ppo.yaml | 100 + pax/conf/experiment/rice/gs_v_ppo.yaml | 100 + pax/conf/experiment/rice/marl_baseline.yaml | 54 + .../experiment/rice/mediator_gs_naive.yaml | 90 + pax/conf/experiment/rice/mediator_gs_ppo.yaml | 80 + pax/conf/experiment/rice/mediator_shaper.yaml | 79 + pax/conf/experiment/rice/mfos_v_ppo.yaml | 84 + pax/conf/experiment/rice/sarl.yaml | 56 + pax/conf/experiment/rice/shaper_v_ppo.yaml | 83 + pax/conf/experiment/rice/weight_sharing.yaml | 54 + pax/envs/cournot.py | 109 + pax/envs/fishery.py | 145 + pax/envs/infinite_matrix_game.py | 8 +- pax/envs/iterated_matrix_game.py | 6 +- pax/envs/iterated_tensor_game_n_player.py | 5 +- pax/envs/rice/27_regions/11.yml | 13 + pax/envs/rice/27_regions/12.yml | 13 + pax/envs/rice/27_regions/13.yml | 13 + pax/envs/rice/27_regions/14.yml | 13 + pax/envs/rice/27_regions/15.yml | 13 + pax/envs/rice/27_regions/16.yml | 13 + pax/envs/rice/27_regions/17.yml | 13 + pax/envs/rice/27_regions/18.yml | 13 + pax/envs/rice/27_regions/19.yml | 13 + pax/envs/rice/27_regions/2.yml | 13 + pax/envs/rice/27_regions/20.yml | 13 + pax/envs/rice/27_regions/21.yml | 13 + pax/envs/rice/27_regions/22.yml | 13 + pax/envs/rice/27_regions/23.yml | 13 + pax/envs/rice/27_regions/24.yml | 13 + pax/envs/rice/27_regions/25.yml | 13 + pax/envs/rice/27_regions/26.yml | 13 + pax/envs/rice/27_regions/27.yml | 13 + pax/envs/rice/27_regions/28.yml | 13 + pax/envs/rice/27_regions/29.yml | 13 + pax/envs/rice/27_regions/3.yml | 13 + pax/envs/rice/27_regions/30.yml | 13 + pax/envs/rice/27_regions/4.yml | 13 + pax/envs/rice/27_regions/5.yml | 13 + pax/envs/rice/27_regions/6.yml | 13 + pax/envs/rice/27_regions/7.yml | 13 + pax/envs/rice/27_regions/9.yml | 13 + pax/envs/rice/27_regions/default.yml | 68 + pax/envs/rice/5_regions/1.yml | 13 + pax/envs/rice/5_regions/2.yml | 13 + pax/envs/rice/5_regions/3.yml | 13 + pax/envs/rice/5_regions/4.yml | 13 + pax/envs/rice/5_regions/5.yml | 13 + pax/envs/rice/5_regions/default.yml | 68 + pax/envs/rice/c_rice.py | 623 + pax/envs/rice/rice.py | 776 + pax/envs/rice/sarl_rice.py | 115 + pax/experiment.py | 205 +- pax/runners/README.md | 26 + pax/runners/runner_eval.py | 310 +- pax/runners/runner_eval_multishaper.py | 26 +- pax/runners/runner_evo.py | 236 +- pax/runners/runner_evo_multishaper.py | 29 +- pax/runners/runner_marl.py | 56 +- pax/runners/runner_marl_nplayer.py | 45 +- pax/runners/runner_sarl.py | 25 +- pax/runners/runner_weight_sharing.py | 265 + pax/utils.py | 18 +- pax/version.py | 5 +- pax/{watchers.py => watchers/__init__.py} | 85 +- pax/watchers/c_rice.py | 174 + pax/watchers/cournot.py | 28 + pax/watchers/fishery.py | 74 + pax/watchers/rice.py | 208 + requirements.txt | 5 +- test/envs/test_cournot.py | 46 + test/envs/test_fishery.py | 50 + .../test_iterated_tensor_game_n_player.py | 28 +- test/envs/test_rice.py | 91 + test/envs/test_rice/test_rice_regression.yml | 14780 ++++++++++++++++ test/test_strategies.py | 52 - 119 files changed, 21547 insertions(+), 494 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 pax/conf/experiment/c_rice/debug.yaml create mode 100644 pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml create mode 100644 pax/conf/experiment/c_rice/marl_baseline.yaml create mode 100644 pax/conf/experiment/c_rice/mediator_gs_ppo.yaml create mode 100644 pax/conf/experiment/c_rice/shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/c_rice/weight_sharing.yaml create mode 100644 pax/conf/experiment/cournot/gs_v_ppo.yaml create mode 100644 pax/conf/experiment/cournot/marl_baseline.yaml create mode 100644 pax/conf/experiment/cournot/shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/cournot/weight_sharing.yaml create mode 100644 pax/conf/experiment/fishery/eval_gs_v_ppo.yaml create mode 100644 pax/conf/experiment/fishery/eval_marl_baseline.yaml create mode 100644 pax/conf/experiment/fishery/eval_weight_sharing_v_ppo.yaml create mode 100644 pax/conf/experiment/fishery/gs_v_ppo.yaml create mode 100644 pax/conf/experiment/fishery/marl_baseline.yaml create mode 100644 pax/conf/experiment/fishery/mfos_v_ppo.yaml create mode 100644 pax/conf/experiment/fishery/shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/fishery/weight_sharing.yaml create mode 100644 pax/conf/experiment/rice/debug.yaml create mode 100644 pax/conf/experiment/rice/eval_shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/rice/gs_v_ppo.yaml create mode 100644 pax/conf/experiment/rice/marl_baseline.yaml create mode 100644 pax/conf/experiment/rice/mediator_gs_naive.yaml create mode 100644 pax/conf/experiment/rice/mediator_gs_ppo.yaml create mode 100644 pax/conf/experiment/rice/mediator_shaper.yaml create mode 100644 pax/conf/experiment/rice/mfos_v_ppo.yaml create mode 100644 pax/conf/experiment/rice/sarl.yaml create mode 100644 pax/conf/experiment/rice/shaper_v_ppo.yaml create mode 100644 pax/conf/experiment/rice/weight_sharing.yaml create mode 100644 pax/envs/cournot.py create mode 100644 pax/envs/fishery.py create mode 100644 pax/envs/rice/27_regions/11.yml create mode 100644 pax/envs/rice/27_regions/12.yml create mode 100644 pax/envs/rice/27_regions/13.yml create mode 100644 pax/envs/rice/27_regions/14.yml create mode 100644 pax/envs/rice/27_regions/15.yml create mode 100644 pax/envs/rice/27_regions/16.yml create mode 100644 pax/envs/rice/27_regions/17.yml create mode 100644 pax/envs/rice/27_regions/18.yml create mode 100644 pax/envs/rice/27_regions/19.yml create mode 100644 pax/envs/rice/27_regions/2.yml create mode 100644 pax/envs/rice/27_regions/20.yml create mode 100644 pax/envs/rice/27_regions/21.yml create mode 100644 pax/envs/rice/27_regions/22.yml create mode 100644 pax/envs/rice/27_regions/23.yml create mode 100644 pax/envs/rice/27_regions/24.yml create mode 100644 pax/envs/rice/27_regions/25.yml create mode 100644 pax/envs/rice/27_regions/26.yml create mode 100644 pax/envs/rice/27_regions/27.yml create mode 100644 pax/envs/rice/27_regions/28.yml create mode 100644 pax/envs/rice/27_regions/29.yml create mode 100644 pax/envs/rice/27_regions/3.yml create mode 100644 pax/envs/rice/27_regions/30.yml create mode 100644 pax/envs/rice/27_regions/4.yml create mode 100644 pax/envs/rice/27_regions/5.yml create mode 100644 pax/envs/rice/27_regions/6.yml create mode 100644 pax/envs/rice/27_regions/7.yml create mode 100644 pax/envs/rice/27_regions/9.yml create mode 100644 pax/envs/rice/27_regions/default.yml create mode 100644 pax/envs/rice/5_regions/1.yml create mode 100644 pax/envs/rice/5_regions/2.yml create mode 100644 pax/envs/rice/5_regions/3.yml create mode 100644 pax/envs/rice/5_regions/4.yml create mode 100644 pax/envs/rice/5_regions/5.yml create mode 100644 pax/envs/rice/5_regions/default.yml create mode 100644 pax/envs/rice/c_rice.py create mode 100644 pax/envs/rice/rice.py create mode 100644 pax/envs/rice/sarl_rice.py create mode 100644 pax/runners/README.md create mode 100644 pax/runners/runner_weight_sharing.py rename pax/{watchers.py => watchers/__init__.py} (95%) create mode 100644 pax/watchers/c_rice.py create mode 100644 pax/watchers/cournot.py create mode 100644 pax/watchers/fishery.py create mode 100644 pax/watchers/rice.py create mode 100644 test/envs/test_cournot.py create mode 100644 test/envs/test_fishery.py create mode 100644 test/envs/test_rice.py create mode 100644 test/envs/test_rice/test_rice_regression.yml diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dfc31072..89bcc56a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,8 +1,8 @@ # .github/workflows/app.yaml name: PyTest -on: +on: pull_request: - branches: + branches: - main jobs: @@ -16,6 +16,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: '3.9' + - uses: pre-commit/action@v3.0.0 - name: Ensure latest pip run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index 337e4a09..f8f04061 100644 --- a/.gitignore +++ b/.gitignore @@ -104,7 +104,6 @@ exp/ # Hydra .hydra -exp/ venv/ plots/ figures/ @@ -115,4 +114,3 @@ experiment.log # Pax pax/version.py -experiment.log diff --git a/.vscode/launch.json b/.vscode/launch.json index bc2baf4c..448e2dc7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -13,8 +13,8 @@ "justMyCode": false, "env": { "OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES", - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python" } } ] -} \ No newline at end of file +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..de288e1e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.formatting.provider": "black" +} \ No newline at end of file diff --git a/README.md b/README.md index 87018ce1..cfce1d1f 100644 --- a/README.md +++ b/README.md @@ -215,7 +215,7 @@ Pax is written in pure Python, but depends on C++ code via JAX. Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt. -First, follow these instructions to install JAX with the relevant accelerator support. +First, follow these instructions to [install](https://github.com/google/jax#installation) JAX with the relevant accelerator support. ## General Information The project entrypoint is `pax/experiment.py`. The simplest command to run a game would be: @@ -265,4 +265,4 @@ If you use Pax in any of your work, please cite: journal = {GitHub repository}, howpublished = {\url{https://github.com/akbir/pax}}, } -``` \ No newline at end of file +``` diff --git a/docs/envs.md b/docs/envs.md index 7f4a711f..6f6e5092 100644 --- a/docs/envs.md +++ b/docs/envs.md @@ -1,13 +1,17 @@ ## Environments -Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independetly you can specify your enviroment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`. +Pax includes many environments specified by `env_id`. These are `infinite_matrix_game`, `iterated_matrix_game` and `coin_game`. Independently you can specify your environment type as either a meta environment (with an inner/ outer loop) by `env_type`, the options supported are `sequential` or `meta`. These are specified in the config files in `pax/configs/{env_id}/EXPERIMENT.yaml`. -| Environment ID | Environment Type | Description | -| ----------- | ----------- | ----------- | -|`iterated_matrix_game`| `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ | -|`iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | -|`infinite_matrix_game` | `meta`| An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ | -|coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode| -|coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode| +| Environment ID | Environment Type | Description | +|------------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `iterated_matrix_game` | `sequential` | An iterated matrix game with a predetermined number of timesteps per episode with a discount factor $\gamma$ | +| `iterated_matrix_game` | `meta` | A meta game over the iterated matrix game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | +| `infinite_matrix_game` | `meta` | An infinite matrix game that calculates exact returns given a payoff and discount factor $\gamma$ | +| coin_game | `sequential` | A sequential series of episode of the coin game between two players. Each player updates at the end of an episode | +| coin_game | `meta` | A meta learning version of the coin game with an outer agent (player 1) and an inner agent (player 2). The inner updates every episode, while the the outer agent updates every meta-episode | +| cournot | `sequential`/`meta` | A one-shot version of a [Cournot competition](https://en.wikipedia.org/wiki/Cournot_competition) | +| fishery | `sequential`/`meta` | A dynamic resource harvesting game as specified in Perman et al. | +| Rice-N | `sequential`/`meta` | A re-implementation of the Integrated Assessment Model introduced by [Zhang et al.](https://papers.ssrn.com/abstract=4189735) available with either the original 27 regions or a new calibration of only 5 regions | +| C-Rice-N | `sequential`/`meta` | An extension of Rice-N with a simple climate club mechanism | diff --git a/pax/agents/hyper/networks.py b/pax/agents/hyper/networks.py index 686f2c7d..c623c0dd 100644 --- a/pax/agents/hyper/networks.py +++ b/pax/agents/hyper/networks.py @@ -4,6 +4,7 @@ import haiku as hk import jax import jax.numpy as jnp +from distrax import MultivariateNormalDiag from pax import utils @@ -74,7 +75,7 @@ def make_GRU(num_actions: int): def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]: """forward function""" gru = hk.GRU(hidden_size) embedding, state = gru(inputs, state) @@ -92,7 +93,7 @@ def make_GRU_hypernetwork(num_actions: int): def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[MultivariateNormalDiag, jnp.ndarray], jnp.ndarray]: """forward function""" gru = hk.GRU(hidden_size) embedding, state = gru(inputs, state) diff --git a/pax/agents/lola/lola.py b/pax/agents/lola/lola.py index d1d95126..a2c5532a 100644 --- a/pax/agents/lola/lola.py +++ b/pax/agents/lola/lola.py @@ -1,17 +1,11 @@ -from typing import Any, Dict, List, Mapping, NamedTuple, Tuple +from typing import Any, List, NamedTuple, Tuple -import haiku as hk import jax import jax.numpy as jnp -import numpy as np import optax -from dm_env import TimeStep - -# from pax.lola.buffer import TrajectoryBuffer -from pax.agents.lola.network import make_network from pax import utils -from pax.agents.ppo.ppo_gru import PPO +from pax.agents.lola.network import make_network from pax.runners.runner_marl import Sample from pax.utils import MemoryState, TrainingState @@ -295,7 +289,9 @@ def inner_loss( "loss_value": value_objective, } - def make_initial_state(key: Any, hidden) -> TrainingState: + def make_initial_state( + key: Any, hidden + ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) dummy_obs = jnp.zeros(shape=obs_spec) @@ -530,7 +526,7 @@ def lola_inlookahead_rollout(carry, unused): ) my_mem = carry[7] other_mems = carry[8] - # jax.debug.breakpoint() + # flip axes to get (num_envs, num_inner, obs_dim) to vmap over numenvs vmap_trajectories = jax.tree_map( lambda x: jnp.swapaxes(x, 0, 1), trajectories @@ -545,7 +541,7 @@ def lola_inlookahead_rollout(carry, unused): rewards_self=vmap_trajectories[0].rewards_self, rewards_other=[traj.rewards for traj in vmap_trajectories[1:]], ) - # jax.debug.breakpoint() + # get gradients of opponents other_gradients = [] for idx in range(len((self.other_agents))): diff --git a/pax/agents/lola/network.py b/pax/agents/lola/network.py index bc32ac44..3c3f9991 100644 --- a/pax/agents/lola/network.py +++ b/pax/agents/lola/network.py @@ -5,8 +5,6 @@ import jax import jax.numpy as jnp -from pax import utils - class CategoricalValueHead(hk.Module): """Network head that produces a categorical distribution and value.""" diff --git a/pax/agents/mfos_ppo/networks.py b/pax/agents/mfos_ppo/networks.py index 6066251c..cb5397a1 100644 --- a/pax/agents/mfos_ppo/networks.py +++ b/pax/agents/mfos_ppo/networks.py @@ -4,12 +4,13 @@ import haiku as hk import jax import jax.numpy as jnp +from distrax import Categorical from pax import utils class ActorCriticMFOS(hk.Module): - def __init__(self, num_values, hidden_size): + def __init__(self, num_values, hidden_size, categorical=True): super().__init__(name="ActorCriticMFOS") self.linear_t_0 = hk.Linear(hidden_size) self.linear_a_0 = hk.Linear(hidden_size) @@ -30,6 +31,7 @@ def __init__(self, num_values, hidden_size): w_init=hk.initializers.Orthogonal(1.0), # baseline with_bias=False, ) + self._categorical = categorical def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray): input, th = inputs @@ -54,7 +56,10 @@ def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray): hidden = jnp.concatenate([hidden_t, hidden_a, hidden_v], axis=-1) state = (_current_th, hidden) - return (distrax.Categorical(logits=logits), value, state) + if self._categorical: + return distrax.Categorical(logits=logits), value, state + else: + return distrax.MultivariateNormalDiag(loc=logits), value, state class CNNFusion(hk.Module): @@ -147,6 +152,21 @@ def forward_fn( return network, hidden_state +def make_mfos_continuous_network(num_actions: int, hidden_size: int): + hidden_state = jnp.zeros((1, 3 * hidden_size)) + + def forward_fn( + inputs: jnp.ndarray, + state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + mfos = ActorCriticMFOS(num_actions, hidden_size, categorical=False) + logits, values, state = mfos(inputs, state) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network, hidden_state + + def make_mfos_ipditm_network( num_actions: int, hidden_size: int, output_channels, kernel_shape ): @@ -155,7 +175,7 @@ def make_mfos_ipditm_network( def forward_fn( inputs: jnp.ndarray, state: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: mfos = CNNMFOS(num_actions, hidden_size, output_channels, kernel_shape) logits, values, state = mfos(inputs, state) return (logits, values), state diff --git a/pax/agents/mfos_ppo/ppo_gru.py b/pax/agents/mfos_ppo/ppo_gru.py index 2191043e..2e130fa7 100644 --- a/pax/agents/mfos_ppo/ppo_gru.py +++ b/pax/agents/mfos_ppo/ppo_gru.py @@ -12,7 +12,9 @@ from pax.agents.mfos_ppo.networks import ( make_mfos_ipditm_network, make_mfos_network, + make_mfos_continuous_network, ) +from pax.envs.rice.rice import Rice from pax.utils import TrainingState, get_advantages @@ -559,6 +561,16 @@ def make_mfos_agent( action_spec, agent_args.hidden_size, ) + elif args.env_id in ["Cournot", Rice.env_id]: + network, initial_hidden_state = make_mfos_continuous_network( + action_spec, + agent_args.hidden_size, + ) + elif args.env_id == "Fishery": + network, initial_hidden_state = make_mfos_continuous_network( + action_spec, + agent_args.hidden_size, + ) elif args.env_id == "InTheMatrix": network, initial_hidden_state = make_mfos_ipditm_network( action_spec, @@ -573,13 +585,12 @@ def make_mfos_agent( # Optimizer transition_steps = ( - num_iterations, - *agent_args.num_epochs * agent_args.num_minibatches, + num_iterations * agent_args.num_epochs * agent_args.num_minibatches ) if agent_args.lr_scheduling: scheduler = optax.linear_schedule( - init_value=agent_args.ppo.learning_rate, + init_value=agent_args.learning_rate, end_value=0, transition_steps=transition_steps, ) diff --git a/pax/agents/naive/buffer.py b/pax/agents/naive/buffer.py index a9546e28..8da906a0 100644 --- a/pax/agents/naive/buffer.py +++ b/pax/agents/naive/buffer.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -import numpy as np from dm_env import TimeStep @@ -105,7 +104,7 @@ def size(self) -> int: return min(self._rollout_length, self._num_added) def fraction_filled(self) -> float: - return self.size / self._rollout_length + return self.size() / self._rollout_length def reset(self): """Resets the replay buffer. Called upon __init__ and when buffer is full""" diff --git a/pax/agents/naive/naive.py b/pax/agents/naive/naive.py index a45e3f95..4eeaf9fb 100644 --- a/pax/agents/naive/naive.py +++ b/pax/agents/naive/naive.py @@ -6,11 +6,14 @@ import jax import jax.numpy as jnp import optax -from dm_env import TimeStep from pax import utils from pax.agents.agent import AgentInterface -from pax.agents.naive.network import make_coingame_network, make_network +from pax.agents.naive.network import ( + make_coingame_network, + make_network, + make_rice_network, +) from pax.utils import MemoryState, TrainingState, get_advantages @@ -397,6 +400,8 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int): if args.env_id == "coin_game": print(f"Making network for {args.env_id} with CNN") network = make_coingame_network(action_spec, args) + elif args.env_id == "Rice-N": + network = make_rice_network(action_spec) else: network = make_network(action_spec) diff --git a/pax/agents/naive/network.py b/pax/agents/naive/network.py index fa74751e..467f0d33 100644 --- a/pax/agents/naive/network.py +++ b/pax/agents/naive/network.py @@ -27,7 +27,48 @@ def __init__( def __call__(self, inputs: jnp.ndarray): logits = self._logit_layer(inputs) value = jnp.squeeze(self._value_layer(inputs), axis=-1) - return (distrax.Categorical(logits=logits), value) + return distrax.Categorical(logits=logits), value + + +class ContinuousValueHead(hk.Module): + """Network head that produces a continuous distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + mean_activation: Optional[str] = None, + ): + super().__init__(name=name) + self.mean_action = mean_activation + self._mean_layer = hk.Linear( + num_values, + w_init=hk.initializers.Orthogonal(0.01), + with_bias=False, + ) + self._scale_layer = hk.Linear( + num_values, + w_init=hk.initializers.Orthogonal(0.01), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Orthogonal(1.0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + if self.mean_action == "sigmoid": + means = jax.nn.sigmoid(self._mean_layer(inputs)) + else: + means = self._mean_layer(inputs) + scales = self._scale_layer(inputs) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + scales = jnp.maximum(scales, 0.01) + return ( + distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), + value, + ) class CNN(hk.Module): @@ -81,6 +122,23 @@ def forward_fn(inputs): return network +def make_rice_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + ContinuousValueHead(num_values=num_actions), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + def make_coingame_network(num_actions: int, args): def forward_fn(inputs): layers = [] diff --git a/pax/agents/ppo/batched_envs.py b/pax/agents/ppo/batched_envs.py index 2497f302..7713a1e0 100644 --- a/pax/agents/ppo/batched_envs.py +++ b/pax/agents/ppo/batched_envs.py @@ -30,7 +30,7 @@ def step(self, actions: jnp.ndarray) -> TimeStep: rewards = [] observations = [] discounts = [] - for env, action in zip(self.envs, actions): + for env, action in zip(self.envs, actions, strict=True): t = env.step(int(action)) if t.step_type == 2: t_reset = env.reset() diff --git a/pax/agents/ppo/buffer.py b/pax/agents/ppo/buffer.py index 8fe601de..825f7c80 100644 --- a/pax/agents/ppo/buffer.py +++ b/pax/agents/ppo/buffer.py @@ -104,7 +104,7 @@ def size(self) -> int: return min(self._rollout_length, self._num_added) def fraction_filled(self) -> float: - return self.size / self._rollout_length + return self.size() / self._rollout_length def reset(self): """Resets the replay buffer. Called upon __init__ and when buffer is full""" diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index 1461d15a..9bd123a2 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -1,11 +1,15 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Any import distrax import haiku as hk import jax import jax.numpy as jnp +import jmp +from distrax import MultivariateNormalDiag, Categorical +from jax import Array from pax import utils +from pax.utils import float_precision class CategoricalValueHead(hk.Module): @@ -31,7 +35,7 @@ def __init__( def __call__(self, inputs: jnp.ndarray): logits = self._logit_layer(inputs) value = jnp.squeeze(self._value_layer(inputs), axis=-1) - return (distrax.Categorical(logits=logits), value) + return distrax.Categorical(logits=logits), value class CategoricalValueHead_ipd(hk.Module): @@ -146,7 +150,7 @@ def __call__(self, inputs: jnp.ndarray): value = self._value_body(inputs) value = jnp.squeeze(self._value_layer(value), axis=-1) - return (distrax.Categorical(logits=logits), value) + return distrax.Categorical(logits=logits), value class ContinuousValueHead(hk.Module): @@ -156,23 +160,39 @@ def __init__( self, num_values: int, name: Optional[str] = None, + mean_activation: Optional[str] = None, + with_bias=False, ): super().__init__(name=name) - self._logit_layer = hk.Linear( + self.mean_action = mean_activation + self._mean_layer = hk.Linear( num_values, - w_init=hk.initializers.Orthogonal(0.01), # baseline - with_bias=False, + w_init=hk.initializers.Orthogonal(0.01), + with_bias=with_bias, + ) + self._scale_layer = hk.Linear( + num_values, + w_init=hk.initializers.Orthogonal(0.01), + with_bias=with_bias, ) self._value_layer = hk.Linear( 1, - w_init=hk.initializers.Orthogonal(1.0), # baseline - with_bias=False, + w_init=hk.initializers.Orthogonal(1.0), + with_bias=with_bias, ) def __call__(self, inputs: jnp.ndarray): - logits = self._logit_layer(inputs) + if self.mean_action == "sigmoid": + means = jax.nn.sigmoid(self._mean_layer(inputs)) + else: + means = self._mean_layer(inputs) + scales = self._scale_layer(inputs) value = jnp.squeeze(self._value_layer(inputs), axis=-1) - return (distrax.MultivariateNormalDiag(loc=logits), value) + scales = jnp.maximum(scales, 0.01) + return ( + distrax.MultivariateNormalDiag(loc=means, scale_diag=scales), + value, + ) class Tabular(hk.Module): @@ -336,6 +356,89 @@ def forward_fn(inputs): return network +def make_cournot_network(num_actions: int, hidden_size: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + hk.nets.MLP( + [hidden_size, hidden_size], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ), + ContinuousValueHead( + num_values=num_actions, name="cournot_value_head" + ), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + +def make_fishery_network(num_actions: int, hidden_size: int): + """Continuous action space network with values clipped between 0 and 1""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + hk.nets.MLP( + [hidden_size, hidden_size], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ), + ContinuousValueHead( + num_values=num_actions, name="fishery_value_head" + ), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + +def make_rice_sarl_network(num_actions: int, hidden_size: int): + """Continuous action space network with values clipped between 0 and 1""" + if float_precision == jnp.float16: + policy = jmp.get_policy( + "params=float16,compute=float16,output=float32" + ) + hk.mixed_precision.set_policy(hk.nets.MLP, policy) + + def forward_fn(inputs): + layers = [ + hk.nets.MLP( + [hidden_size, hidden_size], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jax.nn.relu, + ), + ContinuousValueHead( + num_values=num_actions, + name="rice_value_head", + mean_activation="sigmoid", + ), + ] + + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + def make_coingame_network( num_actions: int, tabular: bool, @@ -454,7 +557,7 @@ def make_GRU_ipd_network(num_actions: int, hidden_size: int): def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" gru = hk.GRU(hidden_size) embedding, state = gru(inputs, state) @@ -472,7 +575,7 @@ def make_GRU_cartpole_network(num_actions: int): def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" torso = hk.nets.MLP( [hidden_size, hidden_size], @@ -502,7 +605,7 @@ def make_GRU_coingame_network( def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: if with_cnn: torso = CNN(output_channels, kernel_shape)(inputs) @@ -542,7 +645,7 @@ def make_GRU_ipditm_network( def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray - ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + ) -> Tuple[Tuple[Categorical, jnp.ndarray], jnp.ndarray]: """forward function""" torso = CNN_ipditm(output_channels, kernel_shape) gru = hk.GRU( @@ -566,6 +669,68 @@ def forward_fn( return network, hidden_state +def make_GRU_fishery_network( + num_actions: int, + hidden_size: int, +): + hidden_state = jnp.zeros((1, hidden_size)) + + def forward_fn( + inputs: jnp.ndarray, state: jnp.ndarray + ) -> tuple[tuple[MultivariateNormalDiag, Array], Any]: + """forward function""" + gru = hk.GRU( + hidden_size, + w_i_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + w_h_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + b_init=hk.initializers.Constant(0), + ) + + cvh = ContinuousValueHead(num_values=num_actions) + embedding, state = gru(inputs, state) + logits, values = cvh(embedding) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network, hidden_state + + +def make_GRU_rice_network( + num_actions: int, + hidden_size: int, + v2=False, +): + # if float_precision == jnp.float16: + # policy = jmp.get_policy('params=float16,compute=float16,output=float32') + # hk.mixed_precision.set_policy(hk.GRU, policy) + hidden_state = jnp.zeros((1, hidden_size)) + + def forward_fn( + inputs: jnp.ndarray, state: jnp.ndarray + ) -> tuple[tuple[MultivariateNormalDiag, Array], Any]: + gru = hk.GRU( + hidden_size, + w_i_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + w_h_init=hk.initializers.Orthogonal(jnp.sqrt(1)), + b_init=hk.initializers.Constant(0), + ) + + if v2: + cvh = ContinuousValueHead( + num_values=num_actions, + mean_activation="sigmoid", + with_bias=True, + ) + else: + cvh = ContinuousValueHead(num_values=num_actions) + embedding, state = gru(inputs, state) + logits, values = cvh(embedding) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network, hidden_state + + def test_GRU(): key = jax.random.PRNGKey(seed=0) num_actions = 2 diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index 49fdf7ef..08cb8009 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -11,11 +11,22 @@ from pax.agents.agent import AgentInterface from pax.agents.ppo.networks import ( make_coingame_network, - make_ipd_network, make_ipditm_network, make_sarl_network, + make_cournot_network, + make_fishery_network, + make_rice_sarl_network, +) +from pax.envs.rice.c_rice import ClubRice +from pax.envs.rice.rice import Rice +from pax.envs.rice.sarl_rice import SarlRice +from pax.utils import ( + Logger, + MemoryState, + TrainingState, + get_advantages, + float_precision, ) -from pax.utils import Logger, MemoryState, TrainingState, get_advantages class Batch(NamedTuple): @@ -64,9 +75,11 @@ def policy( """Agent policy to select actions and calculate agent specific information""" key, subkey = jax.random.split(state.random_key) dist, values = network.apply(state.params, observation) - actions = dist.sample(seed=subkey) + # Calculating logprob separately can cause numerical issues + # https://github.com/deepmind/distrax/issues/7 + actions, log_prob = dist.sample_and_log_prob(seed=subkey) mem.extras["values"] = values - mem.extras["log_probs"] = dist.log_prob(actions) + mem.extras["log_probs"] = log_prob mem = mem._replace(extras=mem.extras) state = state._replace(random_key=key) return actions, state, mem @@ -348,7 +361,7 @@ def make_initial_state( dummy_obs[k] = jnp.zeros(shape=v) elif not tabular: - dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = jnp.zeros(shape=obs_spec, dtype=float_precision) dummy_obs = dummy_obs.at[0].set(1) dummy_obs = dummy_obs.at[9].set(1) dummy_obs = dummy_obs.at[18].set(1) @@ -470,10 +483,8 @@ def make_agent( tabular=False, ): """Make PPO agent""" - if args.runner == "sarl": - network = make_sarl_network(action_spec) - elif args.env_id == "coin_game": - print(f"Making network for {args.env_id}") + print(f"Making network for {args.env_id}") + if args.env_id == "coin_game": network = make_coingame_network( action_spec, tabular, @@ -492,9 +503,21 @@ def make_agent( agent_args.output_channels, agent_args.kernel_shape, ) + elif args.env_id == "Cournot": + network = make_cournot_network(action_spec, agent_args.hidden_size) + elif args.env_id == "Fishery": + network = make_fishery_network(action_spec, agent_args.hidden_size) + elif args.env_id == SarlRice.env_id: + network = make_rice_sarl_network(action_spec, agent_args.hidden_size) + elif args.env_id == Rice.env_id: + network = make_rice_sarl_network(action_spec, agent_args.hidden_size) + elif args.env_id == ClubRice.env_id: + network = make_rice_sarl_network(action_spec, agent_args.hidden_size) + elif args.runner == "sarl": + network = make_sarl_network(action_spec) else: - network = make_ipd_network( - action_spec, tabular, agent_args.hidden_size + raise NotImplementedError( + f"No ppo network implemented for env {args.env_id}" ) # Optimizer diff --git a/pax/agents/ppo/ppo_gru.py b/pax/agents/ppo/ppo_gru.py index 2944a9d6..a1f8317b 100644 --- a/pax/agents/ppo/ppo_gru.py +++ b/pax/agents/ppo/ppo_gru.py @@ -14,7 +14,11 @@ make_GRU_coingame_network, make_GRU_ipd_network, make_GRU_ipditm_network, + make_GRU_fishery_network, + make_GRU_rice_network, ) +from pax.envs.rice.rice import Rice +from pax.envs.rice.c_rice import ClubRice from pax.utils import MemoryState, TrainingState, get_advantages # from dm_env import TimeStep @@ -359,7 +363,7 @@ def model_update_epoch( def make_initial_state( key: Any, initial_hidden_state: jnp.ndarray - ) -> TrainingState: + ) -> Tuple[TrainingState, MemoryState]: """Initialises the training state (parameters and optimiser state).""" # We pass through initial_hidden_state so its easy to batch memory @@ -393,7 +397,6 @@ def make_initial_state( }, ) - # @jax.jit def prepare_batch( traj_batch: NamedTuple, done: Any, @@ -524,7 +527,34 @@ def make_gru_agent( network, initial_hidden_state = make_GRU_ipd_network( action_spec, agent_args.hidden_size ) - + elif args.env_id == "Fishery": + network, initial_hidden_state = make_GRU_fishery_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id == "iterated_tensor_game": + network, initial_hidden_state = make_GRU_ipd_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id == "iterated_nplayer_tensor_game": + network, initial_hidden_state = make_GRU_ipd_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id == "Fishery": + network, initial_hidden_state = make_GRU_fishery_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id == "Cournot": + network, initial_hidden_state = make_GRU_fishery_network( + action_spec, agent_args.hidden_size + ) + elif args.env_id in [Rice.env_id, "Rice-v1"]: + network, initial_hidden_state = make_GRU_rice_network( + action_spec, agent_args.hidden_size, args.rice_v2_network + ) + elif args.env_id == ClubRice.env_id: + network, initial_hidden_state = make_GRU_rice_network( + action_spec, agent_args.hidden_size, args.rice_v2_network + ) elif args.env_id == "InTheMatrix": network, initial_hidden_state = make_GRU_ipditm_network( action_spec, @@ -533,6 +563,10 @@ def make_gru_agent( agent_args.output_channels, agent_args.kernel_shape, ) + else: + raise NotImplementedError( + f"No gru network implemented for env {args.env_id}" + ) gru_dim = initial_hidden_state.shape[1] diff --git a/pax/agents/strategies.py b/pax/agents/strategies.py index cda6ebf5..40cb9587 100644 --- a/pax/agents/strategies.py +++ b/pax/agents/strategies.py @@ -1,5 +1,4 @@ from functools import partial -from re import A from typing import Callable, NamedTuple, Union import jax.numpy as jnp diff --git a/pax/agents/tensor_strategies.py b/pax/agents/tensor_strategies.py index 7d77bbed..478ca0fe 100644 --- a/pax/agents/tensor_strategies.py +++ b/pax/agents/tensor_strategies.py @@ -1,13 +1,12 @@ from functools import partial -from re import A -from typing import Callable, NamedTuple +from typing import NamedTuple import jax.numpy as jnp import jax.random from pax.agents.agent import AgentInterface from pax.agents.strategies import initial_state_fun -from pax.utils import Logger, MemoryState, TrainingState +from pax.utils import Logger, MemoryState class TitForTatStrictStay(AgentInterface): @@ -191,8 +190,8 @@ def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray: # 0th state is all c...cc # 1st state is all c...cd # (2**num_playerth)-1 state is d...dd - # we cooperate in states _CCCC (0th and (2**num_player-1)th) and in start state (state 2**num_player)th ) - # otherwise we defect + # we cooperate in states _CCCC (0th and (2**num_player-1)th) + # and in start state (state 2**num_player)th ) otherwise we defect # 0 is cooperate, 1 is defect action = jnp.where(obs % jnp.exp2(num_players - 1) == 0, 0, 1) return action diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 36b25085..e7fabcc0 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -8,18 +8,27 @@ hydra: root: level: INFO -# Global variables +# Global variables seed: 0 save_dir: "./exp/${wandb.group}/${wandb.name}" save: True save_interval: 100 debug: False -# Agents + +# Agents num_players: 2 num_shapers: 1 + +# Agents agent1: 'PPO' agent2: 'PPO' +shuffle_players: False +# Disable agent 2 learning in an eval setting +agent2_learning: True +agent1_roles: 1 +agent2_roles: 1 # Make agent 2 assume multiple roles in an n-player game +agent2_reset_interval: 1 # Reset agent 2 every rollout # Logging setup wandb: @@ -28,3 +37,5 @@ wandb: group: ?? name: run-seed-${seed} log: True + mode: online + tags: [] diff --git a/pax/conf/experiment/c_rice/debug.yaml b/pax/conf/experiment/c_rice/debug.yaml new file mode 100644 index 00000000..a34b6ccc --- /dev/null +++ b/pax/conf/experiment/c_rice/debug.yaml @@ -0,0 +1,81 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: C-Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 2 +num_inner_steps: 20 +num_iters: 1 +num_devices: 1 +num_steps: 4 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-rice + group: 'mediator' + mode: 'offline' + name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}' + log: False + + diff --git a/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml new file mode 100644 index 00000000..9fe6b239 --- /dev/null +++ b/pax/conf/experiment/c_rice/eval_mediator_gs_ppo.yaml @@ -0,0 +1,93 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' +agent2_roles: 5 + +# Environment +env_id: C-Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: eval +rice_v2_network: True +shuffle_agents: False + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 2000 +num_iters: 2000 +num_devices: 1 +num_steps: 10 + +# Train to convergence +agent2_reset_interval: 1000 +# Regular mediator +#run_path: chrismatix/c-rice/runs/3w7d59ug +#model_path: exp/mediator/c_rice-mediator-gs-ppo-interval10_seed0/2023-10-09_17.00.59.872280/generation_1499 + +# Climate objective +run_path: chrismatix/c-rice/runs/ovss1ahd +model_path: exp/mediator/c-rice-mediator-GS-PPO_memory-seed-0-climate-obj/2023-10-14_17.23.35.878225/generation_1499 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-rice + group: 'eval' + name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/marl_baseline.yaml b/pax/conf/experiment/c_rice/marl_baseline.yaml new file mode 100644 index 00000000..5c30f380 --- /dev/null +++ b/pax/conf/experiment/c_rice/marl_baseline.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +# Agents +agent_default: 'PPO' + +# Environment +env_id: C-Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: evo +rice_v2_network: True + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 20 +num_iters: 2000 +num_devices: 1 +num_steps: 200 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# Logging setup +wandb: + project: c-rice + group: 'mediator' + name: 'c-rice-MARL-${agent_default}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml new file mode 100644 index 00000000..f5ca57e2 --- /dev/null +++ b/pax/conf/experiment/c_rice/mediator_gs_ppo.yaml @@ -0,0 +1,83 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' +agent2_roles: 5 + +# Environment +env_id: C-Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: evo +rice_v2_network: True +agent2_reset_interval: 10 + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 250 +num_inner_steps: 200 +num_iters: 1500 +num_devices: 1 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-rice + group: 'mediator' + name: 'c-rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/shaper_v_ppo.yaml b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml new file mode 100644 index 00000000..81d2edbf --- /dev/null +++ b/pax/conf/experiment/c_rice/shaper_v_ppo.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +agent2_roles: 4 + +# Environment +env_id: C-Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: False +config_folder: pax/envs/rice/5_regions +runner: evo +rice_v2_network: True + +default_club_mitigation_rate: 0.1 +default_club_tariff_rate: 0.1 + +agent2_reset_interval: 10 + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 500 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: c-rice + group: 'shaper' + name: 'c-rice-SHAPER-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/c_rice/weight_sharing.yaml b/pax/conf/experiment/c_rice/weight_sharing.yaml new file mode 100644 index 00000000..e7e7f9f9 --- /dev/null +++ b/pax/conf/experiment/c_rice/weight_sharing.yaml @@ -0,0 +1,53 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: C-Rice-N +env_type: sequential +num_players: 5 +has_mediator: False +config_folder: pax/envs/rice/5_regions +runner: weight_sharing +rice_v2_network: True + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 20 +num_iters: 6e6 +save_interval: 100 +num_steps: 2000 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: c-rice + group: 'weight_sharing' + name: 'c-rice-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/conf/experiment/cournot/gs_v_ppo.yaml b/pax/conf/experiment/cournot/gs_v_ppo.yaml new file mode 100644 index 00000000..40660332 --- /dev/null +++ b/pax/conf/experiment/cournot/gs_v_ppo.yaml @@ -0,0 +1,77 @@ +# @package _global_ + +# Agents +# Agent default applies to all agents +agent_default: 'PPO' + +# Environment +env_id: Cournot +env_type: meta +a: 100 +b: 1 +marginal_cost: 10 +# This means the nash quantity is 2(a-marginal_cost)/3b = 60 + +# Runner +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game +num_iters: 3000 +num_devices: 1 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: cournot + name: 'cournot-GS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/cournot/marl_baseline.yaml b/pax/conf/experiment/cournot/marl_baseline.yaml new file mode 100644 index 00000000..4a6a56cb --- /dev/null +++ b/pax/conf/experiment/cournot/marl_baseline.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# Agents +agent_default: 'PPO' + +# Environment +env_id: Cournot +env_type: sequential +a: 100 +b: 1 +marginal_cost: 10 +# This means the nash quantity is 2(a-marginal_cost)/3b = 60 +runner: tensor_rl_nplayer + +# env_batch_size = num_envs * num_opponents +num_envs: 20 +num_opps: 1 +num_outer_steps: 1 # This makes it a symmetric game +num_inner_steps: 1 # One-shot game +num_iters: 1e7 + +# Useful information +# batch_size = num_envs * num_inner_steps +# batch_size % num_minibatches == 0 must hold + +# PPO agent parameters +ppo_default: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + +# Logging setup +wandb: + group: 'cournot' + name: 'cournot-MARL^2-${agent1}-vs-${agent2}-parity' + log: True diff --git a/pax/conf/experiment/cournot/shaper_v_ppo.yaml b/pax/conf/experiment/cournot/shaper_v_ppo.yaml new file mode 100644 index 00000000..1d44629d --- /dev/null +++ b/pax/conf/experiment/cournot/shaper_v_ppo.yaml @@ -0,0 +1,100 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent_default: 'PPO' + +# Environment +env_id: Cournot +env_type: meta +a: 100 +b: 1 +marginal_cost: 10 + +# Runner +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game +num_iters: 1000 +num_devices: 1 +num_steps: '${num_inner_steps}' + + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: cournot + group: 'shaper' + name: 'cournot-SHAPER-${num_players}p-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/cournot/weight_sharing.yaml b/pax/conf/experiment/cournot/weight_sharing.yaml new file mode 100644 index 00000000..433df3b6 --- /dev/null +++ b/pax/conf/experiment/cournot/weight_sharing.yaml @@ -0,0 +1,100 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: Cournot +env_type: meta +a: 100 +b: 1 +marginal_cost: 10 + +# Runner +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 300 +num_inner_steps: 1 # One-shot game +num_iters: 1000 +num_devices: 1 +num_steps: '${num_inner_steps}' + + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: cournot + group: 'weight_sharing' + name: 'cournot-weight_sharing-${num_players}p-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/fishery/eval_gs_v_ppo.yaml b/pax/conf/experiment/fishery/eval_gs_v_ppo.yaml new file mode 100644 index 00000000..a0e86b87 --- /dev/null +++ b/pax/conf/experiment/fishery/eval_gs_v_ppo.yaml @@ -0,0 +1,109 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 + + +# Runner +runner: eval + +# TODO +run_path: chrismatix/thesis/dhzxkw57 +model_path: exp/fishery/fishery-GS-PPO-vs-PPO_memory/2023-08-01_20.27.07.402547/generation_900 + + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 1 +num_steps: 600 # Run num_steps // num_inner_steps trials +num_inner_steps: 300 +num_iters: 2 +num_devices: 1 + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + group: eval + project: fishery + name: 'EVAL_fishery-GS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/fishery/eval_marl_baseline.yaml b/pax/conf/experiment/fishery/eval_marl_baseline.yaml new file mode 100644 index 00000000..9900420d --- /dev/null +++ b/pax/conf/experiment/fishery/eval_marl_baseline.yaml @@ -0,0 +1,79 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memoruy' +agent2: 'PPO' + +# Environment +env_id: Fishery +env_type: sequential +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 + +# Runner +runner: eval + +# TODO +run_path: chrismatix/thesis/dhzxkw57 +model_path: exp/fishery/fishery-GS-PPO-vs-PPO_memory/2023-08-01_20.27.07.402547/generation_900 + +# env_batch_size = num_envs * num_opponents +num_envs: 100 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 300 # number of inner steps (only for MetaFinite Env) +num_iters: 1e6 + +# Useful information +# batch_size = num_envs * num_steps + +# PPO agent parameters +ppo1: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + +ppo2: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + +# Logging setup +wandb: + project: fishery + name: 'fishery-MARL^2-${agent1}-vs-${agent2}-parity' + log: True diff --git a/pax/conf/experiment/fishery/eval_weight_sharing_v_ppo.yaml b/pax/conf/experiment/fishery/eval_weight_sharing_v_ppo.yaml new file mode 100644 index 00000000..a8d34ead --- /dev/null +++ b/pax/conf/experiment/fishery/eval_weight_sharing_v_ppo.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 +# Runner +runner: eval + +run_path: chrismatix/fishery/runs/8ux5pgbx +model_path: exp/weight_sharing/fishery-weight_sharing-PPO_memory-seed-0/2023-09-30_21.05.25.064114/iteration_59999 + + +# env_batch_size = num_envs * num_opponents +num_devices: 1 +num_envs: 10 +num_opps: 1 +num_inner_steps: 300 +num_iters: 500 +num_outer_steps: 2 +num_steps: 4200 + +agent2_reset_interval: 500 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: fishery + group: 'eval' + name: 'eval-fishery-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/conf/experiment/fishery/gs_v_ppo.yaml b/pax/conf/experiment/fishery/gs_v_ppo.yaml new file mode 100644 index 00000000..849fbd0b --- /dev/null +++ b/pax/conf/experiment/fishery/gs_v_ppo.yaml @@ -0,0 +1,103 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 + + +# Runner +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 300 +num_iters: 1000 +num_devices: 1 + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: fishery + group: gs + name: 'fishery-GS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/fishery/marl_baseline.yaml b/pax/conf/experiment/fishery/marl_baseline.yaml new file mode 100644 index 00000000..f0faa91b --- /dev/null +++ b/pax/conf/experiment/fishery/marl_baseline.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 + +# This means the optimum quantity is 2(a-marginal_cost)/3b = 60 +runner: evo + +# env_batch_size = num_envs * num_opponents +num_envs: 100 +num_opps: 1 +num_outer_steps: 2000 +num_inner_steps: 300 # number of inner steps (only for MetaFinite Env) +num_iters: 1500 +num_devices: 1 +num_steps: 1500 # number of steps per meta-episode + + +# Useful information +# batch_size = num_envs * num_steps + +# PPO agent parameters +ppo_default: + num_minibatches: 10 + num_epochs: 4 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.25e9 + entropy_coeff_end: 0.05 + lr_scheduling: True + learning_rate: 3e-4 + adam_epsilon: 1e-5 + with_memory: True + hidden_size: 16 + with_cnn: False + +# Logging setup +wandb: + project: fishery + name: 'fishery-marl_baseline-${agent1}-vs-${agent2}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/fishery/mfos_v_ppo.yaml b/pax/conf/experiment/fishery/mfos_v_ppo.yaml new file mode 100644 index 00000000..78e4aa97 --- /dev/null +++ b/pax/conf/experiment/fishery/mfos_v_ppo.yaml @@ -0,0 +1,101 @@ +# @package _global_ + +# Agents +agent1: 'MFOS' +agent2: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 + +# Runner +runner: evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 300 +num_iters: 1000 +num_devices: 1 +num_steps: 1500 + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: True # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: fishery + group: mfos + name: 'fishery-MFOS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/fishery/shaper_v_ppo.yaml b/pax/conf/experiment/fishery/shaper_v_ppo.yaml new file mode 100644 index 00000000..a44819b6 --- /dev/null +++ b/pax/conf/experiment/fishery/shaper_v_ppo.yaml @@ -0,0 +1,102 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 +# Runner +runner: evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 4%0 +num_inner_steps: 300 +num_iters: 1500 +num_devices: 1 +num_steps: 1100 + + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: fishery + group: shaper + name: 'fishery-SHAPER-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/fishery/weight_sharing.yaml b/pax/conf/experiment/fishery/weight_sharing.yaml new file mode 100644 index 00000000..930dad6e --- /dev/null +++ b/pax/conf/experiment/fishery/weight_sharing.yaml @@ -0,0 +1,56 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: Fishery +env_type: meta +g: 0.15 +e: 0.009 +P: 200 +w: 0.9 +s_0: 0.5 +s_max: 1.0 +# Runner +runner: weight_sharing + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 300 +num_iters: 4e6 +save_interval: 100 +num_steps: 2100 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: fishery + group: 'weight_sharing' + name: 'fishery-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml b/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml index aa176d22..c486d240 100644 --- a/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml +++ b/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'Hyper' agent2: 'NaiveEx' @@ -9,12 +9,11 @@ env_id: infinite_matrix_game env_type: meta env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] -runner: rl +runner: rl # Training hyperparameters num_envs: 4000 num_opps: 1 -num_outer_steps: 1 num_outer_steps: 100 num_iters: 1e4 @@ -24,7 +23,7 @@ num_iters: 1e4 # PPO agent parameters ppo: num_minibatches: 1 - num_epochs: 4 + num_epochs: 4 gamma: 0.96 gae_lambda: 0.99 ppo_clipping_epsilon: 0.2 @@ -33,7 +32,7 @@ ppo: max_gradient_norm: 0.5 anneal_entropy: False entropy_coeff_start: 0.01 - entropy_coeff_horizon: 0.4e9 + entropy_coeff_horizon: 0.4e9 entropy_coeff_end: 0.001 lr_scheduling: False learning_rate: 4e-3 diff --git a/pax/conf/experiment/ipd/mfos_v_ppo.yaml b/pax/conf/experiment/ipd/mfos_v_ppo.yaml index d984861f..31c76f87 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Agents +# Agents agent1: 'MFOS' agent2: 'PPO' @@ -10,16 +10,16 @@ env_type: meta env_discount: 0.96 payoff: [[-1, -1], [-3, 0], [0, -3], [-2, -2]] -# Runner -runner: evo +# Runner +runner: evo # Training top_k: 5 popsize: 1000 num_envs: 100 num_opps: 1 -num_outer_steps: 100 -num_inner_steps: 100 +num_outer_steps: 100 +num_inner_steps: 100 num_iters: 5000 num_devices: 2 @@ -31,7 +31,7 @@ ppo1: gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 value_coeff: 0.5 - clip_value: True + clip_value: True max_gradient_norm: 0.5 anneal_entropy: False entropy_coeff_start: 0.02 @@ -44,8 +44,8 @@ ppo1: with_cnn: False hidden_size: 16 -# ES parameters -es: +# ES parameters +es: algo: OpenES # [OpenES, CMA_ES] sigma_init: 0.04 # Initial scale of isotropic Gaussian noise sigma_decay: 0.999 # Multiplicative decay factor @@ -61,11 +61,11 @@ es: beta_2: 0.999 # Adam - beta_2 eps: 1e-8 # eps constant, centered_rank: False # Fitness centered_rank - w_decay: 0 # Decay old elite fitness - maximise: True # Maximise fitness - z_score: True # Normalise fitness + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: True # Normalise fitness mean_reduce: True # Remove mean - + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/rice/debug.yaml b/pax/conf/experiment/rice/debug.yaml new file mode 100644 index 00000000..ed3354c4 --- /dev/null +++ b/pax/conf/experiment/rice/debug.yaml @@ -0,0 +1,82 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: tensor_evo +rice_v2_network: True + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 2 +num_inner_steps: 20 +num_iters: 1 +num_devices: 1 +num_steps: 4 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'mediator' + mode: 'offline' + name: 'rice-mediator-GS-${agent_default}-seed-${seed}' + log: False + + diff --git a/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml b/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml new file mode 100644 index 00000000..ceca98d7 --- /dev/null +++ b/pax/conf/experiment/rice/eval_shaper_v_ppo.yaml @@ -0,0 +1,100 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +agent2_roles: 4 + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: True +config_folder: pax/envs/rice/5_regions +rice_v2_network: True + +runner: eval + +run_path: chrismatix/rice/runs/yg67hb4e +model_path: exp/shaper/rice-SHAPER-PPO_memory-seed-0-interval_20/2023-10-09_12.14.28.778753/generation_1499 + +# Better run but with old network +#run_path: chrismatix/rice/runs/btpdx3d2 +#model_path: exp/shaper/rice-SHAPER-PPO_memory-seed-0-interval_10/2023-10-03_17.06.36.625352/generation_1499 + +# v2 network sharing +run_path2: chrismatix/rice/runs/l6ug3nod +model_path2: exp/weight_sharing/rice-weight_sharing-PPO_memory-seed-0/2023-10-12_18.54.03.092581/iteration_119999 + +# v1 network weight sharing +#run_path2: chrismatix/rice/runs/ozked2ow +#model_path2: exp/weight_sharing/rice-weight_sharing-PPO_memory-seed-0/2023-09-23_09.07.36.737803/iteration_119999 + +# Training +top_k: 5 +popsize: 1000 +num_devices: 1 +num_envs: 20 +num_opps: 1 +num_inner_steps: 20 +num_outer_steps: 1 +num_iters: 100 +num_steps: 200 + +agent2_reset_interval: 1000 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'eval' + name: 'eval-rice-SHAPER-${agent_default}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/rice/gs_v_ppo.yaml b/pax/conf/experiment/rice/gs_v_ppo.yaml new file mode 100644 index 00000000..daf31c69 --- /dev/null +++ b/pax/conf/experiment/rice/gs_v_ppo.yaml @@ -0,0 +1,100 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO_memory' +agent2_roles: 4 + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: False +config_folder: pax/envs/rice/5_regions +runner: evo + + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 500 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + +# PPO agent parameters +ppo1: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + +# PPO agent parameters +ppo2: + num_minibatches: 4 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.02 + entropy_coeff_horizon: 2000000 + entropy_coeff_end: 0.001 + lr_scheduling: False + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False + with_cnn: False + hidden_size: 16 + + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + name: 'rice-GS-${agent1}-vs-${agent2}' + log: True + + diff --git a/pax/conf/experiment/rice/marl_baseline.yaml b/pax/conf/experiment/rice/marl_baseline.yaml new file mode 100644 index 00000000..bea52f74 --- /dev/null +++ b/pax/conf/experiment/rice/marl_baseline.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# Agents +agent_default: 'PPO' + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 2 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 20 +num_iters: 2000 +num_devices: 1 +num_steps: 200 + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# Logging setup +wandb: + project: rice + group: 'marl_baseline' + name: 'rice-MARL-${agent_default}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/rice/mediator_gs_naive.yaml b/pax/conf/experiment/rice/mediator_gs_naive.yaml new file mode 100644 index 00000000..57317874 --- /dev/null +++ b/pax/conf/experiment/rice/mediator_gs_naive.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'Naive' + +# Environment +env_id: Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 1 +num_inner_steps: 20 +num_iters: 3500 +num_devices: 1 +num_steps: 4 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +naive: + num_minibatches: 1 + num_epochs: 1 + gamma: 1 + gae_lambda: 0.95 + max_gradient_norm: 1.0 + learning_rate: 1.0 + adam_epsilon: 1e-5 + entropy_coeff: 0.0 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'mediator' + name: 'rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/rice/mediator_gs_ppo.yaml b/pax/conf/experiment/rice/mediator_gs_ppo.yaml new file mode 100644 index 00000000..36bc65df --- /dev/null +++ b/pax/conf/experiment/rice/mediator_gs_ppo.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent_default: 'PPO' + +# Environment +env_id: Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 200 +num_inner_steps: 20 +num_iters: 3500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'mediator' + name: 'rice-mediator-GS-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/rice/mediator_shaper.yaml b/pax/conf/experiment/rice/mediator_shaper.yaml new file mode 100644 index 00000000..67e1e2a4 --- /dev/null +++ b/pax/conf/experiment/rice/mediator_shaper.yaml @@ -0,0 +1,79 @@ +# @package _global_ + +# Agents +agent_default: 'PPO_memory' + +# Environment +env_id: Rice-N +env_type: meta +num_players: 6 +has_mediator: True +config_folder: pax/envs/rice/5_regions +runner: tensor_evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 180 +num_inner_steps: 20 +num_iters: 3500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'mediator' + name: 'rice-mediator-SHAPER-${agent_default}-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/rice/mfos_v_ppo.yaml b/pax/conf/experiment/rice/mfos_v_ppo.yaml new file mode 100644 index 00000000..13e5f4f4 --- /dev/null +++ b/pax/conf/experiment/rice/mfos_v_ppo.yaml @@ -0,0 +1,84 @@ +# @package _global_ + +# Agents +agent1: 'MFOS' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +agent2_roles: 4 + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: False +config_folder: pax/envs/rice/5_regions +runner: evo + +# Training +top_k: 5 +popsize: 1000 +num_envs: 4 +num_opps: 1 +num_outer_steps: 500 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'mfos' + name: 'rice-MFOS-seed-${seed}' + log: True + + diff --git a/pax/conf/experiment/rice/sarl.yaml b/pax/conf/experiment/rice/sarl.yaml new file mode 100644 index 00000000..d184b60e --- /dev/null +++ b/pax/conf/experiment/rice/sarl.yaml @@ -0,0 +1,56 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: SarlRice-N +env_type: sequential +num_players: 5 +has_mediator: False +config_folder: pax/envs/rice/5_regions +runner: sarl +rice_v2_network: True + +# fixed_mitigation_rate: 0 Set this to test BAU scenario +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 20 +num_iters: 5e6 +save_interval: 100 +num_steps: 2000 + +# Evaluation +#run_path: ucl-dark/cg/3sp0y2cy +#model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 + +ppo0: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 64 + +# Logging setup +wandb: + project: rice + group: 'sarl' + name: 'rice-SARL-${agent1}' diff --git a/pax/conf/experiment/rice/shaper_v_ppo.yaml b/pax/conf/experiment/rice/shaper_v_ppo.yaml new file mode 100644 index 00000000..52c94f6d --- /dev/null +++ b/pax/conf/experiment/rice/shaper_v_ppo.yaml @@ -0,0 +1,83 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +agent2_roles: 4 + +# Environment +env_id: Rice-N +env_type: meta +num_players: 5 +has_mediator: False +shuffle_players: True +config_folder: pax/envs/rice/5_regions +runner: evo +rice_v2_network: True + +# Training +top_k: 5 +popsize: 1000 +num_envs: 1 +num_opps: 1 +num_outer_steps: 200 +num_inner_steps: 20 +num_iters: 1500 +num_devices: 1 +num_steps: 200 + + +# PPO agent parameters +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 32 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.01 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + centered_rank: False # Fitness centered_rank + w_decay: 0 # Decay old elite fitness + maximise: True # Maximise fitness + z_score: False # Normalise fitness + mean_reduce: True # Remove mean + +# Logging setup +wandb: + project: rice + group: 'shaper' + name: 'rice-SHAPER-${agent_default}-seed-${seed}' + log: True diff --git a/pax/conf/experiment/rice/weight_sharing.yaml b/pax/conf/experiment/rice/weight_sharing.yaml new file mode 100644 index 00000000..594d3185 --- /dev/null +++ b/pax/conf/experiment/rice/weight_sharing.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' +agent_default: 'PPO_memory' + +# Environment +env_id: Rice-N +env_type: sequential +num_players: 5 +has_mediator: False +config_folder: pax/envs/rice/5_regions +runner: weight_sharing +rice_v2_network: True +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_inner_steps: 20 +num_iters: 6e6 +save_interval: 100 +num_steps: 2000 + + +ppo_default: + num_minibatches: 4 + num_epochs: 4 + gamma: 1.0 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.0 + entropy_coeff_horizon: 10000000 + entropy_coeff_end: 0.0 + lr_scheduling: True + learning_rate: 1e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [ 3, 3 ] + separate: True + hidden_size: 32 + + +# Logging setup +wandb: + project: rice + group: 'weight_sharing' + name: 'rice-weight_sharing-${agent1}-seed-${seed}' diff --git a/pax/envs/cournot.py b/pax/envs/cournot.py new file mode 100644 index 00000000..83f91513 --- /dev/null +++ b/pax/envs/cournot.py @@ -0,0 +1,109 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.numpy as jnp +from gymnax.environments import environment, spaces + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + + +@chex.dataclass +class EnvParams: + a: int + b: float + marginal_cost: float + + +class CournotGame(environment.Environment): + def __init__(self, num_players: int, num_inner_steps: int): + super().__init__() + self.num_players = num_players + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, + ): + assert len(actions) == num_players + t = state.outer_t + done = t >= num_inner_steps + key, _ = jax.random.split(key, 2) + + actions = jnp.asarray(actions).squeeze() + actions = jnp.clip(actions, a_min=0) + p = params.a - params.b * actions.sum() + + all_obs = [] + all_rewards = [] + for i in range(num_players): + q = actions[i] + obs = jnp.concatenate([actions, jnp.array([p])]) + all_obs.append(obs) + + r = p * q - params.marginal_cost * q + all_rewards.append(r) + + state = EnvState( + inner_t=state.inner_t + 1, outer_t=state.outer_t + 1 + ) + + return ( + tuple(all_obs), + state, + tuple(all_rewards), + done, + {}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[Tuple, EnvState]: + state = EnvState( + inner_t=jnp.zeros((), dtype=jnp.int8), + outer_t=jnp.zeros((), dtype=jnp.int8), + ) + obs = jax.random.uniform(key, (num_players + 1,)) + obs = jnp.concatenate([obs]) + return tuple([obs for _ in range(num_players)]), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + @property + def name(self) -> str: + """Environment name.""" + return "Cournot-v1" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 1 + + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: + """Action space of the environment.""" + return spaces.Box(low=0, high=float("inf"), shape=(1,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + """Observation space of the environment.""" + return spaces.Box( + low=0, + high=float("inf"), + shape=self.num_players + 1, + dtype=jnp.float32, + ) + + @staticmethod + def nash_policy(params: EnvParams) -> float: + return 2 * (params.a - params.marginal_cost) / (3 * params.b) + + @staticmethod + def nash_reward(params: EnvParams) -> float: + q = CournotGame.nash_policy(params) + p = params.a - params.b * q + return p * q - params.marginal_cost * q diff --git a/pax/envs/fishery.py b/pax/envs/fishery.py new file mode 100644 index 00000000..c86dc23e --- /dev/null +++ b/pax/envs/fishery.py @@ -0,0 +1,145 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +from gymnax.environments import environment, spaces + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + s: float + + +@chex.dataclass +class EnvParams: + g: float + e: float + P: float + w: float + s_0: float + s_max: float + + +def to_obs_array(params: EnvParams) -> jnp.ndarray: + return jnp.array([params.g, params.e, params.P, params.w]) + + +""" +Based off the open access fishery environment from Perman et al. (3rd edition) + +Biological sub-model: + +The stock of fish is given by G and grows at a natural rate of g. +The stock is harvested at a rate of E (sum over agent actions) and the harvest is sold at a price of p. + +g growth rate +e efficiency parameter +P price of fish +w cost per unit of effort +s_0 initial stock TODO randomize? +s_max maximum stock size +""" + + +class Fishery(environment.Environment): + env_id: str = "Fishery" + + def __init__(self, num_players: int, num_inner_steps: int): + super().__init__() + self.num_players = num_players + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, + ): + t = state.inner_t + 1 + key, _ = jax.random.split(key, 2) + done = t >= num_inner_steps + + actions = jnp.asarray(actions).squeeze() + actions = jnp.clip(actions, a_min=0) + E = actions.sum() + s_new = params.g * state.s * (1 - state.s / params.s_max) + s_growth = state.s + s_new + + # Prevent s from dropping below 0 + H = jnp.clip(E * state.s * params.e, a_max=s_growth) + s_next = s_growth - H + next_state = EnvState(inner_t=t, outer_t=state.outer_t, s=s_next) + reset_obs, reset_state = _reset(key, params) + reset_state = reset_state.replace(outer_t=state.outer_t + 1) + + all_obs = [] + all_rewards = [] + for i in range(num_players): + obs = jnp.concatenate([actions, jnp.array([s_next])]) + obs = jax.lax.select(done, reset_obs[i], obs) + all_obs.append(obs) + + e = actions[i] + # reward = benefit - cost + # = P * H - w * E + r = jnp.where(E != 0, params.P * e / E * H - params.w * e, 0) + all_rewards.append(r) + + state = jax.tree_map( + lambda x, y: jax.lax.select(done, x, y), + reset_state, + next_state, + ) + + return ( + tuple(all_obs), + state, + tuple(all_rewards), + done, + { + "H": H, + "E": E, + "growth": s_new, + "cost": params.w * E, + }, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[Tuple, EnvState]: + state = EnvState( + inner_t=jnp.zeros((), dtype=jnp.int16), + outer_t=jnp.zeros((), dtype=jnp.int16), + s=params.s_0, + ) + obs = jax.random.uniform(key, (num_players,)) + obs = jnp.concatenate([obs, jnp.array([state.s])]) + return tuple([obs for _ in range(num_players)]), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return 1 + + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: + """Action space of the environment.""" + return spaces.Box(low=0, high=params.s_max, shape=(1,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + """Observation space of the environment.""" + return spaces.Box( + low=0, + high=float("inf"), + shape=self.num_players + 1, + dtype=jnp.float32, + ) + + @staticmethod + def equilibrium(params: EnvParams) -> float: + return params.s_max * (1 - params.g / params.e / params.P) diff --git a/pax/envs/infinite_matrix_game.py b/pax/envs/infinite_matrix_game.py index 1d5b5a0d..33ea79d5 100644 --- a/pax/envs/infinite_matrix_game.py +++ b/pax/envs/infinite_matrix_game.py @@ -83,7 +83,7 @@ def _step( def _reset( key: chex.PRNGKey, params: EnvParams - ) -> Tuple[chex.Array, EnvState]: + ) -> Tuple[Tuple, EnvState]: state = EnvState( inner_t=jnp.zeros((), dtype=jnp.int8), outer_t=jnp.zeros((), dtype=jnp.int8), @@ -104,9 +104,7 @@ def num_actions(self) -> int: """Number of actions possible in environment.""" return 5 - def action_space( - self, params: Optional[EnvParams] = None - ) -> spaces.Discrete: + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: """Action space of the environment.""" return spaces.Box(low=0, high=1, shape=(5,)) @@ -114,6 +112,6 @@ def observation_space(self, params: EnvParams) -> spaces.Box: """Observation space of the environment.""" return spaces.Box(0, 1, (10,), dtype=jnp.float32) - def state_space(self, params: EnvParams) -> spaces.Dict: + def state_space(self, params: EnvParams) -> spaces.Box: """State space of the environment.""" return spaces.Box(0, 1, (10,), dtype=jnp.float32) diff --git a/pax/envs/iterated_matrix_game.py b/pax/envs/iterated_matrix_game.py index 04403b0e..218edc07 100644 --- a/pax/envs/iterated_matrix_game.py +++ b/pax/envs/iterated_matrix_game.py @@ -85,7 +85,7 @@ def _step( def _reset( key: chex.PRNGKey, params: EnvParams - ) -> Tuple[chex.Array, EnvState]: + ) -> Tuple[Tuple, EnvState]: state = EnvState( inner_t=jnp.zeros((), dtype=jnp.int8), outer_t=jnp.zeros((), dtype=jnp.int8), @@ -113,10 +113,10 @@ def action_space( """Action space of the environment.""" return spaces.Discrete(4) - def observation_space(self, params: EnvParams) -> spaces.Box: + def observation_space(self, params: EnvParams) -> spaces.Discrete: """Observation space of the environment.""" return spaces.Discrete(5) - def state_space(self, params: EnvParams) -> spaces.Dict: + def state_space(self, params: EnvParams) -> spaces.Discrete: """State space of the environment.""" return spaces.Discrete(5) diff --git a/pax/envs/iterated_tensor_game_n_player.py b/pax/envs/iterated_tensor_game_n_player.py index ef0267fe..8335228f 100644 --- a/pax/envs/iterated_tensor_game_n_player.py +++ b/pax/envs/iterated_tensor_game_n_player.py @@ -1,9 +1,8 @@ -from typing import Optional, Tuple +from typing import Tuple import chex import jax import jax.numpy as jnp -from flax import struct from gymnax.environments import environment, spaces @@ -107,7 +106,7 @@ def _step( def _reset( key: chex.PRNGKey, params: EnvParams - ) -> Tuple[chex.Array, EnvState]: + ) -> Tuple[Tuple[chex.Array, ...], EnvState]: state = EnvState( inner_t=jnp.zeros((), dtype=jnp.int8), outer_t=jnp.zeros((), dtype=jnp.int8), diff --git a/pax/envs/rice/27_regions/11.yml b/pax/envs/rice/27_regions/11.yml new file mode 100644 index 00000000..43a27ea8 --- /dev/null +++ b/pax/envs/rice/27_regions/11.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 1.8724419952820714 + xK_0: 0.239419592 + xL_0: 476.878017 + xL_a: 669.593553 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.13880429539557834 + xg_A: 0.12202134941105497 + xgamma: 0.3 + xl_g: 0.034238352160625596 + xsigma_0: 0.4559257467059924 diff --git a/pax/envs/rice/27_regions/12.yml b/pax/envs/rice/27_regions/12.yml new file mode 100644 index 00000000..c1a0a6b7 --- /dev/null +++ b/pax/envs/rice/27_regions/12.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.405493223457656 + xK_0: 3.30354611 + xL_0: 68.394527 + xL_a: 93.497311 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1880269001436297 + xg_A: 0.10300420806704261 + xgamma: 0.3 + xl_g: 0.05753057218640376 + xsigma_0: 0.5289744017993728 diff --git a/pax/envs/rice/27_regions/13.yml b/pax/envs/rice/27_regions/13.yml new file mode 100644 index 00000000..befbd70e --- /dev/null +++ b/pax/envs/rice/27_regions/13.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.5579000509140952 + xK_0: 0.109143954 + xL_0: 64.122372 + xL_a: 135.074132 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.16127452284439697 + xg_A: 0.12735655631209186 + xgamma: 0.3 + xl_g: 0.02623933488387354 + xsigma_0: 0.8162518983719008 diff --git a/pax/envs/rice/27_regions/14.yml b/pax/envs/rice/27_regions/14.yml new file mode 100644 index 00000000..c27b8b91 --- /dev/null +++ b/pax/envs/rice/27_regions/14.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 1.92663301826947 + xK_0: 1.423908312 + xL_0: 284.698846 + xL_a: 465.307807 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.24445012982169362 + xg_A: 0.1335428337437049 + xgamma: 0.3 + xl_g: 0.024422285778436918 + xsigma_0: 1.220638524516315 diff --git a/pax/envs/rice/27_regions/15.yml b/pax/envs/rice/27_regions/15.yml new file mode 100644 index 00000000..95ebe5c8 --- /dev/null +++ b/pax/envs/rice/27_regions/15.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.111280036435135 + xK_0: 0.268152174 + xL_0: 28.141422 + xL_a: 23.573851 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.16335430971807735 + xg_A: 0.10573757974990125 + xgamma: 0.3 + xl_g: -0.05715547594428186 + xsigma_0: 0.29029694003558093 diff --git a/pax/envs/rice/27_regions/16.yml b/pax/envs/rice/27_regions/16.yml new file mode 100644 index 00000000..ea3ef19f --- /dev/null +++ b/pax/envs/rice/27_regions/16.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 4.217213133650901 + xK_0: 3.18362519 + xL_0: 548.75442 + xL_a: 560.054221 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1703030497846267 + xg_A: 0.09485139864239062 + xgamma: 0.3 + xl_g: 0.08033413573292254 + xsigma_0: 0.3019631318655498 diff --git a/pax/envs/rice/27_regions/17.yml b/pax/envs/rice/27_regions/17.yml new file mode 100644 index 00000000..8e79670d --- /dev/null +++ b/pax/envs/rice/27_regions/17.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.4913586566019945 + xK_0: 0.043635414 + xL_0: 46.488546 + xL_a: 59.987638 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05834835573455604 + xg_A: 0.049004053769246436 + xgamma: 0.3 + xl_g: 0.03709027315241262 + xsigma_0: 0.4196283605267465 diff --git a/pax/envs/rice/27_regions/18.yml b/pax/envs/rice/27_regions/18.yml new file mode 100644 index 00000000..1298357c --- /dev/null +++ b/pax/envs/rice/27_regions/18.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.5248787296824777 + xK_0: 1.080409098 + xL_0: 69.194146 + xL_a: 100.015768 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.3464232696284064 + xg_A: 0.0785686384327884 + xgamma: 0.3 + xl_g: 0.028895835870575235 + xsigma_0: 1.0104732880546095 diff --git a/pax/envs/rice/27_regions/19.yml b/pax/envs/rice/27_regions/19.yml new file mode 100644 index 00000000..b6f3dc25 --- /dev/null +++ b/pax/envs/rice/27_regions/19.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.4596628149703816 + xK_0: 0.183982308 + xL_0: 513.737375 + xL_a: 1867.771496 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 1.8390289471577375 + xg_A: 0.46217845237530203 + xgamma: 0.3 + xl_g: 0.017149514576045286 + xsigma_0: 0.3103140976545981 diff --git a/pax/envs/rice/27_regions/2.yml b/pax/envs/rice/27_regions/2.yml new file mode 100644 index 00000000..e68f20ec --- /dev/null +++ b/pax/envs/rice/27_regions/2.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 12.157936179442062 + xK_0: 2.64167507 + xL_0: 38.101107 + xL_a: 56.990157 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.13084887390535965 + xg_A: 0.06274070897633105 + xgamma: 0.3 + xl_g: 0.020192884216840113 + xsigma_0: 0.35044418275452427 diff --git a/pax/envs/rice/27_regions/20.yml b/pax/envs/rice/27_regions/20.yml new file mode 100644 index 00000000..3210b7da --- /dev/null +++ b/pax/envs/rice/27_regions/20.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 0.9929511285910457 + xK_0: 0.160199062 + xL_0: 522.481879 + xL_a: 1830.325243 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08560686741591728 + xg_A: 0.06506072277236097 + xgamma: 0.3 + xl_g: 0.01902705663391574 + xsigma_0: 0.23517024551671273 diff --git a/pax/envs/rice/27_regions/21.yml b/pax/envs/rice/27_regions/21.yml new file mode 100644 index 00000000..6ff60218 --- /dev/null +++ b/pax/envs/rice/27_regions/21.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 5.000360862831762 + xK_0: 2.289358084859004 + xL_0: 165.293239 + xL_a: 230.19114338372032 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.18278991377259912 + xg_A: 0.07108490122759262 + xgamma: 0.3 + xl_g: 0.026773049602328805 + xsigma_0: 0.4187771240034329 diff --git a/pax/envs/rice/27_regions/22.yml b/pax/envs/rice/27_regions/22.yml new file mode 100644 index 00000000..3a3af1db --- /dev/null +++ b/pax/envs/rice/27_regions/22.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 29.853559456004625 + xK_0: 2.019951041942154 + xL_0: 165.75054 + xL_a: 216.9269455 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757142538145 + xg_A: 0.07541285925058157 + xgamma: 0.3 + xl_g: -0.0024986057450947508 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/rice/27_regions/23.yml b/pax/envs/rice/27_regions/23.yml new file mode 100644 index 00000000..60ce3e51 --- /dev/null +++ b/pax/envs/rice/27_regions/23.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 23.314991608844633 + xK_0: 3.0391651447451187 + xL_0: 109.39535640000001 + xL_a: 143.17178403 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757821836926 + xg_A: 0.07541286104378697 + xgamma: 0.3 + xl_g: -0.002498605753950543 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/rice/27_regions/24.yml b/pax/envs/rice/27_regions/24.yml new file mode 100644 index 00000000..db156cb1 --- /dev/null +++ b/pax/envs/rice/27_regions/24.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 29.853559456004625 + xK_0: 0.6867833542603324 + xL_0: 56.355183600000004 + xL_a: 73.75516147 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08802757142538145 + xg_A: 0.07541285925058157 + xgamma: 0.3 + xl_g: -0.002498605778464897 + xsigma_0: 0.25439108584131914 diff --git a/pax/envs/rice/27_regions/25.yml b/pax/envs/rice/27_regions/25.yml new file mode 100644 index 00000000..43ebae03 --- /dev/null +++ b/pax/envs/rice/27_regions/25.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 10.922004036104973 + xK_0: 0.6059142357084183 + xL_0: 705.464681 + xL_a: 532.496728 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877082389193 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/rice/27_regions/26.yml b/pax/envs/rice/27_regions/26.yml new file mode 100644 index 00000000..0dc50d55 --- /dev/null +++ b/pax/envs/rice/27_regions/26.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 9.633893693772771 + xK_0: 0.6076078389971926 + xL_0: 465.60668946000004 + xL_a: 351.44784048 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877094273714 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/rice/27_regions/27.yml b/pax/envs/rice/27_regions/27.yml new file mode 100644 index 00000000..dd580c1b --- /dev/null +++ b/pax/envs/rice/27_regions/27.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 8.620918323265558 + xK_0: 0.45330037729157585 + xL_0: 239.85799154000003 + xL_a: 181.04888752000002 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09603760623634239 + xg_A: 0.16817991622593043 + xgamma: 0.3 + xl_g: -0.015844877127171766 + xsigma_0: 0.7813181890031158 diff --git a/pax/envs/rice/27_regions/28.yml b/pax/envs/rice/27_regions/28.yml new file mode 100644 index 00000000..8906fe2a --- /dev/null +++ b/pax/envs/rice/27_regions/28.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.1898536850408714 + xK_0: 0.1287514001006796 + xL_0: 690.0021925 + xL_a: 723.512806 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05389348561714375 + xg_A: 0.06812795377170236 + xgamma: 0.3 + xl_g: -0.012597171762552104 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/rice/27_regions/29.yml b/pax/envs/rice/27_regions/29.yml new file mode 100644 index 00000000..8e40829c --- /dev/null +++ b/pax/envs/rice/27_regions/29.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.033527139192083 + xK_0: 0.3810937821808831 + xL_0: 455.40144705 + xL_a: 477.51845196000005 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.053893489919463196 + xg_A: 0.06812795559760368 + xgamma: 0.3 + xl_g: -0.012597171772754604 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/rice/27_regions/3.yml b/pax/envs/rice/27_regions/3.yml new file mode 100644 index 00000000..8298c5ec --- /dev/null +++ b/pax/envs/rice/27_regions/3.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 13.219587477586199 + xK_0: 16.295084052817813 + xL_0: 502.409662 + xL_a: 445.861101 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.25224119422016716 + xg_A: 0.07423569831381745 + xgamma: 0.3 + xl_g: -0.033398145012670695 + xsigma_0: 0.17048017530013193 diff --git a/pax/envs/rice/27_regions/30.yml b/pax/envs/rice/27_regions/30.yml new file mode 100644 index 00000000..c7d7c205 --- /dev/null +++ b/pax/envs/rice/27_regions/30.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.1898536850408714 + xK_0: 0.04377547603423107 + xL_0: 234.60074545 + xL_a: 245.99435404000002 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.05389348561714375 + xg_A: 0.06812795377170236 + xgamma: 0.3 + xl_g: -0.012597171800996251 + xsigma_0: 0.9487399403167854 diff --git a/pax/envs/rice/27_regions/4.yml b/pax/envs/rice/27_regions/4.yml new file mode 100644 index 00000000..aa99fbe7 --- /dev/null +++ b/pax/envs/rice/27_regions/4.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 6.386787299600672 + xK_0: 1.094110266 + xL_0: 317.880267 + xL_a: 287.533185 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.19401001361417486 + xg_A: 0.23666124530625898 + xgamma: 0.3 + xl_g: -0.052512175141607595 + xsigma_0: 0.8402859337043421 diff --git a/pax/envs/rice/27_regions/5.yml b/pax/envs/rice/27_regions/5.yml new file mode 100644 index 00000000..d54301c4 --- /dev/null +++ b/pax/envs/rice/27_regions/5.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.480556451289923 + xK_0: 0.090493838 + xL_0: 94.484285 + xL_a: 102.997258 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.20277540450432627 + xg_A: 0.20063089079847785 + xgamma: 0.3 + xl_g: 0.036907187009908436 + xsigma_0: 1.6646404809736024 diff --git a/pax/envs/rice/27_regions/6.yml b/pax/envs/rice/27_regions/6.yml new file mode 100644 index 00000000..769ac92e --- /dev/null +++ b/pax/envs/rice/27_regions/6.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 10.852953800501595 + xK_0: 17.553847656 + xL_0: 222.891134 + xL_a: 168.350837 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.005000000000006631 + xg_A: -0.00046726965128841526 + xgamma: 0.3 + xl_g: -0.011976043898247184 + xsigma_0: 0.2851271547872655 diff --git a/pax/envs/rice/27_regions/7.yml b/pax/envs/rice/27_regions/7.yml new file mode 100644 index 00000000..f3f99553 --- /dev/null +++ b/pax/envs/rice/27_regions/7.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 4.135420683261054 + xK_0: 1.00243116 + xL_0: 103.2943 + xL_a: 87.417937 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.1577590297489783 + xg_A: 0.12252697231832654 + xgamma: 0.3 + xl_g: -0.06254962519160859 + xsigma_0: 0.6013249328720022 diff --git a/pax/envs/rice/27_regions/9.yml b/pax/envs/rice/27_regions/9.yml new file mode 100644 index 00000000..58ded32e --- /dev/null +++ b/pax/envs/rice/27_regions/9.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.7159049409990375 + xK_0: 1.0340369 + xL_0: 573.818276 + xL_a: 681.210099 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.09686897303646223 + xg_A: 0.10149076832484585 + xgamma: 0.3 + xl_g: 0.04313067107477931 + xsigma_0: 0.6378326935010085 diff --git a/pax/envs/rice/27_regions/default.yml b/pax/envs/rice/27_regions/default.yml new file mode 100644 index 00000000..64df6fa6 --- /dev/null +++ b/pax/envs/rice/27_regions/default.yml @@ -0,0 +1,68 @@ +_DICE_CONSTANT: + xt_0: 2015 # starting year of the whole model + xDelta: 5 # the time interval (year) + xN: 20 # total time steps + + # Climate diffusion parameters + xPhi_T: [ [ 0.8718, 0.0088 ], [ 0.025, 0.975 ] ] + xB_T: [ 0.1005, 0 ] + # xB_T: [0.03, 0] + + # Carbon cycle diffusion parameters (the zeta matrix in the paper) + xPhi_M: [ [ 0.88, 0.196, 0 ], [ 0.12, 0.797, 0.001465 ], [ 0, 0.007, 0.99853488 ] ] + # xB_M: [0.2727272727272727, 0, 0] # 12/44 + xB_M: [ 1.36388, 0, 0 ] # 12/44 + xeta: 3.6813 #?? I don't find where it's used + + xM_AT_1750: 588 # atmospheric mass of carbon in the year of 1750 + xf_0: 0.5 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xf_1: 1 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xt_f: 20 # in Eq 3 time step param to effect of greenhouse gases other than carbon dioxide + xE_L0: 2.6 # 2.6 # in Eq 4 param to the emissions due to land use changes + xdelta_EL: 0.001 # 0.115 # 0.115 # in Eq 4 param to the emissions due to land use changes + + xM_AT_0: 851.0 # in CAP the atmospheric mass of carbon in the year t + xM_UP_0: 460.0 # in CAP the atmospheric upper bound of mass of carbon in the year t + xM_LO_0: 1740.0 # in CAP the atmospheric lower bound of mass of carbon in the year t + xe_0: 35.85 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xq_0: 105.5 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xmu_0: 0.03 # in EI define the initial simga_0: e0/(q0(1-mu0)) + + # From Python implementation PyDICE + xF_2x: 3.6813 # 3.6813 # Forcing that doubles equilibrium carbon. + xT_2x: 3.1 # 3.1 # Equilibrium temperature increase at double carbon eq. + +_RICE_CONSTANT_DEFAULT: + xA_0: 5.115 # in TFP technology at starting point + xK_0: 223 # in CAP initial condition for capital + xL_0: 7403 # in POP population at the staring point + xL_a: 11500 # in POP the expected population at convergence + xa_1: 0 + xa_2: 0.00236 # in CAP Eq 6 + xa_3: 2 # in CAP Eq 6 + xdelta_A: 0.005 # in TFP control the rate of increasing of tech smaller->faster + xg_A: 0.076 # in TFP control the rate of increasing of tech larger->faster + xgamma: 0.3 # in CAP Eq 5 the capital elasticty + xl_g: 0.134 # in POP control the rate to converge + xsigma_0: 0.3503 # e0/(q0(1-mu0)) in EI emission intensity at the starting point + +_RICE_GLOBAL_CONSTANT: + xtheta_2: 2.6 # in CAP Eq 6 + xdelta_K: 0.1 # in CAP Eq 9 param discribe the depreciate of the capital + xalpha: 1.45 # Utility function param + + xrho: 0.015 # discount factor of the utility + + xg_sigma: 0.0025 # 0.0152 # 0.0025 in EI control the rate of mitigation larger->reduce more emission + xdelta_sigma: 0.1 # 0.01 in EI control the rate of mitigation larger->reduce less emission + xp_b: 550 # 550 # in Eq 2 (estimate of the cost of mitigation) represents the price of a backstop technology that can remove carbon dioxide from the atmosphere + xdelta_pb: 0.001 # 0.025 # in Eq 2 control the how the cost of mitigation change through time larger->cost less as time goes by + + xscale_1: 0.030245527 # in Eq 29 Nordhaus scaled cost function param + xscale_2: 10993.704 # in Eq 29 Nordhaus scaled cost function param + + xT_AT_0: 0.85 # in CAP a part of damage function initial condition + xT_LO_0: 0.0068 # in CAP a part of damage function initial condition + r: 0.1 # balance interest rate adjusted for xDelta=5 + + diff --git a/pax/envs/rice/5_regions/1.yml b/pax/envs/rice/5_regions/1.yml new file mode 100644 index 00000000..0fd3e120 --- /dev/null +++ b/pax/envs/rice/5_regions/1.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 12.61284526211837 + xK_0: 67.5363532808178 + xL_0: 1153.559429 + xL_a: 1168.64817 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.2335591082853097 + xg_A: 0.056522215873884896 + xgamma: 0.3 + xl_g: 0.07975982642810232 + xsigma_0: 0.20463581436908657 diff --git a/pax/envs/rice/5_regions/2.yml b/pax/envs/rice/5_regions/2.yml new file mode 100644 index 00000000..12ae3fc1 --- /dev/null +++ b/pax/envs/rice/5_regions/2.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 5.109992330543673 + xK_0: 0.871072684 + xL_0: 285.952538 + xL_a: 280.243849 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.19198426963422988 + xg_A: 0.28691234316031794 + xgamma: 0.3 + xl_g: 0.03894660908604062 + xsigma_0: 1.1576063116121318 diff --git a/pax/envs/rice/5_regions/3.yml b/pax/envs/rice/5_regions/3.yml new file mode 100644 index 00000000..96cc564e --- /dev/null +++ b/pax/envs/rice/5_regions/3.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 3.250119930740304 + xK_0: 27.50443660104453 + xL_0: 4031.540324 + xL_a: 3985.643734 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.08239668593209744 + xg_A: 0.10814096456032099 + xgamma: 0.3 + xl_g: 0.09961653941224989 + xsigma_0: 0.7227875392397705 diff --git a/pax/envs/rice/5_regions/4.yml b/pax/envs/rice/5_regions/4.yml new file mode 100644 index 00000000..29e3aa4a --- /dev/null +++ b/pax/envs/rice/5_regions/4.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 2.045017429173559 + xK_0: 6.256779604 + xL_0: 1545.816991 + xL_a: 4489.488827 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.19819334478483616 + xg_A: 0.11691865387930357 + xgamma: 0.3 + xl_g: 0.019327062449774557 + xsigma_0: 0.6883369609580036 diff --git a/pax/envs/rice/5_regions/5.yml b/pax/envs/rice/5_regions/5.yml new file mode 100644 index 00000000..8c6536f6 --- /dev/null +++ b/pax/envs/rice/5_regions/5.yml @@ -0,0 +1,13 @@ +_RICE_CONSTANT: + xA_0: 4.87298338701072 + xK_0: 2.23253879 + xL_0: 618.130152 + xL_a: 642.301927 + xa_1: 0 + xa_2: 0.00236 + xa_3: 2 + xdelta_A: 0.15629628201213583 + xg_A: 0.09776895934200977 + xgamma: 0.3 + xl_g: 0.07207745988965227 + xsigma_0: 0.3245836451389051 diff --git a/pax/envs/rice/5_regions/default.yml b/pax/envs/rice/5_regions/default.yml new file mode 100644 index 00000000..64df6fa6 --- /dev/null +++ b/pax/envs/rice/5_regions/default.yml @@ -0,0 +1,68 @@ +_DICE_CONSTANT: + xt_0: 2015 # starting year of the whole model + xDelta: 5 # the time interval (year) + xN: 20 # total time steps + + # Climate diffusion parameters + xPhi_T: [ [ 0.8718, 0.0088 ], [ 0.025, 0.975 ] ] + xB_T: [ 0.1005, 0 ] + # xB_T: [0.03, 0] + + # Carbon cycle diffusion parameters (the zeta matrix in the paper) + xPhi_M: [ [ 0.88, 0.196, 0 ], [ 0.12, 0.797, 0.001465 ], [ 0, 0.007, 0.99853488 ] ] + # xB_M: [0.2727272727272727, 0, 0] # 12/44 + xB_M: [ 1.36388, 0, 0 ] # 12/44 + xeta: 3.6813 #?? I don't find where it's used + + xM_AT_1750: 588 # atmospheric mass of carbon in the year of 1750 + xf_0: 0.5 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xf_1: 1 # in Eq 3 param to effect of greenhouse gases other than carbon dioxide + xt_f: 20 # in Eq 3 time step param to effect of greenhouse gases other than carbon dioxide + xE_L0: 2.6 # 2.6 # in Eq 4 param to the emissions due to land use changes + xdelta_EL: 0.001 # 0.115 # 0.115 # in Eq 4 param to the emissions due to land use changes + + xM_AT_0: 851.0 # in CAP the atmospheric mass of carbon in the year t + xM_UP_0: 460.0 # in CAP the atmospheric upper bound of mass of carbon in the year t + xM_LO_0: 1740.0 # in CAP the atmospheric lower bound of mass of carbon in the year t + xe_0: 35.85 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xq_0: 105.5 # in EI define the initial simga_0: e0/(q0(1-mu0)) + xmu_0: 0.03 # in EI define the initial simga_0: e0/(q0(1-mu0)) + + # From Python implementation PyDICE + xF_2x: 3.6813 # 3.6813 # Forcing that doubles equilibrium carbon. + xT_2x: 3.1 # 3.1 # Equilibrium temperature increase at double carbon eq. + +_RICE_CONSTANT_DEFAULT: + xA_0: 5.115 # in TFP technology at starting point + xK_0: 223 # in CAP initial condition for capital + xL_0: 7403 # in POP population at the staring point + xL_a: 11500 # in POP the expected population at convergence + xa_1: 0 + xa_2: 0.00236 # in CAP Eq 6 + xa_3: 2 # in CAP Eq 6 + xdelta_A: 0.005 # in TFP control the rate of increasing of tech smaller->faster + xg_A: 0.076 # in TFP control the rate of increasing of tech larger->faster + xgamma: 0.3 # in CAP Eq 5 the capital elasticty + xl_g: 0.134 # in POP control the rate to converge + xsigma_0: 0.3503 # e0/(q0(1-mu0)) in EI emission intensity at the starting point + +_RICE_GLOBAL_CONSTANT: + xtheta_2: 2.6 # in CAP Eq 6 + xdelta_K: 0.1 # in CAP Eq 9 param discribe the depreciate of the capital + xalpha: 1.45 # Utility function param + + xrho: 0.015 # discount factor of the utility + + xg_sigma: 0.0025 # 0.0152 # 0.0025 in EI control the rate of mitigation larger->reduce more emission + xdelta_sigma: 0.1 # 0.01 in EI control the rate of mitigation larger->reduce less emission + xp_b: 550 # 550 # in Eq 2 (estimate of the cost of mitigation) represents the price of a backstop technology that can remove carbon dioxide from the atmosphere + xdelta_pb: 0.001 # 0.025 # in Eq 2 control the how the cost of mitigation change through time larger->cost less as time goes by + + xscale_1: 0.030245527 # in Eq 29 Nordhaus scaled cost function param + xscale_2: 10993.704 # in Eq 29 Nordhaus scaled cost function param + + xT_AT_0: 0.85 # in CAP a part of damage function initial condition + xT_LO_0: 0.0068 # in CAP a part of damage function initial condition + r: 0.1 # balance interest rate adjusted for xDelta=5 + + diff --git a/pax/envs/rice/c_rice.py b/pax/envs/rice/c_rice.py new file mode 100644 index 00000000..ddf06c12 --- /dev/null +++ b/pax/envs/rice/c_rice.py @@ -0,0 +1,623 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +from gymnax.environments import environment, spaces +from jax import Array + +from pax.envs.rice.rice import ( + EnvState, + EnvParams, + get_consumption, + get_armington_agg, + get_utility, + get_social_welfare, + get_global_temperature, + load_rice_params, + get_exogenous_emissions, + get_mitigation_cost, + get_carbon_intensity, + get_abatement_cost, + get_damages, + get_production, + get_gross_output, + get_investment, + zero_diag, + get_max_potential_exports, + get_land_emissions, + get_aux_m, + get_global_carbon_mass, + get_capital_depreciation, + get_capital, + get_labor, + get_production_factor, + get_carbon_price, +) +from pax.utils import float_precision + +eps = 1e-5 + +""" +This extension of the Rice-N environment adds a simple club mechanism as presented by Nordhaus. +The club mechanism is implemented as follows: + +1. Regions can join the club at any time +2. The club has a fixed tariff rate +3. Club members must implement a minimum mitigation rate + +If the mediator is enabled it can choose the club tariff rate and the minimum mitigation rate. + +""" + + +class ClubRice(environment.Environment): + env_id: str = "C-Rice-N" + + def __init__( + self, + config_folder: str, + has_mediator=False, + mediator_climate_weight=0, + mediator_utility_weight=1, + mediator_climate_objective=False, + default_club_tariff_rate=0.1, + default_club_mitigation_rate=0.3, + episode_length: int = 20, + ): + super().__init__() + + params, num_regions = load_rice_params(config_folder) + self.has_mediator = has_mediator + self.num_players = num_regions + self.episode_length = episode_length + self.num_actors = ( + self.num_players + 1 if self.has_mediator else self.num_players + ) + self.rice_constant = params["_RICE_GLOBAL_CONSTANT"] + self.dice_constant = params["_DICE_CONSTANT"] + self.region_constants = params["_REGIONS"] + self.region_params = params["_REGION_PARAMS"] + + self.savings_action_n = 1 + self.mitigation_rate_action_n = 1 + # Each region sets max allowed export from own region + self.export_action_n = 1 + # Each region sets import bids (max desired imports from other countries) + # TODO Find an "automatic" model for trade imports + # Reason: without protocols such as clubs there is no incentive for countries to impose a tariff + self.import_actions_n = self.num_players + # Each region sets import tariffs imposed on other countries + self.tariff_actions_n = self.num_players + self.join_club_action_n = 1 + + self.actions_n = ( + self.savings_action_n + + self.mitigation_rate_action_n + + self.export_action_n + + self.import_actions_n + + self.tariff_actions_n + + self.join_club_action_n + ) + + # Determine the index of each action to slice them in the step function + self.savings_action_index = 0 + self.mitigation_rate_action_index = ( + self.savings_action_index + self.savings_action_n + ) + self.export_action_index = ( + self.mitigation_rate_action_index + self.mitigation_rate_action_n + ) + self.tariffs_action_index = ( + self.export_action_index + self.export_action_n + ) + self.desired_imports_action_index = ( + self.tariffs_action_index + self.tariff_actions_n + ) + self.join_club_action_index = ( + self.desired_imports_action_index + self.join_club_action_n + ) + + # Parameters for armington aggregation utility + self.sub_rate = jnp.asarray(0.5, dtype=float_precision) + self.dom_pref = jnp.asarray(0.5, dtype=float_precision) + self.for_pref = jnp.asarray( + [0.5 / (self.num_players - 1)] * self.num_players, + dtype=float_precision, + ) + self.default_club_tariff_rate = jnp.asarray( + default_club_tariff_rate, dtype=float_precision + ) + self.default_club_mitigation_rate = jnp.asarray( + default_club_mitigation_rate, dtype=float_precision + ) + + if mediator_climate_objective: + self.mediator_climate_weight = 1 + self.mediator_utility_weight = 0 + else: + self.mediator_climate_weight = jnp.asarray( + mediator_climate_weight, dtype=float_precision + ) + self.mediator_utility_weight = jnp.asarray( + mediator_utility_weight, dtype=float_precision + ) + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, + ): + t = state.inner_t + 1 # Rice equations expect to start at t=1 + key, _ = jax.random.split(key, 2) + done = t >= self.episode_length + + t_at = state.global_temperature[0] + global_exogenous_emissions = get_exogenous_emissions( + self.dice_constant["xf_0"], + self.dice_constant["xf_1"], + self.dice_constant["xt_f"], + t, + ) + + actions = jnp.asarray(actions).astype(float_precision).squeeze() + actions = jnp.clip(actions, a_min=0, a_max=1) + + if self.has_mediator: + region_actions = actions[1:] + else: + region_actions = actions + + # TODO it'd be better if this variable was categorical in the agent network + club_membership_all = jnp.round( + region_actions[:, self.join_club_action_index] + ).astype(jnp.int8) + mitigation_cost_all = get_mitigation_cost( + self.rice_constant["xp_b"], + self.rice_constant["xtheta_2"], + self.rice_constant["xdelta_pb"], + state.intensity_all, + t, + ) + intensity_all = get_carbon_intensity( + state.intensity_all, + self.rice_constant["xg_sigma"], + self.rice_constant["xdelta_sigma"], + self.dice_constant["xDelta"], + t, + ) + + if has_mediator: + club_mitigation_rate = actions[ + 0, self.mitigation_rate_action_index + ] + club_tariff_rate = actions[0, self.tariffs_action_index] + else: + # Get the maximum carbon price of non-members from the last timestep + # club_price = jnp.max(state.carbon_price_all * (1 - state.club_membership_all)) + club_tariff_rate = self.default_club_tariff_rate + club_mitigation_rate = self.default_club_mitigation_rate + # club_mitigation_rate = get_club_mitigation_rates( + # club_price, + # intensity_all, + # self.rice_constant["xtheta_2"], + # mitigation_cost_all, + # state.damages_all + # ) + mitigation_rate_all = jnp.where( + club_membership_all == 1, + club_mitigation_rate, + region_actions[:, self.mitigation_rate_action_index], + ) + + abatement_cost_all = get_abatement_cost( + mitigation_rate_all, + mitigation_cost_all, + self.rice_constant["xtheta_2"], + ) + damages_all = get_damages( + t_at, + self.region_params["xa_1"], + self.region_params["xa_2"], + self.region_params["xa_3"], + ) + production_all = get_production( + state.production_factor_all, + state.capital_all, + state.labor_all, + self.region_params["xgamma"], + ) + gross_output_all = get_gross_output( + damages_all, abatement_cost_all, production_all + ) + balance_all = state.balance_all * (1 + self.rice_constant["r"]) + investment_all = get_investment( + region_actions[:, self.savings_action_index], gross_output_all + ) + + # Trade + desired_imports = region_actions[ + :, + self.desired_imports_action_index : self.desired_imports_action_index + + self.import_actions_n, + ] + # Countries cannot import from themselves + desired_imports = zero_diag(desired_imports) + total_desired_imports = desired_imports.sum(axis=1) + clipped_desired_imports = jnp.clip( + total_desired_imports, 0, gross_output_all + ) + desired_imports = ( + desired_imports + * ( + # Transpose to apply the scaling row and not column-wise + clipped_desired_imports + / (total_desired_imports + eps) + )[:, jnp.newaxis] + ) + init_capital_multiplier = 10.0 + debt_ratio = ( + init_capital_multiplier + * balance_all + / self.region_params["xK_0"] + ) + debt_ratio = jnp.clip(debt_ratio, -1.0, 0.0) + scaled_imports = desired_imports * (1 + debt_ratio) + + max_potential_exports = get_max_potential_exports( + region_actions[:, self.export_action_index], + gross_output_all, + investment_all, + ) + # Summing along columns yields the total number of exports requested by other regions + total_desired_exports = jnp.sum(scaled_imports, axis=0) + clipped_desired_exports = jnp.clip( + total_desired_exports, 0, max_potential_exports + ) + scaled_imports = ( + scaled_imports + * clipped_desired_exports + / (total_desired_exports + eps) + ) + + prev_tariffs = state.future_tariff + tariffed_imports = scaled_imports * (1 - prev_tariffs) + # calculate tariffed imports, tariff revenue and budget balance + # In the paper this goes to a "special reserve fund", i.e. it's not used + tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=0) + + total_exports = scaled_imports.sum(axis=0) + balance_all = balance_all + self.dice_constant["xDelta"] * ( + total_exports - scaled_imports.sum(axis=1) + ) + + c_dom = get_consumption( + gross_output_all, investment_all, total_exports + ) + consumption_all = get_armington_agg( + c_dom, + tariffed_imports, + self.sub_rate, + self.dom_pref, + self.for_pref, + ) + utility_all = get_utility( + state.labor_all, consumption_all, self.rice_constant["xalpha"] + ) + social_welfare_all = get_social_welfare( + utility_all, + self.rice_constant["xrho"], + self.dice_constant["xDelta"], + t, + ) + + # Update ecology + m_at = state.global_carbon_mass[0] + global_temperature = get_global_temperature( + self.dice_constant["xPhi_T"], + state.global_temperature, + self.dice_constant["xB_T"], + self.dice_constant["xF_2x"], + m_at, + self.dice_constant["xM_AT_1750"], + global_exogenous_emissions, + ) + + global_land_emissions = get_land_emissions( + self.dice_constant["xE_L0"], + self.dice_constant["xdelta_EL"], + t, + self.num_players, + ) + + aux_m_all = get_aux_m( + state.intensity_all, + mitigation_rate_all, + production_all, + global_land_emissions, + ) + + global_carbon_mass = get_global_carbon_mass( + self.dice_constant["xPhi_M"], + state.global_carbon_mass, + self.dice_constant["xB_M"], + jnp.sum(aux_m_all), + ) + + capital_depreciation = get_capital_depreciation( + self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] + ) + capital_all = get_capital( + capital_depreciation, + state.capital_all, + self.dice_constant["xDelta"], + investment_all, + ) + labor_all = get_labor( + state.labor_all, + self.region_params["xL_a"], + self.region_params["xl_g"], + ) + production_factor_all = get_production_factor( + state.production_factor_all, + self.region_params["xg_A"], + self.region_params["xdelta_A"], + self.dice_constant["xDelta"], + t, + ) + carbon_price_all = get_carbon_price( + mitigation_cost_all, + intensity_all, + mitigation_rate_all, + self.rice_constant["xtheta_2"], + damages_all, + ) + + desired_future_tariffs = region_actions[ + :, + self.tariffs_action_index : self.tariffs_action_index + + self.num_players, + ] + # Club members impose a minimum tariff of the club tariff rate + future_tariffs = jnp.where( + (club_membership_all == 1).reshape(-1, 1), + jnp.clip(desired_future_tariffs, a_min=club_tariff_rate), + desired_future_tariffs, + ) + # Club members don't impose tariffs on themselves or other club members + membership_mask = ( + club_membership_all.reshape(-1, 1) * club_membership_all + ) + future_tariffs = future_tariffs * (1 - membership_mask) + future_tariffs = zero_diag(future_tariffs) + + next_state = EnvState( + inner_t=state.inner_t + 1, + outer_t=state.outer_t, + global_temperature=global_temperature, + global_carbon_mass=global_carbon_mass, + global_exogenous_emissions=global_exogenous_emissions, + global_land_emissions=global_land_emissions, + labor_all=labor_all, + capital_all=capital_all, + production_factor_all=production_factor_all, + intensity_all=intensity_all, + balance_all=balance_all, + future_tariff=future_tariffs, + gross_output_all=gross_output_all, + investment_all=investment_all, + production_all=production_all, + utility_all=utility_all, + social_welfare_all=social_welfare_all, + capital_depreciation_all=jnp.asarray([capital_depreciation]), + mitigation_cost_all=mitigation_cost_all, + consumption_all=consumption_all, + damages_all=damages_all, + abatement_cost_all=abatement_cost_all, + tariff_revenue_all=tariff_revenue_all, + carbon_price_all=carbon_price_all, + club_membership_all=club_membership_all, + ) + + reset_obs, reset_state = _reset(key, params) + reset_state = reset_state.replace(outer_t=state.outer_t + 1) + + obs = [] + if self.has_mediator: + obs.append( + self._generate_mediator_observation( + next_state, club_mitigation_rate, club_tariff_rate + ) + ) + + for i in range(self.num_players): + obs.append( + self._generate_observation( + i, next_state, club_mitigation_rate, club_tariff_rate + ) + ) + + obs = jax.tree_map( + lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs) + ) + + result_state = jax.tree_map( + lambda x, y: jnp.where(done, x, y), + reset_state, + next_state, + ) + + rewards = result_state.utility_all + if self.has_mediator: + temp_increase = ( + next_state.global_temperature[0] + - state.global_temperature[0] + ) + # Rescale the social reward to make the weights comparable + social_reward = result_state.utility_all.sum() / ( + self.num_players * 10 + ) + mediator_reward = ( + -self.mediator_climate_weight * temp_increase + + self.mediator_utility_weight * social_reward + ) + rewards = jnp.insert( + result_state.utility_all, 0, mediator_reward + ) + + return ( + tuple(obs), + result_state, + tuple(rewards), + done, + {}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[Tuple, EnvState]: + state = self._get_initial_state() + obs = [] + club_state = jax.random.uniform(key, (2,)) + if self.has_mediator: + obs.append( + self._generate_mediator_observation( + state, club_state[0], club_state[1] + ) + ) + for i in range(self.num_players): + obs.append( + self._generate_observation( + i, state, club_state[0], club_state[1] + ) + ) + return tuple(obs), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + def _get_initial_state(self) -> EnvState: + return EnvState( + inner_t=jnp.zeros((), dtype=jnp.int16), + outer_t=jnp.zeros((), dtype=jnp.int16), + global_temperature=jnp.array( + [self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]] + ), + global_carbon_mass=jnp.array( + [ + self.dice_constant["xM_AT_0"], + self.dice_constant["xM_UP_0"], + self.dice_constant["xM_LO_0"], + ] + ), + global_exogenous_emissions=jnp.zeros((), dtype=float_precision), + global_land_emissions=jnp.zeros((), dtype=float_precision), + labor_all=self.region_params["xL_0"], + capital_all=self.region_params["xK_0"], + production_factor_all=self.region_params["xA_0"], + intensity_all=self.region_params["xsigma_0"], + balance_all=jnp.zeros(self.num_players, dtype=float_precision), + future_tariff=jnp.zeros( + (self.num_players, self.num_players), dtype=float_precision + ), + gross_output_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + investment_all=jnp.zeros(self.num_players, dtype=float_precision), + production_all=jnp.zeros(self.num_players, dtype=float_precision), + utility_all=jnp.zeros(self.num_players, dtype=float_precision), + social_welfare_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + capital_depreciation_all=jnp.zeros(0, dtype=float_precision), + mitigation_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + consumption_all=jnp.zeros(self.num_players, dtype=float_precision), + damages_all=jnp.zeros(self.num_players, dtype=float_precision), + abatement_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + tariff_revenue_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + carbon_price_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + club_membership_all=jnp.zeros(self.num_players, dtype=jnp.int8), + ) + + def _generate_observation( + self, + index: int, + state: EnvState, + club_mitigation_rate, + club_tariff_rate, + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.asarray([index]), + jnp.asarray([state.inner_t]), + jnp.array([club_mitigation_rate]), + jnp.array([club_tariff_rate]), + state.global_temperature, + state.global_carbon_mass, + state.gross_output_all, + state.investment_all, + state.abatement_cost_all, + state.tariff_revenue_all, + state.club_membership_all, + ], + dtype=float_precision, + ) + + def _generate_mediator_observation( + self, state: EnvState, club_mitigation_rate, club_tariff_rate + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.zeros(1), + jnp.asarray([state.inner_t]), + jnp.array([club_mitigation_rate]), + jnp.array([club_tariff_rate]), + state.global_temperature, + state.global_carbon_mass, + state.gross_output_all, + state.investment_all, + state.abatement_cost_all, + state.tariff_revenue_all, + state.club_membership_all, + ], + dtype=float_precision, + ) + + @property + def name(self) -> str: + return self.env_id + + @property + def num_actions(self) -> int: + return self.actions_n + + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: + return spaces.Box(low=0, high=1, shape=(self.actions_n,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + init_state = self._get_initial_state() + obs = self._generate_observation(0, init_state, 0, 0) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=float_precision + ) + + +def get_club_mitigation_rates( + coalition_price, intensity, theta_2, mitigation_cost, damages +): + return pow( + coalition_price * intensity / (theta_2 * mitigation_cost * damages), + 1 / (theta_2 - 1), + ) diff --git a/pax/envs/rice/rice.py b/pax/envs/rice/rice.py new file mode 100644 index 00000000..9c5f5df8 --- /dev/null +++ b/pax/envs/rice/rice.py @@ -0,0 +1,776 @@ +import os +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +import yaml +from gymnax.environments import environment, spaces +from jax import Array + +from pax.utils import float_precision + + +@chex.dataclass +class EnvState: + inner_t: int + outer_t: int + + # Ecological + global_temperature: chex.ArrayDevice + global_carbon_mass: chex.ArrayDevice + global_exogenous_emissions: float + global_land_emissions: float + + # Economic + labor_all: chex.ArrayDevice + capital_all: chex.ArrayDevice + production_factor_all: chex.ArrayDevice + intensity_all: chex.ArrayDevice + balance_all: chex.ArrayDevice + + # Tariffs are applied to the next time step + future_tariff: chex.ArrayDevice + # tariff_revenue: chex.ArrayDevice + + # The following values are intermediary values + # that we only track in the state for easier evaluation and logging + gross_output_all: chex.ArrayDevice + investment_all: chex.ArrayDevice + production_all: chex.ArrayDevice + utility_all: chex.ArrayDevice + social_welfare_all: chex.ArrayDevice + capital_depreciation_all: chex.ArrayDevice + mitigation_cost_all: chex.ArrayDevice + consumption_all: chex.ArrayDevice + damages_all: chex.ArrayDevice + abatement_cost_all: chex.ArrayDevice + + tariff_revenue_all: chex.ArrayDevice + carbon_price_all: chex.ArrayDevice + club_membership_all: chex.ArrayDevice + + +@chex.dataclass +class EnvParams: + pass + + +eps = 1e-5 + +""" +Based off the MARL environment from https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4189735 +which in turn is an adaptation of the RICE IAM. +""" + + +class Rice(environment.Environment): + env_id: str = "Rice-N" + + def __init__( + self, config_folder: str, has_mediator=False, episode_length=20 + ): + super().__init__() + + # TODO refactor all the constants to use env_params + # 1. Load env params in the experiment.py#env_setup + # 2. type env params as a chex dataclass + # 3. change the references in the code to env params + params, num_regions = load_rice_params(config_folder) + self.has_mediator = has_mediator + self.num_players = num_regions + self.num_actors = ( + self.num_players + 1 if self.has_mediator else self.num_players + ) + self.rice_constant = params["_RICE_GLOBAL_CONSTANT"] + self.dice_constant = params["_DICE_CONSTANT"] + self.region_constants = params["_REGIONS"] + self.region_params = params["_REGION_PARAMS"] + + self.savings_action_n = 1 + self.mitigation_rate_action_n = 1 + # Each region sets max allowed export from own region + self.export_action_n = 1 + # Each region sets import bids (max desired imports from other countries) + # TODO Find an "automatic" model for trade imports + # Reason: without protocols such as clubs there is no incentive for countries to impose a tariff + self.import_actions_n = self.num_players + # Each region sets import tariffs imposed on other countries + self.tariff_actions_n = self.num_players + + self.actions_n = ( + self.savings_action_n + + self.mitigation_rate_action_n + + self.export_action_n + + self.import_actions_n + + self.tariff_actions_n + ) + + # Determine the index of each action to slice them in the step function + self.savings_action_index = 0 + self.mitigation_rate_action_index = ( + self.savings_action_index + self.savings_action_n + ) + self.export_action_index = ( + self.mitigation_rate_action_index + self.mitigation_rate_action_n + ) + self.tariffs_action_index = ( + self.export_action_index + self.export_action_n + ) + self.desired_imports_action_index = ( + self.tariffs_action_index + self.import_actions_n + ) + + # Parameters for armington aggregation utility + self.sub_rate = jnp.asarray(0.5, dtype=float_precision) + self.dom_pref = jnp.asarray(0.5, dtype=float_precision) + self.for_pref = jnp.asarray( + [0.5 / (self.num_players - 1)] * self.num_players, + dtype=float_precision, + ) + self.episode_length = episode_length + + def _step( + key: chex.PRNGKey, + state: EnvState, + actions: Tuple[float, ...], + params: EnvParams, + ): + t = state.inner_t + 1 # Rice equations expect to start at t=1 + key, _ = jax.random.split(key, 2) + done = t >= self.episode_length + + t_at = state.global_temperature[0] + global_exogenous_emissions = get_exogenous_emissions( + self.dice_constant["xf_0"], + self.dice_constant["xf_1"], + self.dice_constant["xt_f"], + t, + ) + + actions = jnp.asarray(actions).astype(float_precision).squeeze() + actions = jnp.clip(actions, a_min=0, a_max=1) + + if self.has_mediator: + region_actions = actions[1:] + else: + region_actions = actions + + mitigation_cost_all = get_mitigation_cost( + self.rice_constant["xp_b"], + self.rice_constant["xtheta_2"], + self.rice_constant["xdelta_pb"], + state.intensity_all, + t, + ) + abatement_cost_all = get_abatement_cost( + region_actions[:, self.mitigation_rate_action_index], + mitigation_cost_all, + self.rice_constant["xtheta_2"], + ) + damages_all = get_damages( + t_at, + self.region_params["xa_1"], + self.region_params["xa_2"], + self.region_params["xa_3"], + ) + production_all = get_production( + state.production_factor_all, + state.capital_all, + state.labor_all, + self.region_params["xgamma"], + ) + gross_output_all = get_gross_output( + damages_all, abatement_cost_all, production_all + ) + balance_all = state.balance_all * (1 + self.rice_constant["r"]) + investment_all = get_investment( + region_actions[:, self.savings_action_index], gross_output_all + ) + + # Trade + desired_imports = region_actions[ + :, + self.desired_imports_action_index : self.desired_imports_action_index + + self.import_actions_n, + ] + # Countries cannot import from themselves + desired_imports = zero_diag(desired_imports) + total_desired_imports = desired_imports.sum(axis=1) + clipped_desired_imports = jnp.clip( + total_desired_imports, 0, gross_output_all + ) + desired_imports = ( + desired_imports + * ( + # Transpose to apply the scaling row and not column-wise + clipped_desired_imports + / (total_desired_imports + eps) + )[:, jnp.newaxis] + ) + init_capital_multiplier = 10.0 + debt_ratio = ( + init_capital_multiplier + * balance_all + / self.region_params["xK_0"] + ) + debt_ratio = jnp.clip(debt_ratio, -1.0, 0.0) + scaled_imports = desired_imports * (1 + debt_ratio) + + max_potential_exports = get_max_potential_exports( + region_actions[:, self.export_action_index], + gross_output_all, + investment_all, + ) + # Summing along columns yields the total number of exports requested by other regions + total_desired_exports = jnp.sum(scaled_imports, axis=0) + clipped_desired_exports = jnp.clip( + total_desired_exports, 0, max_potential_exports + ) + scaled_imports = ( + scaled_imports + * clipped_desired_exports + / (total_desired_exports + eps) + ) + + prev_tariffs = state.future_tariff + tariffed_imports = scaled_imports * (1 - prev_tariffs) + # calculate tariffed imports, tariff revenue and budget balance + # In the paper this goes to a "special reserve fund", i.e. it's not used + tariff_revenue_all = jnp.sum(scaled_imports * prev_tariffs, axis=1) + + total_exports = scaled_imports.sum(axis=0) + balance_all = balance_all + self.dice_constant["xDelta"] * ( + total_exports - scaled_imports.sum(axis=1) + ) + + c_dom = get_consumption( + gross_output_all, investment_all, total_exports + ) + consumption_all = get_armington_agg( + c_dom, + tariffed_imports, + self.sub_rate, + self.dom_pref, + self.for_pref, + ) + utility_all = get_utility( + state.labor_all, consumption_all, self.rice_constant["xalpha"] + ) + social_welfare_all = get_social_welfare( + utility_all, + self.rice_constant["xrho"], + self.dice_constant["xDelta"], + t, + ) + + # Update ecology + m_at = state.global_carbon_mass[0] + global_temperature = get_global_temperature( + self.dice_constant["xPhi_T"], + state.global_temperature, + self.dice_constant["xB_T"], + self.dice_constant["xF_2x"], + m_at, + self.dice_constant["xM_AT_1750"], + global_exogenous_emissions, + ) + + global_land_emissions = get_land_emissions( + self.dice_constant["xE_L0"], + self.dice_constant["xdelta_EL"], + t, + self.num_players, + ) + + aux_m_all = get_aux_m( + state.intensity_all, + region_actions[:, self.mitigation_rate_action_index], + production_all, + global_land_emissions, + ) + + global_carbon_mass = get_global_carbon_mass( + self.dice_constant["xPhi_M"], + state.global_carbon_mass, + self.dice_constant["xB_M"], + jnp.sum(aux_m_all), + ) + + capital_depreciation = get_capital_depreciation( + self.rice_constant["xdelta_K"], self.dice_constant["xDelta"] + ) + capital_all = get_capital( + capital_depreciation, + state.capital_all, + self.dice_constant["xDelta"], + investment_all, + ) + intensity_all = get_carbon_intensity( + state.intensity_all, + self.rice_constant["xg_sigma"], + self.rice_constant["xdelta_sigma"], + self.dice_constant["xDelta"], + t, + ) + labor_all = get_labor( + state.labor_all, + self.region_params["xL_a"], + self.region_params["xl_g"], + ) + production_factor_all = get_production_factor( + state.production_factor_all, + self.region_params["xg_A"], + self.region_params["xdelta_A"], + self.dice_constant["xDelta"], + t, + ) + carbon_price_all = get_carbon_price( + mitigation_cost_all, + intensity_all, + region_actions[:, self.mitigation_rate_action_index], + self.rice_constant["xtheta_2"], + damages_all, + ) + + next_state = EnvState( + inner_t=state.inner_t + 1, + outer_t=state.outer_t, + global_temperature=global_temperature, + global_carbon_mass=global_carbon_mass, + global_exogenous_emissions=global_exogenous_emissions, + global_land_emissions=global_land_emissions, + labor_all=labor_all, + capital_all=capital_all, + production_factor_all=production_factor_all, + intensity_all=intensity_all, + balance_all=balance_all, + future_tariff=region_actions[ + :, + self.tariffs_action_index : self.tariffs_action_index + + self.num_players, + ], + gross_output_all=gross_output_all, + investment_all=investment_all, + production_all=production_all, + utility_all=utility_all, + social_welfare_all=social_welfare_all, + capital_depreciation_all=jnp.asarray([capital_depreciation]), + mitigation_cost_all=mitigation_cost_all, + consumption_all=consumption_all, + damages_all=damages_all, + abatement_cost_all=abatement_cost_all, + tariff_revenue_all=tariff_revenue_all, + carbon_price_all=carbon_price_all, + club_membership_all=jnp.zeros( + self.num_players, dtype=jnp.int8 + ), + ) + + reset_obs, reset_state = _reset(key, params) + reset_state = reset_state.replace(outer_t=state.outer_t + 1) + + obs = [] + if self.has_mediator: + obs.append( + self._generate_mediator_observation(actions, next_state) + ) + + for i in range(self.num_players): + obs.append(self._generate_observation(i, actions, next_state)) + + obs = jax.tree_map( + lambda x, y: jnp.where(done, x, y), reset_obs, tuple(obs) + ) + + state = jax.tree_map( + lambda x, y: jnp.where(done, x, y), + reset_state, + next_state, + ) + + if self.has_mediator: + rewards = jnp.insert( + state.utility_all, 0, state.utility_all.sum() + ) + else: + rewards = state.utility_all + + return ( + tuple(obs), + state, + tuple(rewards), + done, + {}, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[Tuple, EnvState]: + state = self._get_initial_state() + actions = jnp.asarray( + [ + jax.random.uniform(key, (self.num_actions,)) + for _ in range(self.num_actors) + ] + ) + obs = [] + if self.has_mediator: + obs.append(self._generate_mediator_observation(actions, state)) + for i in range(self.num_players): + obs.append(self._generate_observation(i, actions, state)) + return tuple(obs), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + def _get_initial_state(self) -> EnvState: + return EnvState( + inner_t=jnp.zeros((), dtype=jnp.int16), + outer_t=jnp.zeros((), dtype=jnp.int16), + global_temperature=jnp.array( + [self.rice_constant["xT_AT_0"], self.rice_constant["xT_LO_0"]] + ), + global_carbon_mass=jnp.array( + [ + self.dice_constant["xM_AT_0"], + self.dice_constant["xM_UP_0"], + self.dice_constant["xM_LO_0"], + ] + ), + global_exogenous_emissions=jnp.zeros((), dtype=float_precision), + global_land_emissions=jnp.zeros((), dtype=float_precision), + labor_all=self.region_params["xL_0"], + capital_all=self.region_params["xK_0"], + production_factor_all=self.region_params["xA_0"], + intensity_all=self.region_params["xsigma_0"], + balance_all=jnp.zeros(self.num_players, dtype=float_precision), + future_tariff=jnp.zeros( + (self.num_players, self.num_players), dtype=float_precision + ), + gross_output_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + investment_all=jnp.zeros(self.num_players, dtype=float_precision), + production_all=jnp.zeros(self.num_players, dtype=float_precision), + utility_all=jnp.zeros(self.num_players, dtype=float_precision), + social_welfare_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + capital_depreciation_all=jnp.zeros(0, dtype=float_precision), + mitigation_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + consumption_all=jnp.zeros(self.num_players, dtype=float_precision), + damages_all=jnp.zeros(self.num_players, dtype=float_precision), + abatement_cost_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + tariff_revenue_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + carbon_price_all=jnp.zeros( + self.num_players, dtype=float_precision + ), + club_membership_all=jnp.zeros(self.num_players, dtype=jnp.int8), + ) + + def _generate_observation( + self, index: int, actions: chex.ArrayDevice, state: EnvState + ) -> Array: + return jnp.concatenate( + [ + # Public features + jnp.asarray([index]), + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.tariff_revenue_all, + state.carbon_price_all, + state.club_membership_all, + # Private features + jnp.asarray([state.damages_all[index]]), + jnp.asarray([state.abatement_cost_all[index]]), + jnp.asarray([state.production_all[index]]), + # All agent actions + actions.ravel(), + ], + dtype=float_precision, + ) + + def _generate_mediator_observation( + self, actions: chex.ArrayDevice, state: EnvState + ) -> Array: + return jnp.concatenate( + [ + jnp.zeros(1), + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.tariff_revenue_all, + state.carbon_price_all, + state.club_membership_all, + # Maintain same dimensionality as for other players + jnp.zeros(3), + # All agent actions + actions.ravel(), + ], + dtype=float_precision, + ) + + @property + def name(self) -> str: + return self.env_id + + @property + def num_actions(self) -> int: + return self.actions_n + + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: + return spaces.Box(low=0, high=1, shape=(self.actions_n,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + init_state = self._get_initial_state() + obs = self._generate_observation( + 0, jnp.zeros(self.num_actions * self.num_actors), init_state + ) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=float_precision + ) + + +def load_rice_params(config_dir=None): + """Helper function to read yaml data and set environment configs.""" + assert config_dir is not None + base_params = load_yaml_data(os.path.join(config_dir, "default.yml")) + file_list = sorted(os.listdir(config_dir)) # + yaml_files = [] + for file in file_list: + if file[-4:] == ".yml" and file != "default.yml": + yaml_files.append(file) + + region_params = [] + for file in yaml_files: + region_params.append(load_yaml_data(os.path.join(config_dir, file))) + + # _REGIONS is a list of dictionaries + base_params["_REGIONS"] = [] + for param in region_params: + region_to_append = param["_RICE_CONSTANT"] + for k in base_params["_RICE_CONSTANT_DEFAULT"].keys(): + if k not in region_to_append.keys(): + region_to_append[k] = base_params["_RICE_CONSTANT_DEFAULT"][k] + base_params["_REGIONS"].append(region_to_append) + + # _REGION_PARAMS is a dictionary of lists + base_params["_REGION_PARAMS"] = {} + for k in base_params["_RICE_CONSTANT_DEFAULT"].keys(): + base_params["_REGION_PARAMS"][k] = [] + for param in region_params: + parameter_value = param["_RICE_CONSTANT"].get( + k, base_params["_RICE_CONSTANT_DEFAULT"][k] + ) + base_params["_REGION_PARAMS"][k].append(parameter_value) + base_params["_REGION_PARAMS"][k] = jnp.asarray( + base_params["_REGION_PARAMS"][k], dtype=float_precision + ) + + return base_params, len(region_params) + + +def zero_diag(matrix: jax.Array) -> jax.Array: + return matrix - matrix * jnp.eye( + matrix.shape[0], matrix.shape[1], dtype=matrix.dtype + ) + + +def get_exogenous_emissions(f_0, f_1, t_f, timestep): + return f_0 + jnp.min( + jnp.array([f_1 - f_0, (f_1 - f_0) / t_f * (timestep - 1)]) + ) + + +def get_land_emissions(e_l0, delta_el, timestep, num_regions): + return e_l0 * pow(1 - delta_el, timestep - 1) / num_regions + + +def get_mitigation_cost(p_b, theta_2, delta_pb, intensity, timestep): + return p_b / (1000 * theta_2) * pow(1 - delta_pb, timestep - 1) * intensity + + +def get_damages(t_at, a_1, a_2, a_3): + return 1 / (1 + a_1 * t_at + a_2 * pow(t_at, a_3)) + + +def get_abatement_cost(mitigation_rate, mitigation_cost, theta_2): + return mitigation_cost * pow(mitigation_rate, theta_2) + + +def get_production(production_factor, capital, labor, gamma): + """Obtain the amount of goods produced.""" + return ( + production_factor * pow(capital, gamma) * pow(labor / 1000, 1 - gamma) + ) + + +def get_gross_output(damages, abatement_cost, production): + return damages * (1 - abatement_cost) * production + + +def get_investment(savings, gross_output): + return savings * gross_output + + +def get_consumption(gross_output, investment, total_exports): + return jnp.clip(gross_output - investment - total_exports, 0) + + +def get_max_potential_exports(x_max, gross_output, investment): + return jnp.min( + jnp.array([x_max * gross_output, gross_output - investment]), axis=0 + ) + + +def get_capital_depreciation(x_delta_k, x_delta): + return pow(1 - x_delta_k, x_delta) + + +# Returns shape 2 +def get_global_temperature( + phi_t, temperature, b_t, f_2x, m_at, m_at_1750, exogenous_emissions +): + return jnp.dot(phi_t, temperature) + jnp.dot( + b_t, + f_2x * jnp.log(m_at / m_at_1750) / jnp.log(2) + exogenous_emissions, + ) + + +def get_aux_m(intensity, mitigation_rate, production, land_emissions): + """Auxiliary variable to denote carbon mass levels.""" + return intensity * (1 - mitigation_rate) * production + land_emissions + + +def get_global_carbon_mass(phi_m, carbon_mass, b_m, aux_m): + return jnp.dot(phi_m, carbon_mass) + jnp.dot(b_m, aux_m) + + +def get_capital(capital_depreciation, capital, delta, investment): + return capital_depreciation * capital + delta * investment + + +def get_labor(labor, l_a, l_g): + return labor * pow((1 + l_a) / (1 + labor), l_g) + + +def get_production_factor(production_factor, g_a, delta_a, delta, timestep): + return production_factor * ( + jnp.exp(0.0033) + g_a * jnp.exp(-delta_a * delta * (timestep - 1)) + ) + + +def get_carbon_price( + mitigation_cost, intensity, mitigation_rate, theta_2, damages +): + return ( + theta_2 + * damages + * pow(mitigation_rate, theta_2 - 1) + * mitigation_cost + / intensity + ) + + +def get_carbon_intensity(intensity, g_sigma, delta_sigma, delta, timestep): + return intensity * jnp.exp( + -g_sigma * pow(1 - delta_sigma, delta * (timestep - 1)) * delta + ) + + +_SMALL_NUM = 1e-0 + + +def get_utility(labor, consumption, alpha): + return ( + (labor / 1000.0) + * (pow(consumption / (labor / 1000.0) + _SMALL_NUM, 1 - alpha) - 1) + / (1 - alpha) + ) + + +def get_social_welfare(utility, rho, delta, timestep): + return utility / pow(1 + rho, delta * timestep) + + +def get_armington_agg( + c_dom, + imports, # np.array + sub_rate=0.5, # in (0,1) + dom_pref=0.5, # in [0,1] + for_pref=None, # np.array +): + """ + Armington aggregate from Lessmann, 2009. + Consumption goods from different regions act as imperfect substitutes. + As such, consumption of domestic and foreign goods are scaled according to + relative preferences, as well as a substitution rate, which are modeled + by a CES functional form. + Inputs : + `C_dom` : A scalar representing domestic consumption. The value of + C_dom is what is left over from initial production after + investment and exports are deducted. + `C_for` : A matrix representing the trade flows between regions. + `sub_rate` : A substitution parameter in (0,1). The elasticity of + substitution is 1 / (1 - sub_rate). + `dom_pref` : A scalar in [0,1] representing the relative preference for + domestic consumption over foreign consumption. + `for_pref` : An array of the same size as `C_for`. Each element is the + relative preference for foreign goods from that country. + """ + + c_dom_pref = dom_pref * (c_dom**sub_rate) + # Axis 1 because we consider imports not exports + c_for_pref = jnp.sum(for_pref * pow(imports, sub_rate), axis=1) + + c_agg = (c_dom_pref + c_for_pref) ** (1 / sub_rate) # CES function + return c_agg + + +def load_yaml_data(yaml_file: str): + """Helper function to read yaml configuration data.""" + with open(yaml_file, "r", encoding="utf-8") as file_ptr: + file_data = file_ptr.read() + data = yaml.load(file_data, Loader=yaml.FullLoader) + return rec_array_conversion(data) + + +def rec_array_conversion(data): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, list): + data[key] = jnp.asarray(value, dtype=float_precision) + elif isinstance(value, dict): + data[key] = rec_array_conversion(value) + elif isinstance(value, float): + data[key] = jnp.asarray(value, dtype=float_precision) + elif isinstance(value, int): + data[key] = jnp.asarray(value, dtype=float_precision) + elif isinstance(data, list): + data = jnp.asarray(data, dtype=float_precision) + return data diff --git a/pax/envs/rice/sarl_rice.py b/pax/envs/rice/sarl_rice.py new file mode 100644 index 00000000..9742bd2c --- /dev/null +++ b/pax/envs/rice/sarl_rice.py @@ -0,0 +1,115 @@ +from typing import Optional, Tuple + +import chex +import jax +import jax.debug +import jax.numpy as jnp +from gymnax.environments import environment, spaces + +from pax.envs.rice.rice import Rice, EnvState, EnvParams + +""" +Wrapper to turn Rice-N into a single-agent environment. +""" + + +class SarlRice(environment.Environment): + env_id: str = "SarlRice-N" + + def __init__( + self, + config_folder: str, + fixed_mitigation_rate: int = None, + episode_length: int = 20, + ): + super().__init__() + self.rice = Rice(config_folder, episode_length=episode_length) + self.fixed_mitigation_rate = fixed_mitigation_rate + + def _step( + key: chex.PRNGKey, + state: EnvState, + action: chex.Array, + params: EnvParams, + ): + actions = jnp.asarray(jnp.split(action, self.rice.num_players)) + # To facilitate learning and since we want to optimize a social planner + # we disable defection actions + actions = actions.at[:, self.rice.export_action_index].set(1.0) + actions = actions.at[ + :, + self.rice.tariffs_action_index : self.rice.tariffs_action_index + + self.rice.num_players, + ].set(0.0) + + if self.fixed_mitigation_rate is not None: + actions = actions.at[ + :, self.rice.mitigation_rate_action_index + ].set(self.fixed_mitigation_rate) + + # actions = actions.at[:, self.rice.desired_imports_action_index: + # self.rice.desired_imports_action_index + self.rice.num_players].set(0.0) + obs, state, rewards, done, info = self.rice.step( + key, state, tuple(actions), params + ) + + return ( + self._generate_observation(state), + state, + jnp.asarray(rewards).sum(), + done, + info, + ) + + def _reset( + key: chex.PRNGKey, params: EnvParams + ) -> Tuple[chex.Array, EnvState]: + _, state = self.rice.reset(key, params) + return self._generate_observation(state), state + + self.step = jax.jit(_step) + self.reset = jax.jit(_reset) + + def _generate_observation(self, state: EnvState): + return jnp.concatenate( + [ + # Public features + jnp.asarray([state.inner_t]), + state.global_temperature, + state.global_carbon_mass, + jnp.asarray([state.global_exogenous_emissions]), + jnp.asarray([state.global_land_emissions]), + state.labor_all, + state.capital_all, + state.gross_output_all, + state.consumption_all, + state.investment_all, + state.balance_all, + state.production_factor_all, + state.intensity_all, + state.mitigation_cost_all, + state.damages_all, + state.abatement_cost_all, + state.production_all, + state.utility_all, + state.social_welfare_all, + ] + ) + + @property + def name(self) -> str: + return self.env_id + + @property + def num_actions(self) -> int: + return self.rice.num_actions * self.rice.num_players + + def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box: + return spaces.Box(low=0, high=1, shape=(self.num_actions,)) + + def observation_space(self, params: EnvParams) -> spaces.Box: + _, state = self.reset(jax.random.PRNGKey(0), params) + obs = self._generate_observation(state) + return spaces.Box( + low=0, high=float("inf"), shape=obs.shape, dtype=jnp.float32 + ) diff --git a/pax/experiment.py b/pax/experiment.py index b990e885..4467f444 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -8,9 +8,10 @@ import jax import jax.numpy as jnp import omegaconf +import wandb from evosax import CMA_ES, PGPE, OpenES, ParameterReshaper, SimpleGA +from jax.lib import xla_bridge -import wandb from pax.agents.hyper.ppo import make_hyper from pax.agents.lola.lola import make_lola from pax.agents.mfos_ppo.ppo_gru import make_mfos_agent @@ -42,6 +43,10 @@ ) from pax.envs.coin_game import CoinGame from pax.envs.coin_game import EnvParams as CoinGameParams +from pax.envs.cournot import CournotGame +from pax.envs.cournot import EnvParams as CournotParams +from pax.envs.fishery import EnvParams as FisheryParams +from pax.envs.fishery import Fishery from pax.envs.in_the_matrix import EnvParams as InTheMatrixParams from pax.envs.in_the_matrix import InTheMatrix from pax.envs.infinite_matrix_game import EnvParams as InfiniteMatrixGameParams @@ -52,6 +57,10 @@ EnvParams as IteratedTensorGameNPlayerParams, ) from pax.envs.iterated_tensor_game_n_player import IteratedTensorGameNPlayer +from pax.envs.rice.c_rice import ClubRice +from pax.envs.rice.rice import Rice, EnvParams as RiceParams +from pax.envs.rice.sarl_rice import SarlRice +from pax.runners.runner_weight_sharing import WeightSharingRunner from pax.runners.runner_eval import EvalRunner from pax.runners.runner_eval_multishaper import MultishaperEvalRunner from pax.runners.runner_evo import EvoRunner @@ -72,12 +81,12 @@ value_logger_ppo, ) -# NOTE: THIS MUST BE sDONE BEFORE IMPORTING JAX + +# NOTE: THIS MUST BE DONE BEFORE IMPORTING JAX # uncomment to debug multi-devices on CPU # os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" # from jax.config import config - -# config.update("jax_disable_jit", True) +# config.update('jax_disable_jit', True) def global_setup(args): @@ -89,20 +98,23 @@ def global_setup(args): exist_ok=True, ) if args.wandb.log: - print("name", str(args.wandb.name)) + print("run name", str(args.wandb.name)) if args.debug: args.wandb.group = "debug-" + args.wandb.group - wandb.init( + run = wandb.init( reinit=True, entity=str(args.wandb.entity), project=str(args.wandb.project), group=str(args.wandb.group), name=str(args.wandb.name), + mode=str(args.wandb.mode), + tags=args.wandb.tags, config=omegaconf.OmegaConf.to_container( args, resolve=True, throw_on_missing=True ), # type: ignore settings=wandb.Settings(code_dir="."), ) + print("run id", run.id) wandb.run.log_code(".") return save_dir @@ -182,6 +194,76 @@ def env_setup(args, logger=None): logger.info( f"Env Type: InTheMatrix | Inner Episode Length: {args.num_inner_steps}" ) + elif args.env_id == "Cournot": + env_params = CournotParams( + a=args.a, b=args.b, marginal_cost=args.marginal_cost + ) + env = CournotGame( + num_players=args.num_players, + num_inner_steps=args.num_inner_steps, + ) + if logger: + logger.info( + f"Env Type: Cournot | Inner Episode Length: {args.num_inner_steps}" + ) + elif args.env_id == Fishery.env_id: + env_params = FisheryParams( + g=args.g, + e=args.e, + P=args.P, + w=args.w, + s_0=args.s_0, + s_max=args.s_max, + ) + env = Fishery( + num_players=args.num_players, + num_inner_steps=args.num_inner_steps, + ) + if logger: + logger.info( + f"Env Type: Fishery | Inner Episode Length: {args.num_inner_steps}" + ) + elif args.env_id == Rice.env_id: + env_params = RiceParams() + env = Rice( + config_folder=args.config_folder, + has_mediator=args.has_mediator, + ) + if logger: + logger.info( + f"Env Type: {env.env_id} | Inner Episode Length: {args.num_inner_steps}" + ) + elif args.env_id == ClubRice.env_id: + env_params = RiceParams() + env = ClubRice( + config_folder=args.config_folder, + has_mediator=args.has_mediator, + mediator_climate_objective=args.get( + "mediator_climate_objective", None + ), + default_club_mitigation_rate=args.get( + "default_club_mitigation_rate", None + ), + default_club_tariff_rate=args.get( + "default_club_tariff_rate", None + ), + mediator_climate_weight=args.get("mediator_climate_weight", None), + mediator_utility_weight=args.get("mediator_utility_weight", None), + ) + if logger: + logger.info( + f"Env Type: {env.env_id} | Inner Episode Length: {args.num_inner_steps}" + ) + elif args.env_id == SarlRice.env_id: + env_params = RiceParams() + env = SarlRice( + config_folder=args.config_folder, + fixed_mitigation_rate=args.get("fixed_mitigation_rate", None), + ) + if logger: + logger.info( + f"Env Type: SarlRice | Inner Episode Length: {args.num_inner_steps}" + ) elif args.runner == "sarl": env, env_params = gymnax.make(args.env_id) else: @@ -296,6 +378,7 @@ def get_pgpe_strategy(agent): save_dir, args, ) + elif args.runner == "multishaper_evo": logger.info("Training with multishaper EVO runner") return MultishaperEvoRunner( @@ -316,6 +399,9 @@ def get_pgpe_strategy(agent): elif args.runner == "sarl": logger.info("Training with SARL Runner") return SARLRunner(agents, env, save_dir, args) + elif args.runner == WeightSharingRunner.id: + logger.info("Training with Weight Sharing Runner") + return WeightSharingRunner(agents, env, save_dir, args) else: raise ValueError(f"Unknown runner type {args.runner}") @@ -350,7 +436,13 @@ def get_LOLA_agent(seed, player_id): ) def get_PPO_memory_agent(seed, player_id): - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id)) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) + num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps @@ -365,13 +457,17 @@ def get_PPO_memory_agent(seed, player_id): ) def get_PPO_agent(seed, player_id): - player_args = omegaconf.OmegaConf.select(args, "ppo" + str(player_id)) + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) + num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps - else: - num_iterations = args.num_iters - ppo_agent = make_agent( + return make_agent( args, player_args, obs_spec=obs_shape, @@ -380,10 +476,15 @@ def get_PPO_agent(seed, player_id): seed=seed, player_id=player_id, ) - return ppo_agent def get_PPO_tabular_agent(seed, player_id): - player_args = args.ppo1 if player_id == 1 else args.ppo2 + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + player_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) + num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps @@ -400,7 +501,13 @@ def get_PPO_tabular_agent(seed, player_id): return ppo_agent def get_mfos_agent(seed, player_id): - agent_args = args.ppo1 + default_player_args = omegaconf.OmegaConf.select( + args, "ppo_default", default=None + ) + agent_args = omegaconf.OmegaConf.select( + args, "ppo" + str(player_id), default=default_player_args + ) + num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps @@ -486,29 +593,30 @@ def get_stay_agent(seed, player_id): "HyperTFT": partial(HyperTFT, args.num_envs), } - if args.runner == "sarl": + if args.runner in ["sarl", "weight_sharing"]: assert args.agent1 in strategies - num_agents = 1 seeds = [args.seed] # Create Player IDs by normalizing seeds to 1, 2 respectively pids = [0] agent_1 = strategies[args.agent1](seeds[0], pids[0]) # player 1 - - if args.agent1 in ["PPO", "PPO_memory"] and args.ppo.with_cnn: - logger.info(f"PPO with CNN: {args.ppo.with_cnn}") logger.info(f"Agent Pair: {args.agent1}") logger.info(f"Agent seeds: {seeds[0]}") if args.runner in ["eval", "sarl"]: logger.info("Using Independent Learners") - return agent_1 - + return agent_1 else: - for i in range(1, args.num_players + 1): - assert ( - omegaconf.OmegaConf.select(args, "agent" + str(i)) - in strategies + default_agent = omegaconf.OmegaConf.select( + args, "agent_default", default=None + ) + agent_strategies = [ + omegaconf.OmegaConf.select( + args, "agent" + str(i), default=default_agent ) + for i in range(1, args.num_players + 1) + ] + for strategy in agent_strategies: + assert strategy in strategies seeds = [ seed for seed in range(args.seed, args.seed + args.num_players) @@ -519,15 +627,9 @@ def get_stay_agent(seed, player_id): for seed, i in zip(seeds, range(1, args.num_players + 1)) ] agents = [] - for i in range(args.num_players): - agents.append( - strategies[ - omegaconf.OmegaConf.select(args, "agent" + str(i + 1)) - ](seeds[i], pids[i]) - ) - logger.info( - f"Agent Pair: {[omegaconf.OmegaConf.select(args, 'agent' + str(i)) for i in range(1, args.num_players + 1)]}" - ) + for idx, strategy in enumerate(agent_strategies): + agents.append(strategies[strategy](seeds[idx], pids[idx])) + logger.info(f"Agent Pair: {agents}") logger.info(f"Agent seeds: {seeds}") return agents @@ -554,6 +656,10 @@ def ppo_log(agent): losses = losses_ppo(agent) if args.env_id not in [ "coin_game", + "Cournot", + "Fishery", + "Rice-N", + "C-Rice-N", "InTheMatrix", "iterated_matrix_game", "iterated_nplayer_tensor_game", @@ -624,24 +730,38 @@ def naive_pg_log(agent): "Tabular": ppo_log, "PPO_memory_pretrained": ppo_memory_log, "MFOS_pretrained": dumb_log, + "TitForTatStrictStay": dumb_log, + "TitForTatStrictSwitch": dumb_log, + "TitForTatCooperate": dumb_log, + "TitForTatDefect": dumb_log, } - if args.runner == "sarl": + if args.runner in ["sarl", "weight_sharing"]: assert args.agent1 in strategies agent_1_log = naive_pg_log # strategies[args.agent1] # - return agent_1_log else: agent_log = [] - for i in range(1, args.num_players + 1): - assert getattr(args, f"agent{i}") in strategies - agent_log.append(strategies[getattr(args, f"agent{i}")]) + default_agent = omegaconf.OmegaConf.select( + args, "agent_default", default=None + ) + agent_strategies = [ + omegaconf.OmegaConf.select( + args, "agent" + str(i), default=default_agent + ) + for i in range(1, args.num_players + 1) + ] + for strategy in agent_strategies: + assert strategy in strategies + agent_log.append(strategies[strategy]) return agent_log -@hydra.main(config_path="conf", config_name="config") +@hydra.main(config_path="conf", config_name="config", version_base="1.1") def main(args): + print(f"Jax backend: {xla_bridge.get_backend().platform}") + """Set up main.""" logger = logging.getLogger() with Section("Global setup", logger=logger): @@ -665,7 +785,6 @@ def main(args): if args.runner == "evo" or args.runner == "multishaper_evo": runner.run_loop(env_params, agent_pair, args.num_iters, watchers) - elif args.runner == "rl" or args.runner == "tensor_rl_nplayer": # number of episodes print(f"Number of Episodes: {args.num_iters}") @@ -673,10 +792,12 @@ def main(args): elif args.runner == "ipditm_eval" or args.runner == "multishaper_eval": runner.run_loop(env_params, agent_pair, watchers) - elif args.runner == "sarl": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) + elif args.runner == "weight_sharing": + print(f"Number of Episodes: {args.num_iters}") + runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) elif args.runner == "eval": print(f"Number of Episodes: {args.num_iters}") runner.run_loop(env, env_params, agent_pair, args.num_iters, watchers) diff --git a/pax/runners/README.md b/pax/runners/README.md new file mode 100644 index 00000000..4edd488d --- /dev/null +++ b/pax/runners/README.md @@ -0,0 +1,26 @@ +# Notes on runners + +## MARL runner + +Main agent semantics of the training loop: + +```psuedo +init env and agents + +_rollout for each iteration: + agent1.reset_memory() + if env meta: agent2.init() + env.reset() + + _outer_rollout for each outer_step: + + _inner_rollout for each inner_step: + agent1.act() + agent2.act() + env.step() + + + agent1.reset_memory() + agent2.reset_memory() + +``` diff --git a/pax/runners/runner_eval.py b/pax/runners/runner_eval.py index 5edc57b0..c66b95d0 100644 --- a/pax/runners/runner_eval.py +++ b/pax/runners/runner_eval.py @@ -8,6 +8,9 @@ import wandb from pax.utils import load from pax.watchers import cg_visitation, ipd_visitation +from pax.watchers.c_rice import c_rice_eval_stats, c_rice_stats +from pax.watchers.fishery import fishery_stats +from pax.watchers.rice import rice_eval_stats, rice_stats MAX_WANDB_CALLS = 10000 @@ -29,7 +32,7 @@ class EvalRunner: Evaluation runner provides a convenient example for quickly writing a shaping eval runner for PAX. The EvalRunner class can be used to run any two agents together either in a meta-game or regular game, it composes together agents, - watchers, and the environment. Within the init, we declare vmaps and pmaps for training. + watchers, and the environment. Within the init, we declare vmaps for training. Args: agents (Tuple[agents]): The set of agents that will run in the experiment. Note, ordering is important for @@ -48,6 +51,8 @@ def __init__(self, agents, env, args): self.random_key = jax.random.PRNGKey(args.seed) self.run_path = args.run_path self.model_path = args.model_path + self.run_path2 = args.get("run_path2", None) + self.model_path2 = args.get("model_path2", None) self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) # VMAP for num envs: we vmap over the rng but not params @@ -66,7 +71,7 @@ def __init__(self, agents, env, args): self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) - agent1, agent2 = agents + agent1, agent2 = agents[0], agents[1] if args.agent1 == "NaiveEx": # special case where NaiveEx has a different call signature @@ -129,71 +134,110 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, + agent_order, ) = carry # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, 0, :] - # a1_rng = rngs[:, :, 1, :] - # a2_rng = rngs[:, :, 2, :] rngs = rngs[:, :, 3, :] - a1, a1_state, new_a1_mem = agent1.batch_policy( - a1_state, - obs1, - a1_mem, - ) - a2, a2_state, new_a2_mem = agent2.batch_policy( - a2_state, - obs2, - a2_mem, - ) - (next_obs1, next_obs2), env_state, rewards, done, info = env.step( + a1_actions = [] + new_a1_memories = [] + for _obs, _mem in zip(obs1, a1_mem, strict=True): + a1_action, a1_state, new_a1_memory = agent1.batch_policy( + a1_state, + _obs, + _mem, + ) + a1_actions.append(a1_action) + new_a1_memories.append(new_a1_memory) + + a2_actions = [] + new_a2_memories = [] + for _obs, _mem in zip(obs2, a2_mem, strict=True): + a2_action, a2_state, new_a2_memory = agent2.batch_policy( + a2_state, + _obs, + _mem, + ) + a2_actions.append(a2_action) + new_a2_memories.append(new_a2_memory) + + actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order] + obs, env_state, rewards, done, info = env.step( env_rng, env_state, - (a1, a2), + tuple(actions), env_params, ) - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) - traj2 = Sample( - obs2, - a2, - rewards[1], - new_a2_mem.extras["log_probs"], - new_a2_mem.extras["values"], - done, - a2_mem.hidden, - ) + inv_agent_order = jnp.argsort(agent_order) + obs = jnp.asarray(obs)[inv_agent_order] + rewards = jnp.asarray(rewards)[inv_agent_order] + + a1_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs1, + a2_actions, + rewards[: self.args.agent1_roles], + new_a1_memories, + a1_mem, + strict=True, + ) + ] + a2_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[self.args.agent1_roles :], + new_a2_memories, + a2_mem, + strict=True, + ) + ] + return ( rngs, - next_obs1, - next_obs2, - rewards[0], - rewards[1], + tuple(obs[: self.args.agent1_roles]), + tuple(obs[self.args.agent1_roles :]), + tuple(rewards[: self.args.agent1_roles]), + tuple(rewards[self.args.agent1_roles :]), a1_state, - new_a1_mem, + tuple(new_a1_memories), a2_state, - new_a2_mem, + tuple(new_a2_memories), env_state, env_params, + agent_order, ), ( - traj1, - traj2, + a1_trajectories, + a2_trajectories, + env_state, ) def _outer_rollout(carry, unused): """Runner for trial""" # play episode of the game - vals, trajectories = jax.lax.scan( + vals, stack = jax.lax.scan( _inner_rollout, carry, None, @@ -211,18 +255,29 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, + agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) # update second agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2, - a2_state, - a2_mem, - ) + if self.args.agent2_learning is False: + new_a2_memories = a2_mem + a2_metrics = {} + else: + new_a2_memories = [] + for _obs, mem, traj in zip( + obs2, a2_mem, stack[1], strict=True + ): + a2_state, a2_mem, a2_metrics = agent2.batch_update( + traj, + _obs, + a2_state, + mem, + ) + new_a2_memories.append(a2_mem) + new_a2_memories = tuple(new_a2_memories) return ( rngs, obs1, @@ -232,10 +287,11 @@ def _outer_rollout(carry, unused): a1_state, a1_mem, a2_state, - a2_mem, + new_a2_memories, env_state, env_params, - ), (*trajectories, a2_metrics) + agent_order, + ), (*stack, a2_metrics) self.rollout = jax.jit(_outer_rollout) @@ -243,19 +299,35 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): """Run evaluation of agents in environment""" print("Training") print("-----------------------") - agent1, agent2 = agents + agent1, agent2 = agents[0], agents[1] rng, _ = jax.random.split(self.random_key) a1_state, a1_mem = agent1._state, agent1._mem a2_state, a2_mem = agent2._state, agent2._mem - if watchers: + preload_agent_2 = ( + self.model_path2 is not None and self.run_path2 is not None + ) + + if watchers and not self.args.wandb.mode == "offline": wandb.restore( name=self.model_path, run_path=self.run_path, root=os.getcwd() ) + if preload_agent_2: + wandb.restore( + name=self.model_path2, + run_path=self.run_path2, + root=os.getcwd(), + ) + pretrained_params = load(self.model_path) a1_state = a1_state._replace(params=pretrained_params) + if preload_agent_2: + pretrained_params = load(self.model_path2) + a2_state = a2_state._replace(params=pretrained_params) + a2_pretrained_params = pretrained_params + num_iters = max( int(num_episodes / (self.args.num_envs * self.args.num_opps)), 1 ) @@ -263,40 +335,63 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): print(f"Log Interval {log_interval}") # RNG are the same for num_opps but different for num_envs - rngs = jnp.concatenate( - [jax.random.split(rng, self.args.num_envs)] * self.args.num_opps - ).reshape((self.args.num_opps, self.args.num_envs, -1)) + + a1_mem = (a1_mem,) * self.args.agent1_roles # run actual loop for i in range(num_episodes): + rng, _ = jax.random.split(rng) + rngs = jnp.concatenate( + [jax.random.split(rng, self.args.num_envs)] + * self.args.num_opps + ).reshape((self.args.num_opps, self.args.num_envs, -1)) + obs, env_state = env.reset(rngs, env_params) - rewards = [ - jnp.zeros((self.args.num_opps, self.args.num_envs)), - jnp.zeros((self.args.num_opps, self.args.num_envs)), - ] + rewards = [jnp.zeros((self.args.num_opps, self.args.num_envs))] * ( + self.args.agent1_roles + self.args.agent2_roles + ) + + if i % self.args.agent2_reset_interval == 0: + if self.args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs[1]) + elif self.args.env_type in ["meta"]: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split(rng, self.num_opps), a2_mem.hidden + ) + + if preload_agent_2: + # If we are preloading agent 2 we want to keep the state + a2_state = a2_state._replace( + params=jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, 0), + a2_pretrained_params, + ) + ) + a2_mem = (a2_mem,) * self.args.agent2_roles + + agent_order = jnp.arange(self.args.num_players) + if self.args.shuffle_players: + agent_order = jax.random.permutation(rng, agent_order) - if self.args.agent2 == "NaiveEx": - a2_state, a2_mem = agent2.batch_init(obs[1]) - elif self.args.env_type in ["meta"]: - # meta-experiments - init 2nd agent per trial - a2_state, a2_mem = agent2.batch_init( - jax.random.split(rng, self.num_opps), a2_mem.hidden - ) # run trials vals, stack = jax.lax.scan( self.rollout, ( rngs, - *obs, - *rewards, + tuple(obs[: self.args.agent1_roles]), + tuple(obs[self.args.agent1_roles :]), + tuple(rewards[: self.args.agent1_roles]), + tuple(rewards[self.args.agent1_roles :]), a1_state, a1_mem, a2_state, a2_mem, env_state, env_params, + agent_order, ), None, - length=self.args.num_steps // self.args.num_inner_steps, + length=self.args.num_steps, ) ( @@ -306,28 +401,69 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): r1, r2, a1_state, - a1_mem, + # We rename the memory since we do not carry it to the next rollout + _a1_mem, a2_state, - a2_mem, + _a2_mem, env_state, env_params, + agent_order, ) = vals - traj_1, traj_2, a2_metrics = stack + traj_1, traj_2, env_states, a2_metrics = stack - # reset second agent memory - a2_mem = agent2.batch_reset(a2_mem, False) + rewards_1 = jnp.concatenate([traj.rewards for traj in traj_1]) + rewards_2 = jnp.concatenate([traj.rewards for traj in traj_2]) # logging self.train_episodes += 1 if i % log_interval == 0: - print(f"Episode {i}") + print(f"Episode {i}/{num_episodes}") if self.args.env_id == "coin_game": env_stats = jax.tree_util.tree_map( lambda x: x.item(), self.cg_stats(env_state), ) - rewards_1 = traj_1.rewards.sum(axis=1).mean() - rewards_2 = traj_2.rewards.sum(axis=1).mean() + rewards_1 = rewards_1.sum(axis=1).mean() + rewards_2 = rewards_2.sum(axis=1).mean() + + elif self.args.env_id == "Fishery": + env_stats = fishery_stats( + traj_1 + traj_2, self.args.num_players + ) + rewards_1 = rewards_1.sum(axis=1).mean() + rewards_2 = rewards_2.sum(axis=1).mean() + + elif self.args.env_id == "Rice-N": + env_stats = rice_eval_stats( + traj_1 + traj_2, env_states, env + ) + env_stats = jax.tree_util.tree_map( + lambda x: x.tolist(), + env_stats, + ) + env_stats = env_stats | rice_stats( + traj_1 + traj_2, + self.args.num_players, + self.args.has_mediator, + ) + rewards_1 = rewards_1.sum(axis=1).mean() + rewards_2 = rewards_2.sum(axis=1).mean() + + elif self.args.env_id == "C-Rice-N": + env_stats = c_rice_eval_stats( + traj_1 + traj_2, env_states, env + ) + env_stats = jax.tree_util.tree_map( + lambda x: x.tolist(), + env_stats, + ) + env_stats = env_stats | c_rice_stats( + traj_1 + traj_2, + self.args.num_players, + self.args.has_mediator, + ) + rewards_1 = rewards_1.sum(axis=1).mean() + rewards_2 = rewards_2.sum(axis=1).mean() elif self.args.env_type in [ "meta", @@ -349,7 +485,19 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): rewards_2 = traj_2.rewards.mean() env_stats = {} - print(f"Env Stats: {env_stats}") + # Due to the permutation and roles the rewards are not in order + agent_indices = jnp.array( + [0] * self.args.agent1_roles + [1] * self.args.agent2_roles + ) + agent_indices = agent_indices[agent_order] + + # Omit rich per timestep statistics for cleaner logging + printable_env_stats = { + k: v + for k, v in env_stats.items() + if not k.startswith("states") + } + print(f"Env Stats: {printable_env_stats}") print( f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" ) @@ -364,11 +512,13 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): agent2._logger.metrics | flattened_metrics ) - for watcher, agent in zip(watchers, agents): + for watcher, agent in zip(watchers, agents, strict=True): watcher(agent) wandb.log( { "episodes": self.train_episodes, + "agent_order": agent_order, + "agent_indices": agent_indices, "train/episode_reward/player_1": float( rewards_1.mean() ), diff --git a/pax/runners/runner_eval_multishaper.py b/pax/runners/runner_eval_multishaper.py index fcc85929..ea2f698f 100644 --- a/pax/runners/runner_eval_multishaper.py +++ b/pax/runners/runner_eval_multishaper.py @@ -117,7 +117,7 @@ def _reshape_opp_dim(x): # set up agents # batch MemoryState not TrainingState for agent_idx, shaper_agent in enumerate(shapers): - agent_arg = f"agent{agent_idx+1}" + agent_arg = f"agent{agent_idx + 1}" if OmegaConf.select(args, agent_arg) == "NaiveEx": # special case where NaiveEx has a different call signature shaper_agent.batch_init = jax.jit( @@ -141,7 +141,7 @@ def _reshape_opp_dim(x): # go through opponents for agent_idx, target_agent in enumerate(targets): - agent_arg = f"agent{agent_idx+self.num_shapers+1}" + agent_arg = f"agent{agent_idx + self.num_shapers + 1}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": target_agent.batch_init = jax.jit( @@ -161,7 +161,7 @@ def _reshape_opp_dim(x): ) for agent_idx, shaper_agent in enumerate(shapers): - agent_arg = f"agent{agent_idx+1}" + agent_arg = f"agent{agent_idx + 1}" if OmegaConf.select(args, agent_arg) != "NaiveEx": # NaiveEx requires env first step to init. init_hidden = jnp.tile( @@ -175,7 +175,7 @@ def _reshape_opp_dim(x): ) for agent_idx, target_agent in enumerate(targets): - agent_arg = f"agent{agent_idx+self.num_shapers+1}" + agent_arg = f"agent{agent_idx + self.num_shapers + 1}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) != "NaiveEx": # NaiveEx requires env first step to init. @@ -321,7 +321,7 @@ def _outer_rollout(carry, unused): ) = vals # MFOS has to take a meta-action for each episode for agent_idx, shaper_agent in enumerate(shapers): - agent_arg = f"agent{agent_idx+1}" + agent_arg = f"agent{agent_idx + 1}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "MFOS": shapers_mem[agent_idx] = shaper_agent.meta_policy( @@ -385,7 +385,7 @@ def _rollout( targets_state = [None] * self.num_targets for agent_idx, target_agent in enumerate(targets): # if eg 2 shapers, agent3 is the first non-shaper - agent_arg = f"agent{agent_idx+1+self.num_shapers}" + agent_arg = f"agent{agent_idx + 1 + self.num_shapers}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": ( @@ -509,11 +509,11 @@ def run_loop(self, env_params, agents, watchers): targets_mem.append(target_agent._mem) for agent_idx, shaper_agent in enumerate(shaper_agents): - model_path = f"model_path{agent_idx+1}" - run_path = f"run_path{agent_idx+1}" + model_path = f"model_path{agent_idx + 1}" + run_path = f"run_path{agent_idx + 1}" if model_path not in self.args: raise ValueError( - f"Please provide a model path for shaper {agent_idx+1}" + f"Please provide a model path for shaper {agent_idx + 1}" ) wandb.restore( @@ -594,7 +594,7 @@ def run_loop(self, env_params, agents, watchers): # log the inner episodes shaper_rewards_log = [ { - f"eval/reward_per_timestep/shaper_{shaper_idx+1}": float( + f"eval/reward_per_timestep/shaper_{shaper_idx + 1}": float( traj.rewards[i].mean().item() ) for (shaper_idx, traj) in enumerate(shaper_traj) @@ -603,7 +603,7 @@ def run_loop(self, env_params, agents, watchers): ] target_rewards_log = [ { - f"eval/reward_per_timestep/target_{target_idx+1}": float( + f"eval/reward_per_timestep/target_{target_idx + 1}": float( traj.rewards[i].mean().item() ) for (target_idx, traj) in enumerate(target_traj) @@ -662,11 +662,11 @@ def run_loop(self, env_params, agents, watchers): | global_welfare_log[i] ) shaper_rewards_log = { - f"eval/meta_reward/shaper{idx+1}": float(rew.mean().item()) + f"eval/meta_reward/shaper{idx + 1}": float(rew.mean().item()) for (idx, rew) in enumerate(shapers_rewards) } target_rewards_log = { - f"eval/meta_reward/target{idx+1}": float(rew.mean().item()) + f"eval/meta_reward/target{idx + 1}": float(rew.mean().item()) for (idx, rew) in enumerate(targets_rewards) } diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index 9ce590b0..f5b4e459 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -8,32 +8,24 @@ from evosax import FitnessShaper import wandb -from pax.utils import MemoryState, TrainingState, save +from pax.utils import MemoryState, TrainingState, save, float_precision, Sample # TODO: import when evosax library is updated # from evosax.utils import ESLog from pax.watchers import ESLog, cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.fishery import fishery_stats +from pax.watchers.cournot import cournot_stats +from pax.watchers.rice import rice_stats +from pax.watchers.c_rice import c_rice_stats MAX_WANDB_CALLS = 1000 -class Sample(NamedTuple): - """Object containing a batch of data""" - - observations: jnp.ndarray - actions: jnp.ndarray - rewards: jnp.ndarray - behavior_log_probs: jnp.ndarray - behavior_values: jnp.ndarray - dones: jnp.ndarray - hiddens: jnp.ndarray - - class EvoRunner: """ - Evoluationary Strategy runner provides a convenient example for quickly writing + Evolutionary Strategy runner provides a convenient example for quickly writing a MARL runner for PAX. The EvoRunner class can be used to - run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. + run an RL agent (optimised by an Evolutionary Strategy) against a Reinforcement Learner. It composes together agents, watchers, and the environment. Within the init, we declare vmaps and pmaps for training. The environment provided must conform to a meta-environment. @@ -54,6 +46,8 @@ class EvoRunner: A tuple of experiment arguments used (usually provided by HydraConfig). """ + # TODO fix C901 (function too complex) + # flake8: noqa: C901 def __init__( self, agents, env, strategy, es_params, param_reshaper, save_dir, args ): @@ -77,6 +71,7 @@ def __init__( self.ipditm_stats = jax.jit( jax.vmap(ipditm_stats, in_axes=(0, 2, 2, None)) ) + self.cournot_stats = cournot_stats # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) # Evo Runner also has an additional pmap dim (num_devices, ...) @@ -106,7 +101,7 @@ def __init__( ) self.num_outer_steps = args.num_outer_steps - agent1, agent2 = agents + agent1, agent2 = agents[0], agents[1] # vmap agents accordingly # agent 1 is batched over popsize and num_opps @@ -192,14 +187,12 @@ def _inner_rollout(carry, unused): a2_mem, env_state, env_params, + agent_order, ) = carry # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, :, 0, :] - - # a1_rng = rngs[:, :, :, 1, :] - # a2_rng = rngs[:, :, :, 2, :] rngs = rngs[:, :, :, 3, :] a1, a1_state, new_a1_mem = agent1.batch_policy( @@ -207,18 +200,29 @@ def _inner_rollout(carry, unused): obs1, a1_mem, ) - a2, a2_state, new_a2_mem = agent2.batch_policy( - a2_state, - obs2, - a2_mem, - ) - (next_obs1, next_obs2), env_state, rewards, done, info = env.step( + a2_actions = [] + new_a2_memories = [] + for _obs, _mem in zip(obs2, a2_mem, strict=True): + a2_action, a2_state, new_a2_memory = agent2.batch_policy( + a2_state, + _obs, + _mem, + ) + a2_actions.append(a2_action) + new_a2_memories.append(new_a2_memory) + + actions = jnp.asarray([a1, *a2_actions])[agent_order] + obs, env_state, rewards, done, info = env.step( env_rng, env_state, - (a1, a2), + tuple(actions), env_params, ) + inv_agent_order = jnp.argsort(agent_order) + obs = jnp.asarray(obs)[inv_agent_order] + rewards = jnp.asarray(rewards)[inv_agent_order] + traj1 = Sample( obs1, a1, @@ -228,30 +232,42 @@ def _inner_rollout(carry, unused): done, a1_mem.hidden, ) - traj2 = Sample( - obs2, - a2, - rewards[1], - new_a2_mem.extras["log_probs"], - new_a2_mem.extras["values"], - done, - a2_mem.hidden, - ) + a2_trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, new_memory, memory in zip( + obs2, + a2_actions, + rewards[1:], + new_a2_memories, + a2_mem, + strict=True, + ) + ] + return ( rngs, - next_obs1, - next_obs2, + obs[0], + tuple(obs[1:]), rewards[0], - rewards[1], + tuple(rewards[1:]), a1_state, new_a1_mem, a2_state, - new_a2_mem, + tuple(new_a2_memories), env_state, env_params, + agent_order, ), ( traj1, - traj2, + a2_trajectories, ) def _outer_rollout(carry, unused): @@ -275,18 +291,24 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, + agent_order, ) = vals # MFOS has to take a meta-action for each episode if args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) # update second agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2, - a2_state, - a2_mem, - ) + new_a2_memories = [] + for _obs, mem, traj in zip( + obs2, a2_mem, trajectories[1], strict=True + ): + a2_state, a2_mem, a2_metrics = agent2.batch_update( + traj, + _obs, + a2_state, + mem, + ) + new_a2_memories.append(a2_mem) return ( rngs, obs1, @@ -296,9 +318,10 @@ def _outer_rollout(carry, unused): a1_state, a1_mem, a2_state, - a2_mem, + tuple(new_a2_memories), env_state, env_params, + agent_order, ), (*trajectories, a2_metrics) def _rollout( @@ -306,6 +329,7 @@ def _rollout( _rng_run: jnp.ndarray, _a1_state: TrainingState, _a1_mem: MemoryState, + _a2_state: TrainingState, _env_params: Any, ): # env reset @@ -317,9 +341,11 @@ def _rollout( obs, env_state = env.reset(env_rngs, _env_params) rewards = [ - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - jnp.zeros((args.popsize, args.num_opps, args.num_envs)), - ] + jnp.zeros( + (args.popsize, args.num_opps, args.num_envs), + dtype=float_precision, + ) + ] * (1 + args.agent2_roles) # Player 1 _a1_state = _a1_state._replace(params=_params) @@ -327,7 +353,6 @@ def _rollout( # Player 2 if args.agent2 == "NaiveEx": a2_state, a2_mem = agent2.batch_init(obs[1]) - else: # meta-experiments - init 2nd agent per trial a2_rng = jnp.concatenate( @@ -338,19 +363,31 @@ def _rollout( agent2._mem.hidden, ) + if _a2_state is not None: + a2_state = _a2_state + + agent_order = jnp.arange(args.num_players) + if args.shuffle_players: + agent_order = jax.random.permutation(_rng_run, agent_order) + + inv_agent_order = jnp.argsort(agent_order) + obs = jnp.asarray(obs)[inv_agent_order] # run trials vals, stack = jax.lax.scan( _outer_rollout, ( env_rngs, - *obs, - *rewards, + obs[0], + tuple(obs[1:]), + rewards[0], + tuple(rewards[1:]), _a1_state, _a1_mem, a2_state, - a2_mem, + (a2_mem,) * args.agent2_roles, env_state, _env_params, + agent_order, ), None, length=self.num_outer_steps, @@ -368,12 +405,19 @@ def _rollout( a2_mem, env_state, _env_params, + agent_order, ) = vals traj_1, traj_2, a2_metrics = stack # Fitness fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) - other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) + agent_2_rewards = jnp.concatenate( + [traj.rewards for traj in traj_2] + ) + other_fitness = agent_2_rewards.mean(axis=(0, 1, 3, 4)) + rewards_1 = traj_1.rewards.mean() + rewards_2 = agent_2_rewards.mean() + # Stats if args.env_id == "coin_game": env_stats = jax.tree_util.tree_map( @@ -383,7 +427,6 @@ def _rollout( rewards_1 = traj_1.rewards.sum(axis=1).mean() rewards_2 = traj_2.rewards.sum(axis=1).mean() - elif args.env_id in [ "iterated_matrix_game", ]: @@ -395,9 +438,6 @@ def _rollout( obs1, ), ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() - elif args.env_id == "InTheMatrix": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -408,12 +448,23 @@ def _rollout( args.num_envs, ), ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.cournot_stats(traj_1.observations, _env_params, 2), + ) + elif args.env_id == "Fishery": + env_stats = fishery_stats([traj_1] + traj_2, args.num_players) + elif args.env_id == "Rice-N": + env_stats = rice_stats( + [traj_1] + traj_2, args.num_players, args.has_mediator + ) + elif args.env_id == "C-Rice-N": + env_stats = c_rice_stats( + [traj_1] + traj_2, args.num_players, args.has_mediator + ) else: env_stats = {} - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() return ( fitness, @@ -422,11 +473,12 @@ def _rollout( rewards_1, rewards_2, a2_metrics, + a2_state, ) self.rollout = jax.pmap( _rollout, - in_axes=(0, None, None, None, None), + in_axes=(0, None, None, None, None, None), ) print( @@ -452,7 +504,7 @@ def run_loop( print(f"Log Interval: {log_interval}") print("------------------------------") # Initialize agents and RNG - agent1, agent2 = agents + agent1, agent2 = agents[0], agents[1] rng, _ = jax.random.split(self.random_key) # Initialize evolution @@ -489,6 +541,7 @@ def run_loop( ) a1_state, a1_mem = agent1._state, agent1._mem + a2_state = None for gen in range(num_gens): rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) @@ -500,6 +553,17 @@ def run_loop( params = jax.tree_util.tree_map( lambda x: jax.lax.expand_dims(x, (0,)), params ) + + if gen % self.args.agent2_reset_interval == 0: + a2_state = None + + if self.args.num_devices == 1 and a2_state is not None: + # The first rollout returns a2_state with an extra batch dim that + # will cause issues when passing it back to the vmapped batch_policy + a2_state = jax.tree_util.tree_map( + lambda w: jnp.squeeze(w, axis=0), a2_state + ) + # Evo Rollout ( fitness, @@ -508,10 +572,15 @@ def run_loop( rewards_1, rewards_2, a2_metrics, - ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) + a2_state, + ) = self.rollout( + params, rng_run, a1_state, a1_mem, a2_state, env_params + ) # Aggregate over devices - fitness = jnp.reshape(fitness, popsize * self.args.num_devices) + fitness = jnp.reshape( + fitness, popsize * self.args.num_devices + ).astype(dtype=jnp.float32) env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) # Tell @@ -524,9 +593,12 @@ def run_loop( # Logging log = es_logging.update(log, x, fitness) + is_last_loop = gen == num_iters - 1 # Saving - if gen % self.args.save_interval == 0: - log_savepath = os.path.join(self.save_dir, f"generation_{gen}") + if gen % self.args.save_interval == 0 or is_last_loop: + log_savepath1 = os.path.join( + self.save_dir, f"generation_{gen}" + ) if self.args.num_devices > 1: top_params = param_reshaper.reshape( log["top_gen_params"][0 : self.args.num_devices] @@ -541,15 +613,19 @@ def run_loop( top_params = jax.tree_util.tree_map( lambda x: x.reshape(x.shape[1:]), top_params ) - save(top_params, log_savepath) + save(top_params, log_savepath1) + log_savepath2 = os.path.join( + self.save_dir, f"agent2_iteration_{gen}" + ) + save(a2_state.params, log_savepath2) if watchers: - print(f"Saving generation {gen} locally and to WandB") - wandb.save(log_savepath) + print(f"Saving iteration {gen} locally and to WandB") + wandb.save(log_savepath1) + wandb.save(log_savepath2) else: print(f"Saving iteration {gen} locally") - - if gen % log_interval == 0: - print(f"Generation: {gen}") + if gen % log_interval == 0 or is_last_loop: + print(f"Generation: {gen}/{num_iters}") print( "--------------------------------------------------------------------------" ) @@ -605,13 +681,15 @@ def run_loop( wandb_log.update(env_stats) # loop through population for idx, (overall_fitness, gen_fitness) in enumerate( - zip(log["top_fitness"], log["top_gen_fitness"]) + zip( + log["top_fitness"], log["top_gen_fitness"], strict=True + ) ): wandb_log[ - f"train/fitness/top_overall_agent_{idx+1}" + f"train/fitness/top_overall_agent_{idx + 1}" ] = overall_fitness wandb_log[ - f"train/fitness/top_gen_agent_{idx+1}" + f"train/fitness/top_gen_agent_{idx + 1}" ] = gen_fitness # player 2 metrics @@ -621,7 +699,7 @@ def run_loop( ) agent2._logger.metrics.update(flattened_metrics) - for watcher, agent in zip(watchers, agents): + for watcher, agent in zip(watchers, agents, strict=True): watcher(agent) wandb_log = jax.tree_util.tree_map( lambda x: x.item() if isinstance(x, jax.Array) else x, diff --git a/pax/runners/runner_evo_multishaper.py b/pax/runners/runner_evo_multishaper.py index afe0e36e..6aae4902 100644 --- a/pax/runners/runner_evo_multishaper.py +++ b/pax/runners/runner_evo_multishaper.py @@ -137,7 +137,7 @@ def __init__( ) # go through opponents, we start with agent2 for agent_idx, target_agent in enumerate(targets): - agent_arg = f"agent{agent_idx+self.num_shapers+1}" + agent_arg = f"agent{agent_idx + self.num_shapers + 1}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": # special case where NaiveEx has a different call signature @@ -337,7 +337,7 @@ def _outer_rollout(carry, unused): ) = vals # MFOS has to take a meta-action for each episode for agent_idx, shaper_agent in enumerate(shapers): - agent_arg = f"agent{agent_idx+1}" + agent_arg = f"agent{agent_idx + 1}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "MFOS": shapers_mem[agent_idx] = shaper_agent.meta_policy( @@ -408,7 +408,7 @@ def _rollout( ) for agent_idx, target_agent in enumerate(targets): # if eg 2 shapers, agent3 is the first non-shaper - agent_arg = f"agent{agent_idx+1+self.num_shapers}" + agent_arg = f"agent{agent_idx + 1 + self.num_shapers}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": ( @@ -572,7 +572,6 @@ def run_loop( # Reshape a single agent's params before vmapping shaper_agents = agents[: self.num_shapers] - target_agents = agents[self.num_shapers :] init_hiddens = [ jnp.tile( @@ -601,7 +600,7 @@ def run_loop( shapers_params = [] old_evo_states = evo_states evo_states = [] - for shaper_idx, shaper_agent in enumerate(shaper_agents): + for shaper_idx, _ in enumerate(shaper_agents): shaper_rng, rng_evo = jax.random.split(rng_evo, 2) x, evo_state = strategy.ask( shaper_rng, old_evo_states[shaper_idx], es_params @@ -636,21 +635,23 @@ def run_loop( # Tell fitness_re = [ fit_shaper.apply(x, fitness) - for x, fitness in zip(xs, shapers_fitness) + for x, fitness in zip(xs, shapers_fitness, strict=True) ] if self.args.es.mean_reduce: fitness_re = [fit_re - fit_re.mean() for fit_re in fitness_re] evo_states = [ strategy.tell(x, fit_re, evo_state, es_params) - for x, fit_re, evo_state in zip(xs, fitness_re, evo_states) + for x, fit_re, evo_state in zip( + xs, fitness_re, evo_states, strict=True + ) ] # Logging logs = [ es_log.update(log, x, fitness) for es_log, log, x, fitness in zip( - es_logging, logs, xs, shapers_fitness + es_logging, logs, xs, shapers_fitness, strict=True ) ] # Saving @@ -726,7 +727,9 @@ def run_loop( ] rewards_strs = shaper_rewards_strs + target_rewards_strs rewards_val = shaper_rewards_val + target_rewards_val - rewards_dict = dict(zip(rewards_strs, rewards_val)) + rewards_dict = dict( + zip(rewards_strs, rewards_val, strict=True) + ) shaper_fitness_str = [ "train/fitness/shaper_" + str(i) @@ -745,7 +748,9 @@ def run_loop( fitness_strs = shaper_fitness_str + target_fitness_str fitness_vals = shaper_fitness_val + target_fitness_val - fitness_dict = dict(zip(fitness_strs, fitness_vals)) + fitness_dict = dict( + zip(fitness_strs, fitness_vals, strict=True) + ) shaper_welfare = float( sum([reward.mean() for reward in shapers_rewards]) @@ -796,7 +801,9 @@ def run_loop( # other player metrics # metrics [outer_timesteps, num_opps] - for agent, metrics in zip(agents[1:], targets_metrics): + for agent, metrics in zip( + agents[1:], targets_metrics, strict=True + ): flattened_metrics = jax.tree_util.tree_map( lambda x: jnp.sum(jnp.mean(x, 1)), metrics ) diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 98263fc4..92d1511c 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -4,16 +4,15 @@ import jax import jax.numpy as jnp - import wandb + +from pax.utils import MemoryState, TrainingState, save from pax.utils import ( - MemoryState, - TrainingState, copy_state_and_mem, - copy_state_and_network, - save, ) from pax.watchers import cg_visitation, ipd_visitation, ipditm_stats +from pax.watchers.cournot import cournot_stats +from pax.watchers.fishery import fishery_stats MAX_WANDB_CALLS = 1000 @@ -105,6 +104,7 @@ def _reshape_opp_dim(x): self.reduce_opp_dim = jax.jit(_reshape_opp_dim) self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) + self.cournot_stats = cournot_stats # VMAP for num_envs self.ipditm_stats = jax.jit(ipditm_stats) # VMAP for num envs: we vmap over the rng but not params @@ -126,7 +126,6 @@ def _reshape_opp_dim(x): self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) num_outer_steps = self.args.num_outer_steps agent1, agent2 = agents - agent1.agent2 = agent2 # Pointer for LOLA # set up agents if args.agent1 == "NaiveEx": @@ -139,11 +138,6 @@ def _reshape_opp_dim(x): (None, 0), (None, 0), ) - if args.agent1 == "LOLA": - # batch for num_opps - agent1.batch_in_lookahead = jax.vmap( - agent1.in_lookahead, (0, None, 0, 0, 0), (0, 0) - ) agent1.batch_reset = jax.jit( jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 ) @@ -239,15 +233,16 @@ def _inner_rollout(carry, unused): a1_mem.hidden, a1_mem.th, ) - traj1 = Sample( - obs1, - a1, - rewards[0], - new_a1_mem.extras["log_probs"], - new_a1_mem.extras["values"], - done, - a1_mem.hidden, - ) + else: + traj1 = Sample( + obs1, + a1, + rewards[0], + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) traj2 = Sample( obs2, a2, @@ -333,9 +328,8 @@ def _rollout( rngs = jnp.concatenate( [jax.random.split(_rng_run, args.num_envs)] * args.num_opps ).reshape((args.num_opps, args.num_envs, -1)) - _rng_run, _ = jax.random.split(_rng_run) - obs, env_state = env.batch_reset(rngs, _env_params) + obs, env_state = env.reset(rngs, _env_params) rewards = [ jnp.zeros((args.num_opps, args.num_envs)), jnp.zeros((args.num_opps, args.num_envs)), @@ -426,6 +420,8 @@ def _rollout( a1_mem = agent1.batch_reset(a1_mem, False) a2_mem = agent2.batch_reset(a2_mem, False) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() # Stats if args.env_id == "coin_game": env_stats = jax.tree_util.tree_map( @@ -435,7 +431,6 @@ def _rollout( rewards_1 = traj_1.rewards.sum(axis=1).mean() rewards_2 = traj_2.rewards.sum(axis=1).mean() - elif args.env_id == "iterated_matrix_game": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -445,8 +440,6 @@ def _rollout( obs1, ), ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() elif args.env_id == "InTheMatrix": env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -457,12 +450,15 @@ def _rollout( args.num_envs, ), ) - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() + elif args.env_id == "Cournot": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.cournot_stats(traj_1.observations, _env_params, 2), + ) + elif args.env_id == "Fishery": + env_stats = fishery_stats([traj_1, traj_2], 2) else: env_stats = {} - rewards_1 = traj_1.rewards.mean() - rewards_2 = traj_2.rewards.mean() return ( env_stats, @@ -536,7 +532,7 @@ def run_loop(self, env_params, agents, num_iters, watchers): for stat in env_stats.keys(): print(stat + f": {env_stats[stat].item()}") print( - f"Reward per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" + f"Average Reward per Timestep: {float(rewards_1.mean()), float(rewards_2.mean())}" ) print() diff --git a/pax/runners/runner_marl_nplayer.py b/pax/runners/runner_marl_nplayer.py index 1c43e819..5244a098 100644 --- a/pax/runners/runner_marl_nplayer.py +++ b/pax/runners/runner_marl_nplayer.py @@ -9,6 +9,8 @@ import wandb from pax.utils import MemoryState, TrainingState, copy_state_and_mem, save from pax.watchers import n_player_ipd_visitation +from pax.watchers.cournot import cournot_stats +from pax.watchers.fishery import fishery_stats MAX_WANDB_CALLS = 1000 @@ -99,6 +101,7 @@ def _reshape_opp_dim(x): self.reduce_opp_dim = jax.jit(_reshape_opp_dim) self.ipd_stats = n_player_ipd_visitation + self.cournot_stats = cournot_stats # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) env.step = jax.vmap( @@ -150,7 +153,7 @@ def _reshape_opp_dim(x): # go through opponents, we start with agent2 for agent_idx, non_first_agent in enumerate(other_agents): - agent_arg = f"agent{agent_idx+2}" + agent_arg = f"agent{agent_idx + 2}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": # special case where NaiveEx has a different call signature @@ -207,11 +210,8 @@ def _inner_rollout(carry, unused): # unpack rngs rngs = self.split(rngs, 4) env_rng = rngs[:, :, 0, :] - # a1_rng = rngs[:, :, 1, :] - # a2_rng = rngs[:, :, 2, :] rngs = rngs[:, :, 3, :] new_other_agent_mem = [None] * len(other_agents) - actions = [] ( first_action, @@ -222,7 +222,7 @@ def _inner_rollout(carry, unused): first_agent_obs, first_agent_mem, ) - actions.append(first_action) + actions = [first_action] for agent_idx, non_first_agent in enumerate(other_agents): ( non_first_action, @@ -377,7 +377,7 @@ def _rollout( for agent_idx, non_first_agent in enumerate(other_agents): # indexing starts at 2 for args - agent_arg = f"agent{agent_idx+2}" + agent_arg = f"agent{agent_idx + 2}" # equivalent of args.agent_n if OmegaConf.select(args, agent_arg) == "NaiveEx": ( @@ -395,6 +395,7 @@ def _rollout( non_first_agent._mem.hidden, ) _rng_run, other_agent_rng = jax.random.split(_rng_run, 2) + # run trials vals, stack = jax.lax.scan( _outer_rollout, @@ -488,7 +489,8 @@ def _rollout( other_agent_mem[agent_idx] = non_first_agent.batch_reset( other_agent_mem[agent_idx], False ) - # Stats + + total_rewards = [traj.rewards.mean() for traj in trajectories] if args.env_id == "iterated_nplayer_tensor_game": total_env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -497,10 +499,19 @@ def _rollout( num_players=args.num_players, ), ) - total_rewards = [traj.rewards.mean() for traj in trajectories] + elif args.env_id == "Cournot": + total_env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cournot_stats( + trajectories[0].observations, + _env_params, + args.num_players, + ), + ) + elif args.env_id == "Fishery": + total_env_stats = fishery_stats(trajectories, args.num_players) else: total_env_stats = {} - total_rewards = [traj.rewards.mean() for traj in trajectories] return ( total_env_stats, @@ -514,15 +525,16 @@ def _rollout( trajectories, ) - # self.rollout = _rollout self.rollout = jax.jit(_rollout) def run_loop(self, env_params, agents, num_iters, watchers): """Run training of agents in environment""" print("Training") print("-----------------------") - # log_interval = int(max(num_iters / MAX_WANDB_CALLS, 5)) - log_interval = 100 + num_iters = max( + int(num_iters / (self.args.num_envs * self.num_opps)), 1 + ) + log_interval = int(max(num_iters / MAX_WANDB_CALLS, 5)) save_interval = self.args.save_interval agent1, *other_agents = agents @@ -594,6 +606,15 @@ def run_loop(self, env_params, agents, num_iters, watchers): agent1._logger.metrics = ( agent1._logger.metrics | flattened_metrics_1 ) + for agent, metric in zip( + other_agents, other_agent_metrics + ): + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.mean(x), first_agent_metrics + ) + agent._logger.metrics = ( + agent._logger.metrics | flattened_metrics + ) for watcher, agent in zip(watchers, agents): watcher(agent) diff --git a/pax/runners/runner_sarl.py b/pax/runners/runner_sarl.py index a573e18b..13816699 100644 --- a/pax/runners/runner_sarl.py +++ b/pax/runners/runner_sarl.py @@ -7,6 +7,7 @@ import wandb from pax.utils import MemoryState, TrainingState, save +from pax.watchers.rice import rice_sarl_stats # from jax.config import config # config.update('jax_disable_jit', True) @@ -36,6 +37,7 @@ def __init__(self, agent, env, save_dir, args): self.args = args self.random_key = jax.random.PRNGKey(args.seed) self.save_dir = save_dir + self.rice_stats = rice_sarl_stats # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) @@ -80,8 +82,6 @@ def _inner_rollout(carry, unused): # import pdb; pdb.set_trace() rngs = self.split(rngs, 2) env_rng = rngs[:, 0, :] - # a1_rng = rngs[:, 1, :] - # a2_rng = rngs[:, 2, :] rngs = rngs[:, 1, :] a1, a1_state, new_a1_mem = agent.batch_policy( @@ -162,12 +162,12 @@ def _rollout( _a1_mem, ) - # reset memory _a1_mem = agent.batch_reset(_a1_mem, False) - # Stats rewards = jnp.sum(traj.rewards) / (jnp.sum(traj.dones) + 1e-8) env_stats = {} + if args.env_id == "SarlRice-N": + env_stats = self.rice_stats(traj, args.num_players) return ( env_stats, @@ -177,8 +177,7 @@ def _rollout( _a1_metrics, ) - self.rollout = _rollout - # self.rollout = jax.jit(_rollout) + self.rollout = jax.jit(_rollout) def run_loop(self, env, env_params, agent, num_iters, watcher): """Run training of agent in environment""" @@ -206,7 +205,9 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): a1_metrics, ) = self.rollout(rng_run, a1_state, a1_mem, env_params) - if i % self.args.save_interval == 0: + is_last_loop = i == num_iters - 1 + + if i % self.args.save_interval == 0 or is_last_loop: log_savepath = os.path.join(self.save_dir, f"iteration_{i}") save(a1_state.params, log_savepath) if watcher: @@ -217,11 +218,13 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): # logging self.train_episodes += 1 - if num_iters % log_interval == 0: - print(f"Episode {i}") + if num_iters % log_interval == 0 or is_last_loop: + print(f"Episode {i}/{num_iters}") - print(f"Env Stats: {env_stats}") - print(f"Total Episode Reward: {float(rewards_1.mean())}") + print( + f"Env Stats: {jax.tree_map(lambda x: x.item(), env_stats)}" + ) + print(f"Total Reward per Episode: {float(rewards_1.mean())}") print() if watcher: diff --git a/pax/runners/runner_weight_sharing.py b/pax/runners/runner_weight_sharing.py new file mode 100644 index 00000000..291d6ec4 --- /dev/null +++ b/pax/runners/runner_weight_sharing.py @@ -0,0 +1,265 @@ +import os +import time +from typing import Any, List, Tuple + +import jax +import jax.numpy as jnp +import wandb + +from pax.utils import MemoryState, TrainingState, save, Sample +from pax.watchers import fishery_stats +from pax.watchers.rice import rice_stats +from pax.watchers.c_rice import c_rice_stats + +MAX_WANDB_CALLS = 1000000 + +""" +This runner implements weight sharing by letting the agent assume each role in the environment per turn. +""" + + +class WeightSharingRunner: + """Holds the runner's state.""" + + id = "weight_sharing" + + def __init__(self, agent, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.random.split, (0, None)) + # set up agent + if args.agent1 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent.batch_init = jax.jit(jax.vmap(agent.make_initial_state)) + else: + # batch MemoryState not TrainingState + agent.batch_init = jax.jit(agent.make_initial_state) + + agent.batch_reset = jax.jit(agent.reset_memory, static_argnums=1) + + agent.batch_policy = jax.jit(agent._policy) + + if args.agent1 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent._mem.hidden, (1)) + agent._state, agent._mem = agent.batch_init( + agent._state.random_key, init_hidden + ) + + def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]: + """Runner for inner episode""" + ( + rngs, + obs, + a1_state, + memories, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 2) + env_rng = rngs[:, 0, :] + rngs = rngs[:, 1, :] + + actions = [] + new_memories = [] + for i in range(len(obs)): + action, a1_state, new_a1_mem = agent.batch_policy( + a1_state, + obs[i], + memories[i], + ) + actions.append(action) + new_memories.append(new_a1_mem) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + tuple(actions), + env_params, + ) + + trajectories = [ + Sample( + observation, + action, + reward * jnp.logical_not(done), + new_memory.extras["log_probs"], + new_memory.extras["values"], + done, + memory.hidden, + ) + for observation, action, reward, memory, new_memory in zip( + obs, actions, rewards, memories, new_memories, strict=True + ) + ] + + return ( + rngs, + next_obs, + a1_state, + new_memories, + env_state, + env_params, + ), trajectories + + def _rollout( + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _memories: List[MemoryState], + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + ).reshape((args.num_envs, -1)) + + obs, env_state = env.reset(rngs, _env_params) + _memories = [agent.batch_reset(mem, False) for mem in _memories] + + # run trials + vals, trajectories = jax.lax.scan( + _inner_rollout, + ( + rngs, + obs, + _a1_state, + _memories, + env_state, + _env_params, + ), + None, + length=args.num_steps, + ) + + ( + rngs, + obs, + _a1_state, + _memories, + env_state, + env_params, + ) = vals + + for i in range(len(trajectories)): + _a1_state, _, _a1_metrics = agent.update( + trajectories[i], + obs[i], + _a1_state, + _memories[i], + ) + + # reset memory + _memories = [agent.batch_reset(_mem, False) for _mem in _memories] + + # Stats + rewards = [] + num_episodes = jnp.sum(trajectories[0].dones) + for traj in trajectories: + rewards.append( + jnp.where( + num_episodes != 0, + jnp.sum(traj.rewards) / num_episodes, + 0, + ) + ) + env_stats = {} + if args.env_id == "Rice-N": + env_stats = rice_stats( + trajectories, args.num_players, args.has_mediator + ) + elif args.env_id == "C-Rice-N": + env_stats = c_rice_stats( + trajectories, args.num_players, args.has_mediator + ) + elif args.env_id == "Fishery": + env_stats = fishery_stats(trajectories, args.num_players) + + return ( + env_stats, + rewards, + _a1_state, + _memories, + _a1_metrics, + ) + + self.rollout = jax.jit(_rollout) + + def run_loop(self, env, env_params, agent, num_iters, watcher): + """Run training of agent in environment""" + print("Training") + print("-----------------------") + agent = agent + rng, _ = jax.random.split(self.random_key) + + a1_state, a1_mem = agent._state, agent._mem + memories = tuple([a1_mem for _ in range(self.args.num_players)]) + num_iters = max(int(num_iters / (self.args.num_envs)), 1) + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + + print(f"Log Interval {log_interval}") + print(f"Running for total iterations: {num_iters}") + # run actual loop + for i in range(num_iters): + rng, rng_run = jax.random.split(rng, 2) + # RL Rollout + ( + env_stats, + rewards, + a1_state, + memories, + a1_metrics, + ) = self.rollout(rng_run, a1_state, memories, env_params) + + is_last_iter = i == num_iters - 1 + if i % self.args.save_interval == 0 or is_last_iter: + log_savepath = os.path.join(self.save_dir, f"iteration_{i}") + save(a1_state.params, log_savepath) + if watcher: + print(f"Saving iteration {i} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {i} locally") + + # logging + self.train_episodes += 1 + if num_iters % log_interval == 0 or is_last_iter: + print(f"Episode {i}/{num_iters}") + + print(f"Env Stats: {env_stats}") + print(f"Total Episode Reward: {float(sum(rewards))}") + print() + + if watcher: + # metrics [outer_timesteps] + flattened_metrics_1 = jax.tree_util.tree_map( + lambda x: jnp.mean(x), a1_metrics + ) + agent._logger.metrics = ( + agent._logger.metrics | flattened_metrics_1 + ) + + watcher(agent) + wandb.log( + { + "episodes": self.train_episodes, + } + | env_stats, + ) + + agent._state = a1_state + return agent diff --git a/pax/utils.py b/pax/utils.py index 284e337c..3acbd7d1 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -5,9 +5,21 @@ import chex import haiku as hk import jax -import jax.numpy as jnp import numpy as np import optax +from jax import numpy as jnp + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray class Section(object): @@ -209,3 +221,7 @@ def copy_extended_state_and_network(agent): policy_network = agent.policy_network value_network = agent.value_network return state, policy_network, value_network + + +# TODO make this part of the args +float_precision = jnp.float32 diff --git a/pax/version.py b/pax/version.py index 64a343ff..da1cc9a3 100644 --- a/pax/version.py +++ b/pax/version.py @@ -1,2 +1,3 @@ -__version__ = "0.1.0b" -git_version = "962c352f152f5e739413c10c39542c1480435465" +# fmt: off +__version__ = "0.1.0b+2f5c2ed" +git_version = "2f5c2eda3f66bec784cd996cf40f7a7ac76b8436" diff --git a/pax/watchers.py b/pax/watchers/__init__.py similarity index 95% rename from pax/watchers.py rename to pax/watchers/__init__.py index 14ca65a8..089842b5 100644 --- a/pax/watchers.py +++ b/pax/watchers/__init__.py @@ -2,7 +2,7 @@ import itertools import pickle from functools import partial -from typing import NamedTuple +from typing import NamedTuple, Any import chex import jax @@ -12,8 +12,11 @@ import pax.agents.hyper.ppo as HyperPPO import pax.agents.ppo.ppo as PPO from pax.agents.naive_exact import NaiveExact -from pax.envs.in_the_matrix import InTheMatrix -from pax.envs.iterated_matrix_game import EnvState, IteratedMatrixGame +from pax.envs.iterated_matrix_game import EnvState + + +from .fishery import fishery_stats +from .cournot import cournot_stats # five possible states START = jnp.array([[0, 0, 0, 0, 1]]) @@ -39,7 +42,7 @@ def policy_logger(agent) -> dict: ] # [layer_name]['w'] log_pi = nn.softmax(weights) probs = { - "policy/" + str(s): p[0] for (s, p) in zip(State, log_pi) + "policy/" + str(s): p[0] for (s, p) in zip(State, log_pi, strict=True) } # probability of cooperating is p[0] return probs @@ -47,10 +50,14 @@ def policy_logger(agent) -> dict: def value_logger(agent) -> dict: weights = agent.critic_optimizer.target["Dense_0"]["kernel"] values = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + f"value/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, weights, strict=True) } values.update( - {f"value/{str(s)}.defect": p[1] for (s, p) in zip(State, weights)} + { + f"value/{str(s)}.defect": p[1] + for (s, p) in zip(State, weights, strict=True) + } ) return values @@ -63,12 +70,12 @@ def policy_logger_dqn(agent) -> None: target_steps = agent.target_step_updates probs = { f"policy/player_{str(pid)}/{str(s)}.cooperate": p[0] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } probs.update( { f"policy/player_{str(pid)}/{str(s)}.defect": p[1] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } ) probs.update({"policy/target_step_updates": target_steps}) @@ -81,12 +88,12 @@ def value_logger_dqn(agent) -> dict: target_steps = agent.target_step_updates values = { f"value/player_{str(pid)}/{str(s)}.cooperate": p[0] - for (s, p) in zip(State, weights) + for (s, p) in zip(State, weights, strict=True) } values.update( { f"value/player_{str(pid)}/{str(s)}.defect": p[1] - for (s, p) in zip(State, weights) + for (s, p) in zip(State, weights, strict=True) } ) values.update({"value/target_step_updates": target_steps}) @@ -97,7 +104,10 @@ def policy_logger_ppo(agent: PPO) -> dict: weights = agent._state.params["categorical_value_head/~/linear"]["w"] pi = nn.softmax(weights) sgd_steps = agent._total_steps / agent._num_steps - probs = {f"policy/{str(s)}.cooperate": p[0] for (s, p) in zip(State, pi)} + probs = { + f"policy/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, pi, strict=True) + } probs.update({"policy/total_steps": sgd_steps}) return probs @@ -108,7 +118,8 @@ def value_logger_ppo(agent: PPO) -> dict: ] # 5 x 1 matrix sgd_steps = agent._total_steps / agent._num_steps probs = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + f"value/{str(s)}.cooperate": p[0] + for (s, p) in zip(State, weights, strict=True) } probs.update({"value/total_steps": sgd_steps}) return probs @@ -138,19 +149,18 @@ def policy_logger_ppo_with_memory(agent) -> dict: return {} -def naive_pg_losses(agent) -> None: +def naive_pg_losses(agent) -> dict[str, Any]: pid = agent.player_id sgd_steps = agent._logger.metrics["sgd_steps"] loss_total = agent._logger.metrics["loss_total"] loss_policy = agent._logger.metrics["loss_policy"] loss_value = agent._logger.metrics["loss_value"] - losses = { + return { f"train/naive_learner{pid}/sgd_steps": sgd_steps, f"train/naive_learner{pid}/total": loss_total, f"train/naive_learner{pid}/policy": loss_policy, f"train/naive_learner{pid}/value": loss_value, } - return losses def logger_hyper(agent: HyperPPO) -> dict: @@ -212,7 +222,7 @@ def policy_logger_naive(agent) -> None: sgd_steps = agent._total_steps / agent._num_steps probs = { f"policy/{str(s)}/{agent.player_id}.cooperate": p[0] - for (s, p) in zip(State, pi) + for (s, p) in zip(State, pi, strict=True) } probs.update({"policy/total_steps": sgd_steps}) return probs @@ -277,6 +287,7 @@ def update( self, log: chex.ArrayTree, x: chex.Array, fitness: chex.Array ) -> chex.ArrayTree: """Update the logging storage with newest data.""" + # Check if there are solutions better than current archive def get_top_idx(maximize: bool, vals: jnp.ndarray) -> jnp.ndarray: top_idx = maximize * ((-1) * vals).argsort() + ( @@ -497,10 +508,10 @@ def generate_grouped_combs_strs(num_players): ] visitation_dict = ( - dict(zip(visitation_strs, state_freq)) - | dict(zip(prob_strs, state_probs)) - | dict(zip(grouped_visitation_strs, grouped_state_freq)) - | dict(zip(grouped_prob_strs, grouped_state_probs)) + dict(zip(visitation_strs, state_freq, strict=True)) + | dict(zip(prob_strs, state_probs, strict=True)) + | dict(zip(grouped_visitation_strs, grouped_state_freq, strict=True)) + | dict(zip(grouped_prob_strs, grouped_state_probs, strict=True)) ) return visitation_dict @@ -796,14 +807,20 @@ def third_party_punishment_visitation( total_punishment_str = "total_punishment" visitation_dict = ( - dict(zip(all_game_visitation_strs, action_freq)) - | dict(zip(all_game_prob_strs, action_probs)) - | dict(zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq)) - | dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs)) - | dict(zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq)) - | dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs)) - | dict(zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq)) - | dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs)) + dict(zip(all_game_visitation_strs, action_freq, strict=True)) + | dict(zip(all_game_prob_strs, action_probs, strict=True)) + | dict( + zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq, strict=True) + ) + | dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs, strict=True)) + | dict( + zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq, strict=True) + ) + | dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs, strict=True)) + | dict( + zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq, strict=True) + ) + | dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs, strict=True)) | {pl1_total_defects_prob_str: pl1_total_defects_prob} | {pl2_total_defects_prob_str: pl2_total_defects_prob} | {pl3_total_defects_prob_str: pl3_total_defects_prob} @@ -861,10 +878,6 @@ def third_party_random_visitation( game2 = jnp.stack([prev_cd_actions[:, 1], prev_cd_actions[:, 2]], axis=-1) game3 = jnp.stack([prev_cd_actions[:, 2], prev_cd_actions[:, 0]], axis=-1) - selected_game_actions = jnp.stack( # len_idx3x2 - [game1, game2, game3], axis=1 - ) - b2i = 2 ** jnp.arange(2 - 1, -1, -1) pl1_vs_pl2 = (game1 * b2i).sum(axis=-1) pl2_vs_pl3 = (game2 * b2i).sum(axis=-1) @@ -993,10 +1006,10 @@ def third_party_random_visitation( ) visitation_dict = ( - dict(zip(game_prob_strs, action_probs)) - | dict(zip(game1_prob_strs, pl1_v_pl2_action_probs)) - | dict(zip(game2_prob_strs, pl2_v_pl3_action_probs)) - | dict(zip(game3_prob_strs, pl3_v_pl1_action_probs)) + dict(zip(game_prob_strs, action_probs, strict=True)) + | dict(zip(game1_prob_strs, pl1_v_pl2_action_probs, strict=True)) + | dict(zip(game2_prob_strs, pl2_v_pl3_action_probs, strict=True)) + | dict(zip(game3_prob_strs, pl3_v_pl1_action_probs, strict=True)) | {game_selected_punish_str: game_selected_punish} | {pl1_defects_prob_str: pl1_defect_prob} | {pl2_defects_prob_str: pl2_defect_prob} diff --git a/pax/watchers/c_rice.py b/pax/watchers/c_rice.py new file mode 100644 index 00000000..fec7967a --- /dev/null +++ b/pax/watchers/c_rice.py @@ -0,0 +1,174 @@ +from functools import partial +from typing import NamedTuple, List + +import jax +from jax import numpy as jnp + +from pax.envs.rice.c_rice import ClubRice +from pax.envs.rice.rice import EnvState, Rice + + +@partial(jax.jit, static_argnums=(1, 2)) +def c_rice_stats( + trajectories: List[NamedTuple], num_players: int, mediator: bool +) -> dict: + traj = trajectories[0] + # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim + result = { + "club_mitigation_rate": jnp.mean(traj.observations[..., 2]), + "club_tariff_rate": jnp.mean(traj.observations[..., 3]), + "temperature": jnp.mean(traj.observations[..., 4]), + "land_temperature": jnp.mean(traj.observations[..., 5]), + "carbon_atmosphere": jnp.mean(traj.observations[..., 6]), + "carbon_upper_ocean": jnp.mean(traj.observations[..., 7]), + "carbon_land": jnp.mean(traj.observations[..., 8]), + } + + region_vars = [ + "gross_output", + "investment", + "abatement_cost", + "tariff_revenue", + "club_membership", + ] + + offset = 9 + for i, label in enumerate(region_vars): + start = offset + i * num_players + end = offset + (i + 1) * num_players + result[label] = jnp.mean(traj.observations[..., start:end]) + + num_episodes = jnp.sum(traj.dones) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 4]) / num_episodes, + 0, + ) + + region_trajectories = ( + trajectories if mediator is False else trajectories[1:] + ) + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in region_trajectories] + ).sum() + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, total_reward / num_episodes, 0 + ) + + # Reward per region (necessary to compare weight_sharing against distributed methods) + for i, _traj in enumerate(region_trajectories): + result[f"train/total_reward_per_episode_region_{i}"] = jnp.where( + num_episodes != 0, _traj.rewards.sum() / num_episodes, 0 + ) + + return result + + +ep_length = 20 + + +@partial(jax.jit, static_argnums=(2)) +def c_rice_eval_stats( + trajectories: List[NamedTuple], env_state: EnvState, env: ClubRice +) -> dict: + # In the stacked env_state the inner steps are one long sequence in the first dimension + # But to compute step statistics we need it to be episodes x steps x ... + ep_count = int(env_state.global_temperature.shape[1] / 20) + env_state = jax.tree_util.tree_map( + lambda x: x.reshape((x.shape[0] * ep_count, ep_length, *x.shape[2:])), + env_state, + ) + # Because the initial obs are not included the timesteps are shifted like so: + # 1, 2, ..., 19, 0 + # Since the initial obs are always the same we can just shift them to the start + env_state: EnvState = jax.tree_util.tree_map( + lambda x: jnp.concatenate((x[:, -1:, ...], x[:, :-1, ...]), axis=1), + env_state, + ) + + def add_atrib(name, value, axis): + result[f"states/{name}_mean"] = value.mean(axis=axis) + result[f"states/{name}_std"] = value.std(axis=axis) + result[f"states/{name}_min"] = value.min(axis=axis) + result[f"states/{name}_max"] = value.max(axis=axis) + + result = {} + + add_atrib("temperature", env_state.global_temperature, axis=(0, 2, 3)) + add_atrib("carbon_mass", env_state.global_carbon_mass, axis=(0, 2, 3)) + add_atrib("labor", env_state.labor_all, axis=(0, 2, 3)) + add_atrib("capital", env_state.capital_all, axis=(0, 2, 3)) + add_atrib( + "production_factor", env_state.production_factor_all, axis=(0, 2, 3) + ) + add_atrib("intensity", env_state.intensity_all, axis=(0, 2, 3)) + add_atrib("balance", env_state.balance_all, axis=(0, 2, 3)) + add_atrib("gross_output", env_state.gross_output_all, axis=(0, 2, 3)) + add_atrib("investment", env_state.investment_all, axis=(0, 2, 3)) + add_atrib("production", env_state.production_all, axis=(0, 2, 3)) + add_atrib("consumption", env_state.consumption_all, axis=(0, 2, 3)) + add_atrib("abatement_cost", env_state.abatement_cost_all, axis=(0, 2, 3)) + add_atrib("tariff_revenue", env_state.tariff_revenue_all, axis=(0, 2, 3)) + add_atrib("carbon_price", env_state.carbon_price_all, axis=(0, 2, 3)) + add_atrib("club_membership", env_state.club_membership_all, axis=(0, 2, 3)) + add_atrib("utility", env_state.utility_all, axis=(0, 2, 3)) + add_atrib("social_welfare", env_state.social_welfare_all, axis=(0, 2, 3)) + add_atrib("mitigation_cost", env_state.mitigation_cost_all, axis=(0, 2, 3)) + add_atrib("damages", env_state.damages_all, axis=(0, 2, 3)) + + actions = jnp.stack([traj.actions for traj in trajectories]) + # n_players x rollouts x steps x ... x n_actions + # reshape -> n_players x (rollouts * n_episodes) x episode_steps x ... x n_actions + # transpose -> episode_steps x n_players x (rollouts * n_episodes) x ... x n_actions + actions = jax.tree_util.tree_map( + lambda x: x.reshape( + (x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:]) + ).transpose((2, 0, 1, 3, 4, 5)), + actions, + ) + add_atrib( + "savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4) + ) + add_atrib( + "mitigation_rate", + actions[..., env.mitigation_rate_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "export_limit", actions[..., env.export_action_index], axis=(2, 3, 4) + ) + add_atrib( + "club_join_action", + actions[..., env.join_club_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "imports", + actions[ + ..., + env.desired_imports_action_index : env.desired_imports_action_index + + env.import_actions_n, + ], + axis=(2, 3, 4), + ) + add_atrib( + "tariffs", + actions[ + ..., + env.tariffs_action_index : env.tariffs_action_index + + env.tariff_actions_n, + ], + axis=(2, 3, 4), + ) + + observations = trajectories[0].observations + observations = jax.tree_util.tree_map( + lambda x: x.reshape( + (x.shape[0] * ep_count, ep_length, *x.shape[2:]) + ).transpose((1, 0, 2, 3, 4)), + observations, + ) + add_atrib("club_mitigation_rate", observations[..., 2], axis=(1, 2, 3)) + add_atrib("club_tariff_rate", observations[..., 3], axis=(1, 2, 3)) + + return result diff --git a/pax/watchers/cournot.py b/pax/watchers/cournot.py new file mode 100644 index 00000000..63439a72 --- /dev/null +++ b/pax/watchers/cournot.py @@ -0,0 +1,28 @@ +from functools import partial + +import jax +from jax import numpy as jnp + +from pax.envs.cournot import EnvParams as CournotEnvParams, CournotGame + + +@partial(jax.jit, static_argnums=2) +def cournot_stats( + observations: jnp.ndarray, params: CournotEnvParams, num_players: int +) -> dict: + opt_quantity = CournotGame.nash_policy(params) + + actions = observations[..., :num_players] + average_quantity = actions.mean() + + stats = { + "cournot/average_quantity": average_quantity, + "cournot/quantity_loss": jnp.mean( + (opt_quantity - average_quantity) ** 2 + ), + } + + for i in range(num_players): + stats["cournot/quantity_" + str(i)] = jnp.mean(observations[..., i]) + + return stats diff --git a/pax/watchers/fishery.py b/pax/watchers/fishery.py new file mode 100644 index 00000000..49f4251a --- /dev/null +++ b/pax/watchers/fishery.py @@ -0,0 +1,74 @@ +from functools import partial +from typing import NamedTuple, List + +import jax +import numpy as np +import wandb +from jax import numpy as jnp + + +@partial(jax.jit, static_argnums=1) +def fishery_stats(trajectories: List[NamedTuple], num_players: int) -> dict: + traj = trajectories[0] + # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim + stock_obs = traj.observations[..., -1] + actions = traj.observations[..., :num_players] + completed_episodes = jnp.sum(traj.dones) + stats = { + "fishery/stock": jnp.mean(stock_obs), + "fishery/effort_mean": actions.mean(), + "fishery/effort_std": actions.std(), + "fishery/final_stock": jnp.where( + completed_episodes != 0, + jnp.sum(traj.dones * stock_obs) / jnp.sum(traj.dones), + 0, + ), + } + + for i in range(num_players): + stats["fishery/effort_" + str(i)] = jnp.mean(traj.observations[..., i]) + stats["fishery/effort_" + str(i) + "_std"] = jnp.mean( + traj.observations[..., i] + ) + stats["train/total_reward_" + str(i)] = jnp.where( + completed_episodes != 0, + trajectories[i].rewards.sum() / completed_episodes, + 0, + ) + + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in trajectories] + ).sum() + stats["train/total_reward_per_episode"] = jnp.where( + completed_episodes != 0, total_reward / completed_episodes, 0 + ) + + return stats + + +def fishery_eval_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict: + # Calculate effort for both agents + effort_1 = jax.nn.sigmoid(traj1.actions).squeeze().tolist() + effort_2 = jax.nn.sigmoid(traj2.actions).squeeze().tolist() + ep_length = np.arange(len(effort_1)).tolist() + + # Plot the effort on the same graph + effort_plot = wandb.plot.line_series( + xs=ep_length, + ys=[effort_1, effort_2], + keys=["effort_1", "effort_2"], + xname="step", + title="Agent effort over one episode", + ) + + stock_obs = traj1.observations[..., 0].squeeze().tolist() + stock_table = wandb.Table( + data=[[x, y] for (x, y) in zip(ep_length, stock_obs, strict=True)], + columns=["step", "stock"], + ) + # Plot the stock in a separate graph + stock_plot = wandb.plot.line( + stock_table, x="step", y="stock", title="Stock over one episode" + ) + + return {"fishery/stock": stock_plot, "fishery/effort": effort_plot} diff --git a/pax/watchers/rice.py b/pax/watchers/rice.py new file mode 100644 index 00000000..293449cd --- /dev/null +++ b/pax/watchers/rice.py @@ -0,0 +1,208 @@ +from functools import partial +from typing import NamedTuple, List + +import jax +from jax import numpy as jnp + +from pax.envs.rice.rice import EnvState, Rice + + +@partial(jax.jit, static_argnums=(1, 2)) +def rice_stats( + trajectories: List[NamedTuple], num_players: int, mediator: bool +) -> dict: + traj = trajectories[0] + # obs shape: num_outer_steps x num_inner_steps x num_opponents x num_envs x obs_dim + result = { + "temperature": jnp.mean(traj.observations[..., 2]), + "land_temperature": jnp.mean(traj.observations[..., 3]), + "carbon_atmosphere": jnp.mean(traj.observations[..., 4]), + "carbon_upper_ocean": jnp.mean(traj.observations[..., 5]), + "carbon_land": jnp.mean(traj.observations[..., 6]), + # Omitted: exogenous_emissions, land_emissions + } + + region_vars = [ + "labor", + "capital", + "gross_output", + "consumption", + "investment", + "balance", + "tariff_revenue", + "carbon_price", + "club_membership", + ] + + offset = 9 + for i, label in enumerate(region_vars): + start = offset + i * num_players + end = offset + (i + 1) * num_players + result[label] = jnp.mean(traj.observations[..., start:end]) + + num_episodes = jnp.sum(traj.dones) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 2]) / num_episodes, + 0, + ) + + region_trajectories = ( + trajectories if mediator is False else trajectories[1:] + ) + total_reward = jnp.array( + [jnp.sum(_traj.rewards) for _traj in region_trajectories] + ).sum() + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, total_reward / num_episodes, 0 + ) + + # Reward per region (necessary to compare weight_sharing against distributed methods) + for i, _traj in enumerate(region_trajectories): + result[f"train/total_reward_per_episode_region_{i}"] = jnp.where( + num_episodes != 0, _traj.rewards.sum() / num_episodes, 0 + ) + + return result + + +ep_length = 20 + + +@partial(jax.jit, static_argnums=(2)) +def rice_eval_stats( + trajectories: List[NamedTuple], env_state: EnvState, env: Rice +) -> dict: + # In the stacked env_state the inner steps are one long sequence in the first dimension + # But to compute step statistics we need it to be episodes x steps x ... + ep_count = int(env_state.global_temperature.shape[1] / 20) + env_state = jax.tree_util.tree_map( + lambda x: x.reshape((x.shape[0] * ep_count, ep_length, *x.shape[2:])), + env_state, + ) + # Because the initial obs are not included the timesteps are shifted like so: + # 1, 2, ..., 19, 0 + # Since the initial obs are always the same we can just shift them to the start + env_state: EnvState = jax.tree_util.tree_map( + lambda x: jnp.concatenate((x[:, -1:, ...], x[:, :-1, ...]), axis=1), + env_state, + ) + + def add_atrib(name, value, axis): + result[f"states/{name}_mean"] = value.mean(axis=axis) + result[f"states/{name}_std"] = value.std(axis=axis) + result[f"states/{name}_min"] = value.min(axis=axis) + result[f"states/{name}_max"] = value.max(axis=axis) + + result = {} + + add_atrib("temperature", env_state.global_temperature, axis=(0, 2, 3)) + add_atrib("carbon_mass", env_state.global_carbon_mass, axis=(0, 2, 3)) + add_atrib("labor", env_state.labor_all, axis=(0, 2, 3)) + add_atrib("capital", env_state.capital_all, axis=(0, 2, 3)) + add_atrib( + "production_factor", env_state.production_factor_all, axis=(0, 2, 3) + ) + add_atrib("intensity", env_state.intensity_all, axis=(0, 2, 3)) + add_atrib("balance", env_state.balance_all, axis=(0, 2, 3)) + add_atrib("gross_output", env_state.gross_output_all, axis=(0, 2, 3)) + add_atrib("investment", env_state.investment_all, axis=(0, 2, 3)) + add_atrib("production", env_state.production_all, axis=(0, 2, 3)) + add_atrib("consumption", env_state.consumption_all, axis=(0, 2, 3)) + add_atrib("abatement_cost", env_state.abatement_cost_all, axis=(0, 2, 3)) + add_atrib("tariff_revenue", env_state.tariff_revenue_all, axis=(0, 2, 3)) + add_atrib("carbon_price", env_state.carbon_price_all, axis=(0, 2, 3)) + add_atrib("club_membership", env_state.club_membership_all, axis=(0, 2, 3)) + add_atrib("utility", env_state.utility_all, axis=(0, 2, 3)) + add_atrib("social_welfare", env_state.social_welfare_all, axis=(0, 2, 3)) + add_atrib("mitigation_cost", env_state.mitigation_cost_all, axis=(0, 2, 3)) + add_atrib("damages", env_state.damages_all, axis=(0, 2, 3)) + + actions = jnp.stack([traj.actions for traj in trajectories]) + # n_players x rollouts x steps x ... x n_actions + # reshape -> n_players x (rollouts * n_episodes) x episode_steps x ... x n_actions + # transpose -> episode_steps x n_players x (rollouts * n_episodes) x ... x n_actions + actions = jax.tree_util.tree_map( + lambda x: x.reshape( + (x.shape[0], x.shape[1] * ep_count, ep_length, *x.shape[3:]) + ).transpose((2, 0, 1, 3, 4, 5)), + actions, + ) + add_atrib( + "savings_rate", actions[..., env.savings_action_index], axis=(2, 3, 4) + ) + add_atrib( + "mitigation_rate", + actions[..., env.mitigation_rate_action_index], + axis=(2, 3, 4), + ) + add_atrib( + "export_limit", actions[..., env.export_action_index], axis=(2, 3, 4) + ) + add_atrib( + "imports", + actions[ + ..., + env.desired_imports_action_index : env.desired_imports_action_index + + env.import_actions_n, + ], + axis=(2, 3, 4), + ) + add_atrib( + "tariffs", + actions[ + ..., + env.tariffs_action_index : env.tariffs_action_index + + env.tariff_actions_n, + ], + axis=(2, 3, 4), + ) + + return result + + +@partial(jax.jit, static_argnums=(1,)) +def rice_sarl_stats(traj: NamedTuple, num_players: int) -> dict: + # obs shape: num_steps x num_envs x obs_dim + result = { + # Actions + "savings_rate": jnp.mean(traj.actions[..., 0]), + "mitigation_rate": jnp.mean(traj.actions[..., 1]), + "temperature": jnp.mean(traj.observations[..., 1]), + "land_temperature": jnp.mean(traj.observations[..., 2]), + "carbon_atmosphere": jnp.mean(traj.observations[..., 3]), + "carbon_upper_ocean": jnp.mean(traj.observations[..., 4]), + "carbon_land": jnp.mean(traj.observations[..., 5]), + "exogenous_emissions": jnp.mean(traj.observations[..., 6]), + "land_emissions": jnp.mean(traj.observations[..., 7]), + } + + region_vars = [ + "labor", + "capital", + "gross_output", + "consumption", + "investment", + "balance", + "tariff_revenue", + "carbon_price", + "club_membership", + ] + + offset = 8 + for i, label in enumerate(region_vars): + start = offset + i * num_players + end = offset + (i + 1) * num_players + result[label] = jnp.mean(traj.observations[..., start:end]) + + num_episodes = jnp.sum(traj.dones) + result["final_temperature"] = jnp.where( + num_episodes != 0, + jnp.sum(traj.dones * traj.observations[..., 1]) / num_episodes, + 0, + ) + result["train/total_reward_per_episode"] = jnp.where( + num_episodes != 0, traj.rewards.sum() / num_episodes, 0 + ) + + return result diff --git a/requirements.txt b/requirements.txt index 229946ae..0d42fe91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,6 @@ numpy optax pytest wandb -pytest -pytest-cov \ No newline at end of file +pytest-cov +bsuite +tqdm diff --git a/test/envs/test_cournot.py b/test/envs/test_cournot.py new file mode 100644 index 00000000..6d17e796 --- /dev/null +++ b/test/envs/test_cournot.py @@ -0,0 +1,46 @@ +import jax +import jax.numpy as jnp + +from pax.envs.cournot import EnvParams, CournotGame + + +def test_single_cournot_game(): + rng = jax.random.PRNGKey(0) + + for n_player in [2, 3, 12]: + env = CournotGame(num_players=n_player, num_inner_steps=1) + # This means the optimum production quantity is Q = q1 + q2 = 2(a-marginal_cost)/3b = 60 + env_params = EnvParams(a=100, b=1, marginal_cost=10) + nash_q = CournotGame.nash_policy(env_params) + assert nash_q == 60 + nash_action = jnp.array([nash_q / n_player]) + + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, + env_state, + tuple([nash_action for _ in range(n_player)]), + env_params, + ) + + assert all(element == rewards[0] for element in rewards) + # p_opt = 100 - (30 + 30) = 40 + # r1_opt = 40 * 30 - 10 * 30 = 900 + nash_reward = CournotGame.nash_reward(env_params) + assert nash_reward == 1800 + assert jnp.isclose(nash_reward / n_player, rewards[0], atol=0.01) + expected_obs = jnp.array( + [60 / n_player for _ in range(n_player)] + [40] + ) + assert jnp.allclose(obs[0], expected_obs, atol=0.01) + assert jnp.allclose(obs[0], obs[1], atol=0.0) + + social_opt_action = jnp.array([45 / n_player]) + obs, env_state = env.reset(rng, env_params) + obs, env_state, rewards, done, info = env.step( + rng, + env_state, + tuple([social_opt_action for _ in range(n_player)]), + env_params, + ) + assert jnp.asarray(rewards).sum() == 2025 diff --git a/test/envs/test_fishery.py b/test/envs/test_fishery.py new file mode 100644 index 00000000..659e83bd --- /dev/null +++ b/test/envs/test_fishery.py @@ -0,0 +1,50 @@ +import jax +import jax.numpy as jnp + +from pax.envs.fishery import EnvParams, Fishery + + +def test_fishery_convergence(): + rng = jax.random.PRNGKey(0) + ep_length = 300 + + env = Fishery(num_players=2, num_inner_steps=ep_length) + env_params = EnvParams(g=0.15, e=0.009, P=200, w=0.9, s_0=1.0, s_max=1.0) + # response parameter + + obs, env_state = env.reset(rng, env_params) + d = 0.4 + E = 1.0 + step_reward = 0 + for i in range(3 * ep_length + 1): + if i != 0 and env_state.s > 0: + E = E + d * step_reward + + obs, env_state, rewards, done, info = env.step( + rng, env_state, (E / 2, E / 2), env_params + ) + step_reward = rewards[0] + rewards[1] + + # Check convergence at the end of an episode + if i % ep_length == ep_length - 2: + S_star = env_params.w / ( + env_params.P * env_params.e * env_params.s_max + ) # 0.5 + assert jnp.isclose(S_star, env_state.s, atol=0.01) + + H_star = ( + env_params.w * env_params.g / (env_params.P * env_params.e) + ) * ( + 1 - S_star + ) # 0.0375 + assert jnp.isclose(H_star, info["H"], atol=0.01) + + E_star = (env_params.g / env_params.e) * (1 - S_star) # ~ 8.3333 + assert jnp.isclose(E_star, E, atol=0.01) + + # Check that the environment resets correctly + if done is True: + assert env_state.inner_t == 0 + assert env_state.outer_t == i // ep_length + assert env_state.s == env_params.s_0 + assert step_reward == 0 diff --git a/test/envs/test_iterated_tensor_game_n_player.py b/test/envs/test_iterated_tensor_game_n_player.py index bffbd423..a80ab63e 100644 --- a/test/envs/test_iterated_tensor_game_n_player.py +++ b/test/envs/test_iterated_tensor_game_n_player.py @@ -9,7 +9,7 @@ IteratedTensorGameNPlayer, ) -####### 2 PLAYERS ####### +# 2 PLAYERS payoff_table_2pl = [ [4, jnp.nan], [2, 5], @@ -25,7 +25,7 @@ dc_obs = 2 dd_obs = 3 -####### 3 PLAYERS ####### +# 3 PLAYERS payoff_table_3pl = [ [4, jnp.nan], [2.66, 5.66], @@ -84,7 +84,7 @@ ddc_obs = 6 ddd_obs = 7 -####### 4 PLAYERS ####### +# 4 PLAYERS payoff_table_4pl = [ [4, jnp.nan], [3, 6], @@ -207,6 +207,7 @@ dddc_obs = 14 dddd_obs = 15 + # ###### Begin actual tests ####### @pytest.mark.parametrize("payoff", [payoff_table_2pl]) def test_single_batch_2pl(payoff) -> None: @@ -214,7 +215,7 @@ def test_single_batch_2pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 2 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=2, num_inner_steps=5, num_outer_steps=1 ) @@ -230,7 +231,7 @@ def test_single_batch_2pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 2 player + # test 2 player # cc obs, env_state, rewards, done, info = env.step( rng, env_state, (0 * action, 0 * action), env_params @@ -248,7 +249,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##dc + # dc obs, env_state, rewards, done, info = env.step( rng, env_state, (1 * action, 0 * action), env_params ) @@ -264,7 +265,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##cd + # cd obs, env_state, rewards, done, info = env.step( rng, env_state, (0 * action, 1 * action), env_params ) @@ -279,7 +280,7 @@ def test_single_batch_2pl(payoff) -> None: assert jnp.array_equal(obs[0], expected_obs1) assert jnp.array_equal(obs[1], expected_obs2) - ##dd + # dd obs, env_state, rewards, done, info = env.step( rng, env_state, (1 * action, 1 * action), env_params ) @@ -301,7 +302,7 @@ def test_single_batch_4pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 4 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=num_players, num_inner_steps=100, num_outer_steps=1 ) @@ -317,7 +318,7 @@ def test_single_batch_4pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 4 player + # test 4 player # cccc cccc = (0 * action, 0 * action, 0 * action, 0 * action) obs, env_state, rewards, done, info = env.step( @@ -510,7 +511,7 @@ def test_single_batch_3pl(payoff) -> None: rng = jax.random.PRNGKey(0) num_players = 3 len_one_hot = 2**num_players + 1 - ##### setup + # setup env = IteratedTensorGameNPlayer( num_players=num_players, num_inner_steps=100, num_outer_steps=1 ) @@ -526,7 +527,7 @@ def test_single_batch_3pl(payoff) -> None: ) obs, env_state = env.reset(rng, env_params) - ###### test 3 player + # test 3 player # ccc ccc = (0 * action, 0 * action, 0 * action) obs, env_state, rewards, done, info = env.step( @@ -617,9 +618,6 @@ def test_single_batch_3pl(payoff) -> None: def test_batch_diff_actions() -> None: - num_envs = 5 - all_ones = jnp.ones((num_envs,)) - rng = jax.random.PRNGKey(0) # 5 envs, with actions ccc, dcc, cdc, ddd, ccd diff --git a/test/envs/test_rice.py b/test/envs/test_rice.py new file mode 100644 index 00000000..82dbdc6a --- /dev/null +++ b/test/envs/test_rice.py @@ -0,0 +1,91 @@ +import os +import time + +import jax + +from pax.envs.rice.rice import Rice, EnvParams + +file_dir = os.path.join(os.path.dirname(__file__)) +config_folder = os.path.join(file_dir, "../../pax/envs/rice/5_regions") +num_players = 5 +ep_length = 20 + + +def test_rice(): + rng = jax.random.PRNGKey(0) + + env = Rice(config_folder=config_folder, episode_length=ep_length) + env_params = EnvParams() + obs, env_state = env.reset(rng, EnvParams()) + + key, _ = jax.random.split(rng, 2) + for i in range(100 * ep_length + 1): + # Do random actions + key, _ = jax.random.split(key, 2) + actions = jax.random.uniform(key, (num_players, env.num_actions)) + actions = tuple([action for action in actions]) + obs, env_state, rewards, done, info = env.step( + rng, env_state, actions, env_params + ) + for j in range(num_players): + # assert all obs positive + assert ( + env_state.consumption_all[j].item() >= 0 + ), "consumption cannot be negative!" + assert ( + env_state.production_all[j].item() >= 0 + ), "production cannot be negative!" + assert ( + env_state.labor_all[j].item() >= 0 + ), "labor cannot be negative!" + assert ( + env_state.capital_all[j].item() >= 0 + ), "capital cannot be negative!" + assert ( + env_state.gross_output_all[j].item() >= 0 + ), "gross output cannot be negative!" + assert ( + env_state.investment_all[j].item() >= 0 + ), "investment cannot be negative!" + + if done is True: + assert env_state.inner_t == 0 + assert (i + 1) / ep_length == 0 + assert ( + i + 1 + ) % ep_length == env_state.inner_t, "inner_t not updating correctly" + + +def rice_performance_benchmark(): + rng = jax.random.PRNGKey(0) + iterations = 1000 + + env = Rice(config_folder=config_folder, episode_length=ep_length) + env_params = EnvParams() + obs, env_state = env.reset(rng, EnvParams()) + + start_time = time.time() + + for _ in range(iterations * ep_length): + # Do random actions + key, _ = jax.random.split(rng, 2) + action = jax.random.uniform(rng, (env.num_actions,)) + actions = tuple([action for _ in range(num_players)]) + obs, env_state, rewards, done, info = env.step( + rng, env_state, actions, env_params + ) + + end_time = time.time() # End timing + total_time = end_time - start_time + + # Print or log the total time taken for all iterations + print(f"Total iterations:\t{iterations * ep_length}") + print(f"Total time taken:\t{total_time:.4f} seconds") + print( + f"Average step duration:\t{total_time / (iterations * ep_length):.6f} seconds" + ) + + +# Run a benchmark +if __name__ == "__main__": + rice_performance_benchmark() diff --git a/test/envs/test_rice/test_rice_regression.yml b/test/envs/test_rice/test_rice_regression.yml new file mode 100644 index 00000000..c3765319 --- /dev/null +++ b/test/envs/test_rice/test_rice_regression.yml @@ -0,0 +1,14780 @@ +0: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - - 1.0 + - 0.0 + - - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 1 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 49.29999923706055 + - 2.0 + - 23.30000114440918 + - 4.800000190734863 + - 4.400000095367432 + - - 13.40000057220459 + - 6.599999904632568 + - 3.6000001430511475 + - 2.299999952316284 + - 5.400000095367432 + - - 1.399999976158142 + - 0.10000000149011612 + - 2.0 + - 0.30000001192092896 + - 0.20000000298023224 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 1.5 + - 0.10000000149011612 + - 2.200000047683716 + - 0.30000001192092896 + - 0.20000000298023224 + obs: + - - 0.0 + - 1.0 + - 1.0 + - 0.0 + - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 49.29999923706055 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 1.0 + - 1.0 + - 0.0 + - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 2.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 1.0 + - 1.0 + - 0.0 + - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 23.30000114440918 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 1.0 + - 1.0 + - 0.0 + - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 4.800000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 1.0 + - 1.0 + - 0.0 + - 865.7999877929688 + - 471.3000183105469 + - 1740.7000732421875 + - 0.5 + - 0.5 + - 1154.800048828125 + - 285.70001220703125 + - 4026.900146484375 + - 1578.0 + - 619.7999877929688 + - 162.10000610351562 + - 5.400000095367432 + - 73.0 + - 15.40000057220459 + - 12.199999809265137 + - 48.900001525878906 + - 2.0 + - 22.700000762939453 + - 4.700000286102295 + - 4.400000095367432 + - 7.300000190734863 + - 0.10000000149011612 + - 3.4000000953674316 + - 0.4000000059604645 + - 0.30000001192092896 + - 24.399999618530273 + - 1.0 + - 11.300000190734863 + - 2.299999952316284 + - 2.200000047683716 + - 1.2000000476837158 + - -4.900000095367432 + - 1.2000000476837158 + - 1.2000000476837158 + - 1.2000000476837158 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 4.400000095367432 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.5 + - 0.10000000149011612 + - 2.200000047683716 + - 0.30000001192092896 + - 0.20000000298023224 +1: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - - 1.100000023841858 + - 0.10000000149011612 + - - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 2 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 68.0 + - 4.599999904632568 + - 34.70000076293945 + - 7.200000286102295 + - 8.100000381469727 + - - 13.600000381469727 + - 7.300000190734863 + - 3.9000000953674316 + - 2.4000000953674316 + - 5.599999904632568 + - - 1.399999976158142 + - 0.30000001192092896 + - 2.200000047683716 + - 0.4000000059604645 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.600000023841858 + - 0.30000001192092896 + - 2.6000001430511475 + - 0.5 + - 0.4000000059604645 + obs: + - - 0.0 + - 2.0 + - 1.100000023841858 + - 0.10000000149011612 + - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 68.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 2.0 + - 1.100000023841858 + - 0.10000000149011612 + - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 4.599999904632568 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 2.0 + - 1.100000023841858 + - 0.10000000149011612 + - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 34.70000076293945 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 2.0 + - 1.100000023841858 + - 0.10000000149011612 + - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 7.200000286102295 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 2.0 + - 1.100000023841858 + - 0.10000000149011612 + - 892.7000122070312 + - 482.1000061035156 + - 1741.4000244140625 + - 0.5 + - 0.5 + - 1155.9000244140625 + - 285.5 + - 4022.800048828125 + - 1610.2000732421875 + - 621.4000244140625 + - 264.20001220703125 + - 14.100000381469727 + - 127.4000015258789 + - 26.5 + - 27.30000114440918 + - 67.4000015258789 + - 4.400000095367432 + - 33.70000076293945 + - 7.0 + - 8.0 + - 9.0 + - 1.0 + - 4.5 + - 0.6000000238418579 + - 0.800000011920929 + - 33.70000076293945 + - 2.200000047683716 + - 16.899999618530273 + - 3.5 + - 4.0 + - 3.799999952316284 + - -15.40000057220459 + - 3.799999952316284 + - 3.799999952316284 + - 3.799999952316284 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 8.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.600000023841858 + - 0.30000001192092896 + - 2.6000001430511475 + - 0.5 + - 0.4000000059604645 +2: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - - 1.3000000715255737 + - 0.10000000149011612 + - - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 3 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 80.5 + - 6.700000286102295 + - 44.0 + - 8.90000057220459 + - 10.90000057220459 + - - 13.800000190734863 + - 7.700000286102295 + - 4.099999904632568 + - 2.4000000953674316 + - 5.800000190734863 + - - 1.3000000715255737 + - 0.30000001192092896 + - 2.299999952316284 + - 0.5 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.7000000476837158 + - 0.30000001192092896 + - 2.9000000953674316 + - 0.699999988079071 + - 0.5 + obs: + - - 0.0 + - 3.0 + - 1.3000000715255737 + - 0.10000000149011612 + - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 80.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 3.0 + - 1.3000000715255737 + - 0.10000000149011612 + - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 6.700000286102295 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 3.0 + - 1.3000000715255737 + - 0.10000000149011612 + - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 44.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 3.0 + - 1.3000000715255737 + - 0.10000000149011612 + - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 8.90000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 3.0 + - 1.3000000715255737 + - 0.10000000149011612 + - 927.6000366210938 + - 493.8999938964844 + - 1742.2000732421875 + - 0.6000000238418579 + - 0.5 + - 1156.9000244140625 + - 285.3000183105469 + - 4019.10009765625 + - 1642.4000244140625 + - 622.9000244140625 + - 355.20001220703125 + - 24.5 + - 182.3000030517578 + - 37.5 + - 42.900001525878906 + - 79.70000457763672 + - 6.5 + - 42.79999923706055 + - 8.699999809265137 + - 10.699999809265137 + - 10.600000381469727 + - 1.3000000715255737 + - 5.700000286102295 + - 0.9000000357627869 + - 1.2000000476837158 + - 39.79999923706055 + - 3.200000047683716 + - 21.399999618530273 + - 4.400000095367432 + - 5.400000095367432 + - 6.700000286102295 + - -26.899999618530273 + - 6.700000286102295 + - 6.700000286102295 + - 6.700000286102295 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 10.90000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.7000000476837158 + - 0.30000001192092896 + - 2.9000000953674316 + - 0.699999988079071 + - 0.5 +3: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - - 1.399999976158142 + - 0.10000000149011612 + - - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 4 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 88.80000305175781 + - 8.300000190734863 + - 51.5 + - 10.300000190734863 + - 12.800000190734863 + - - 13.800000190734863 + - 7.800000190734863 + - 4.200000286102295 + - 2.5 + - 5.800000190734863 + - - 1.3000000715255737 + - 0.30000001192092896 + - 2.4000000953674316 + - 0.6000000238418579 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.7000000476837158 + - 0.4000000059604645 + - 3.200000047683716 + - 0.800000011920929 + - 0.6000000238418579 + obs: + - - 0.0 + - 4.0 + - 1.399999976158142 + - 0.10000000149011612 + - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 88.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 4.0 + - 1.399999976158142 + - 0.10000000149011612 + - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 8.300000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 4.0 + - 1.399999976158142 + - 0.10000000149011612 + - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 51.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 4.0 + - 1.399999976158142 + - 0.10000000149011612 + - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 10.300000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 4.0 + - 1.399999976158142 + - 0.10000000149011612 + - 967.4000244140625 + - 507.5 + - 1743.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1157.800048828125 + - 285.1000061035156 + - 4015.800048828125 + - 1674.5999755859375 + - 624.2999877929688 + - 429.3000183105469 + - 34.400001525878906 + - 232.8000030517578 + - 47.10000228881836 + - 56.79999923706055 + - 87.80000305175781 + - 8.0 + - 50.10000228881836 + - 10.0 + - 12.600000381469727 + - 11.699999809265137 + - 1.600000023841858 + - 6.700000286102295 + - 1.100000023841858 + - 1.5 + - 43.900001525878906 + - 4.0 + - 25.0 + - 5.0 + - 6.300000190734863 + - 9.90000057220459 + - -39.60000228881836 + - 9.90000057220459 + - 9.90000057220459 + - 9.90000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.800000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.7000000476837158 + - 0.4000000059604645 + - 3.200000047683716 + - 0.800000011920929 + - 0.6000000238418579 +4: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - - 1.5 + - 0.10000000149011612 + - - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 5 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 94.5 + - 9.40000057220459 + - 57.29999923706055 + - 11.199999809265137 + - 14.100000381469727 + - - 13.90000057220459 + - 7.900000095367432 + - 4.300000190734863 + - 2.5 + - 5.900000095367432 + - - 1.2000000476837158 + - 0.30000001192092896 + - 2.299999952316284 + - 0.6000000238418579 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.7000000476837158 + - 0.4000000059604645 + - 3.4000000953674316 + - 0.800000011920929 + - 0.6000000238418579 + obs: + - - 0.0 + - 5.0 + - 1.5 + - 0.10000000149011612 + - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 94.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 5.0 + - 1.5 + - 0.10000000149011612 + - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 9.40000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 5.0 + - 1.5 + - 0.10000000149011612 + - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 57.29999923706055 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 5.0 + - 1.5 + - 0.10000000149011612 + - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.199999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 5.0 + - 1.5 + - 0.10000000149011612 + - 1010.0 + - 523.1000366210938 + - 1744.0999755859375 + - 0.6000000238418579 + - 0.5 + - 1158.7000732421875 + - 284.8999938964844 + - 4012.699951171875 + - 1706.800048828125 + - 625.6000366210938 + - 487.0 + - 42.70000076293945 + - 276.6000061035156 + - 55.10000228881836 + - 68.20000457763672 + - 93.4000015258789 + - 9.0 + - 55.70000076293945 + - 10.90000057220459 + - 13.90000057220459 + - 12.5 + - 1.7000000476837158 + - 7.400000095367432 + - 1.3000000715255737 + - 1.7000000476837158 + - 46.70000076293945 + - 4.5 + - 27.80000114440918 + - 5.5 + - 6.900000095367432 + - 13.40000057220459 + - -53.60000228881836 + - 13.40000057220459 + - 13.40000057220459 + - 13.40000057220459 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 14.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.7000000476837158 + - 0.4000000059604645 + - 3.4000000953674316 + - 0.800000011920929 + - 0.6000000238418579 +5: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - - 1.7000000476837158 + - 0.20000000298023224 + - - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 6 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 98.5999984741211 + - 10.100000381469727 + - 61.79999923706055 + - 12.0 + - 15.0 + - - 13.90000057220459 + - 7.900000095367432 + - 4.400000095367432 + - 2.5 + - 5.900000095367432 + - - 1.100000023841858 + - 0.20000000298023224 + - 2.200000047683716 + - 0.6000000238418579 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.7000000476837158 + - 0.4000000059604645 + - 3.5 + - 0.9000000357627869 + - 0.6000000238418579 + obs: + - - 0.0 + - 6.0 + - 1.7000000476837158 + - 0.20000000298023224 + - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 98.5999984741211 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 6.0 + - 1.7000000476837158 + - 0.20000000298023224 + - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 10.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 6.0 + - 1.7000000476837158 + - 0.20000000298023224 + - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 61.79999923706055 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 6.0 + - 1.7000000476837158 + - 0.20000000298023224 + - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 6.0 + - 1.7000000476837158 + - 0.20000000298023224 + - 1054.300048828125 + - 540.7000122070312 + - 1745.300048828125 + - 0.6000000238418579 + - 0.5 + - 1159.5 + - 284.70001220703125 + - 4010.0 + - 1739.0 + - 626.7999877929688 + - 530.9000244140625 + - 49.400001525878906 + - 313.1000061035156 + - 61.70000076293945 + - 77.20000457763672 + - 97.30000305175781 + - 9.699999809265137 + - 59.900001525878906 + - 11.699999809265137 + - 14.800000190734863 + - 13.0 + - 1.8000000715255737 + - 8.0 + - 1.399999976158142 + - 1.8000000715255737 + - 48.70000076293945 + - 4.800000190734863 + - 30.0 + - 5.800000190734863 + - 7.400000095367432 + - 17.200000762939453 + - -68.9000015258789 + - 17.200000762939453 + - 17.200000762939453 + - 17.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 15.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.7000000476837158 + - 0.4000000059604645 + - 3.5 + - 0.9000000357627869 + - 0.6000000238418579 +6: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - - 1.899999976158142 + - 0.20000000298023224 + - - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 7 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 101.5 + - 10.600000381469727 + - 65.20000457763672 + - 12.600000381469727 + - 15.699999809265137 + - - 14.0 + - 8.0 + - 4.5 + - 2.5 + - 5.900000095367432 + - - 1.0 + - 0.20000000298023224 + - 2.1000001430511475 + - 0.5 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.6000001430511475 + - 0.9000000357627869 + - 0.699999988079071 + obs: + - - 0.0 + - 7.0 + - 1.899999976158142 + - 0.20000000298023224 + - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 101.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 7.0 + - 1.899999976158142 + - 0.20000000298023224 + - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 10.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 7.0 + - 1.899999976158142 + - 0.20000000298023224 + - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 65.20000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 7.0 + - 1.899999976158142 + - 0.20000000298023224 + - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 7.0 + - 1.899999976158142 + - 0.20000000298023224 + - 1099.5 + - 560.0 + - 1746.5 + - 0.6000000238418579 + - 0.5 + - 1160.2000732421875 + - 284.6000061035156 + - 4007.60009765625 + - 1771.2000732421875 + - 627.9000244140625 + - 563.9000244140625 + - 54.5 + - 342.8000183105469 + - 67.0999984741211 + - 84.0999984741211 + - 100.20000457763672 + - 10.100000381469727 + - 63.20000076293945 + - 12.300000190734863 + - 15.40000057220459 + - 13.40000057220459 + - 1.899999976158142 + - 8.5 + - 1.399999976158142 + - 1.899999976158142 + - 50.10000228881836 + - 5.099999904632568 + - 31.600000381469727 + - 6.099999904632568 + - 7.700000286102295 + - 21.5 + - -85.80000305175781 + - 21.5 + - 21.5 + - 21.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 15.699999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.6000001430511475 + - 0.9000000357627869 + - 0.699999988079071 +7: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - - 2.0 + - 0.30000001192092896 + - - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 8 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 103.80000305175781 + - 11.0 + - 67.80000305175781 + - 13.199999809265137 + - 16.200000762939453 + - - 14.0 + - 8.0 + - 4.5 + - 2.5 + - 6.0 + - - 1.0 + - 0.20000000298023224 + - 2.0 + - 0.5 + - 0.4000000059604645 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.6000001430511475 + - 1.0 + - 0.699999988079071 + obs: + - - 0.0 + - 8.0 + - 2.0 + - 0.30000001192092896 + - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 103.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 8.0 + - 2.0 + - 0.30000001192092896 + - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 8.0 + - 2.0 + - 0.30000001192092896 + - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 67.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 8.0 + - 2.0 + - 0.30000001192092896 + - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 13.199999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 8.0 + - 2.0 + - 0.30000001192092896 + - 1145.2000732421875 + - 580.7999877929688 + - 1747.800048828125 + - 0.699999988079071 + - 0.5 + - 1160.9000244140625 + - 284.3999938964844 + - 4005.400146484375 + - 1803.300048828125 + - 628.9000244140625 + - 588.6000366210938 + - 58.29999923706055 + - 366.3999938964844 + - 71.5 + - 89.30000305175781 + - 102.20000457763672 + - 10.5 + - 65.5999984741211 + - 12.800000190734863 + - 15.90000057220459 + - 13.600000381469727 + - 1.899999976158142 + - 8.800000190734863 + - 1.5 + - 2.0 + - 51.10000228881836 + - 5.200000286102295 + - 32.79999923706055 + - 6.400000095367432 + - 7.900000095367432 + - 26.100000381469727 + - -104.4000015258789 + - 26.100000381469727 + - 26.100000381469727 + - 26.100000381469727 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.200000762939453 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.6000001430511475 + - 1.0 + - 0.699999988079071 +8: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - - 2.200000047683716 + - 0.30000001192092896 + - - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 9 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 105.5 + - 11.199999809265137 + - 69.80000305175781 + - 13.600000381469727 + - 16.600000381469727 + - - 14.100000381469727 + - 8.0 + - 4.5 + - 2.5 + - 6.0 + - - 0.9000000357627869 + - 0.20000000298023224 + - 1.899999976158142 + - 0.5 + - 0.30000001192092896 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 + obs: + - - 0.0 + - 9.0 + - 2.200000047683716 + - 0.30000001192092896 + - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 105.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 9.0 + - 2.200000047683716 + - 0.30000001192092896 + - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.199999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 9.0 + - 2.200000047683716 + - 0.30000001192092896 + - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 69.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 9.0 + - 2.200000047683716 + - 0.30000001192092896 + - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 13.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 9.0 + - 2.200000047683716 + - 0.30000001192092896 + - 1191.2000732421875 + - 602.9000244140625 + - 1749.300048828125 + - 0.699999988079071 + - 0.5 + - 1161.5 + - 284.20001220703125 + - 4003.400146484375 + - 1835.4000244140625 + - 629.7999877929688 + - 607.0 + - 61.20000076293945 + - 384.8999938964844 + - 75.20000457763672 + - 93.30000305175781 + - 103.80000305175781 + - 10.699999809265137 + - 67.4000015258789 + - 13.199999809265137 + - 16.200000762939453 + - 13.800000190734863 + - 2.0 + - 9.0 + - 1.600000023841858 + - 2.0 + - 51.900001525878906 + - 5.300000190734863 + - 33.70000076293945 + - 6.599999904632568 + - 8.100000381469727 + - 31.200000762939453 + - -124.80000305175781 + - 31.200000762939453 + - 31.200000762939453 + - 31.200000762939453 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 +9: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - - 2.4000000953674316 + - 0.30000001192092896 + - - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 10 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 106.9000015258789 + - 11.40000057220459 + - 71.30000305175781 + - 14.100000381469727 + - 16.899999618530273 + - - 14.100000381469727 + - 8.100000381469727 + - 4.599999904632568 + - 2.5 + - 6.0 + - - 0.800000011920929 + - 0.20000000298023224 + - 1.8000000715255737 + - 0.5 + - 0.30000001192092896 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 + obs: + - - 0.0 + - 10.0 + - 2.4000000953674316 + - 0.30000001192092896 + - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 106.9000015258789 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 10.0 + - 2.4000000953674316 + - 0.30000001192092896 + - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.40000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 10.0 + - 2.4000000953674316 + - 0.30000001192092896 + - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 71.30000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 10.0 + - 2.4000000953674316 + - 0.30000001192092896 + - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 14.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 10.0 + - 2.4000000953674316 + - 0.30000001192092896 + - 1237.300048828125 + - 626.0 + - 1751.0 + - 0.699999988079071 + - 0.5 + - 1162.0 + - 284.1000061035156 + - 4001.699951171875 + - 1867.4000244140625 + - 630.7000122070312 + - 620.7999877929688 + - 63.29999923706055 + - 399.3000183105469 + - 78.4000015258789 + - 96.30000305175781 + - 105.0 + - 10.90000057220459 + - 68.80000305175781 + - 13.600000381469727 + - 16.5 + - 14.0 + - 2.0 + - 9.199999809265137 + - 1.600000023841858 + - 2.1000001430511475 + - 52.5 + - 5.400000095367432 + - 34.400001525878906 + - 6.800000190734863 + - 8.199999809265137 + - 36.79999923706055 + - -147.3000030517578 + - 36.79999923706055 + - 36.79999923706055 + - 36.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.899999618530273 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 +10: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - - 2.5 + - 0.4000000059604645 + - - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 11 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 108.0 + - 11.600000381469727 + - 72.5 + - 14.5 + - 17.100000381469727 + - - 14.199999809265137 + - 8.100000381469727 + - 4.599999904632568 + - 2.5 + - 6.0 + - - 0.800000011920929 + - 0.20000000298023224 + - 1.600000023841858 + - 0.5 + - 0.30000001192092896 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 + obs: + - - 0.0 + - 11.0 + - 2.5 + - 0.4000000059604645 + - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 108.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 11.0 + - 2.5 + - 0.4000000059604645 + - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 11.0 + - 2.5 + - 0.4000000059604645 + - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 72.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 11.0 + - 2.5 + - 0.4000000059604645 + - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 14.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 11.0 + - 2.5 + - 0.4000000059604645 + - 1283.5 + - 650.0 + - 1752.800048828125 + - 0.800000011920929 + - 0.5 + - 1162.5999755859375 + - 283.8999938964844 + - 4000.10009765625 + - 1899.300048828125 + - 631.6000366210938 + - 631.2999877929688 + - 64.9000015258789 + - 410.3000183105469 + - 81.20000457763672 + - 98.5999984741211 + - 105.9000015258789 + - 11.0 + - 69.80000305175781 + - 14.0 + - 16.700000762939453 + - 14.100000381469727 + - 2.0 + - 9.300000190734863 + - 1.7000000476837158 + - 2.1000001430511475 + - 52.900001525878906 + - 5.5 + - 34.900001525878906 + - 7.0 + - 8.300000190734863 + - 43.0 + - -172.10000610351562 + - 43.0 + - 43.0 + - 43.0 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.0 + - 0.699999988079071 +11: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - - 2.700000047683716 + - 0.5 + - - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 12 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 108.9000015258789 + - 11.699999809265137 + - 73.4000015258789 + - 14.800000190734863 + - 17.30000114440918 + - - 14.199999809265137 + - 8.100000381469727 + - 4.599999904632568 + - 2.5 + - 6.0 + - - 0.699999988079071 + - 0.20000000298023224 + - 1.5 + - 0.4000000059604645 + - 0.30000001192092896 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.100000023841858 + - 0.699999988079071 + obs: + - - 0.0 + - 12.0 + - 2.700000047683716 + - 0.5 + - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 108.9000015258789 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 12.0 + - 2.700000047683716 + - 0.5 + - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.699999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 12.0 + - 2.700000047683716 + - 0.5 + - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 73.4000015258789 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 12.0 + - 2.700000047683716 + - 0.5 + - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 14.800000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 12.0 + - 2.700000047683716 + - 0.5 + - 1329.5999755859375 + - 674.6000366210938 + - 1754.800048828125 + - 0.800000011920929 + - 0.5 + - 1163.0999755859375 + - 283.8000183105469 + - 3998.60009765625 + - 1931.0999755859375 + - 632.2999877929688 + - 639.2000122070312 + - 66.0 + - 418.8000183105469 + - 83.70000457763672 + - 100.30000305175781 + - 106.5999984741211 + - 11.100000381469727 + - 70.5999984741211 + - 14.300000190734863 + - 16.80000114440918 + - 14.199999809265137 + - 2.0 + - 9.40000057220459 + - 1.7000000476837158 + - 2.1000001430511475 + - 53.29999923706055 + - 5.5 + - 35.29999923706055 + - 7.099999904632568 + - 8.40000057220459 + - 49.79999923706055 + - -199.3000030517578 + - 49.79999923706055 + - 49.79999923706055 + - 49.79999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.30000114440918 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.100000023841858 + - 0.699999988079071 +12: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - - 2.9000000953674316 + - 0.5 + - - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 13 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 109.70000457763672 + - 11.800000190734863 + - 74.20000457763672 + - 15.199999809265137 + - 17.399999618530273 + - - 14.300000190734863 + - 8.100000381469727 + - 4.599999904632568 + - 2.6000001430511475 + - 6.099999904632568 + - - 0.699999988079071 + - 0.10000000149011612 + - 1.399999976158142 + - 0.4000000059604645 + - 0.30000001192092896 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.100000023841858 + - 0.699999988079071 + obs: + - - 0.0 + - 13.0 + - 2.9000000953674316 + - 0.5 + - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 109.70000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 13.0 + - 2.9000000953674316 + - 0.5 + - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.800000190734863 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 13.0 + - 2.9000000953674316 + - 0.5 + - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 74.20000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 13.0 + - 2.9000000953674316 + - 0.5 + - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 15.199999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 13.0 + - 2.9000000953674316 + - 0.5 + - 1375.9000244140625 + - 699.7999877929688 + - 1757.0 + - 0.800000011920929 + - 0.5 + - 1163.5 + - 283.6000061035156 + - 3997.300048828125 + - 1962.800048828125 + - 633.0 + - 645.2999877929688 + - 66.9000015258789 + - 425.20001220703125 + - 85.9000015258789 + - 101.5999984741211 + - 107.0999984741211 + - 11.199999809265137 + - 71.20000457763672 + - 14.600000381469727 + - 17.0 + - 14.300000190734863 + - 2.0 + - 9.5 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.60000228881836 + - 5.599999904632568 + - 35.60000228881836 + - 7.300000190734863 + - 8.5 + - 57.29999923706055 + - -229.1999969482422 + - 57.29999923706055 + - 57.29999923706055 + - 57.29999923706055 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.399999618530273 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.700000047683716 + - 1.100000023841858 + - 0.699999988079071 +13: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - - 3.1000001430511475 + - 0.6000000238418579 + - - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 110.5 + - 11.90000057220459 + - 74.80000305175781 + - 15.600000381469727 + - 17.600000381469727 + - - 14.300000190734863 + - 8.199999809265137 + - 4.599999904632568 + - 2.6000001430511475 + - 6.099999904632568 + - - 0.6000000238418579 + - 0.10000000149011612 + - 1.3000000715255737 + - 0.4000000059604645 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.100000023841858 + - 0.699999988079071 + obs: + - - 0.0 + - 14.0 + - 3.1000001430511475 + - 0.6000000238418579 + - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 110.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 14.0 + - 3.1000001430511475 + - 0.6000000238418579 + - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 11.90000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 14.0 + - 3.1000001430511475 + - 0.6000000238418579 + - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 74.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 14.0 + - 3.1000001430511475 + - 0.6000000238418579 + - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 15.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 14.0 + - 3.1000001430511475 + - 0.6000000238418579 + - 1422.0999755859375 + - 725.4000244140625 + - 1759.300048828125 + - 0.800000011920929 + - 0.5 + - 1163.9000244140625 + - 283.5 + - 3996.199951171875 + - 1994.5 + - 633.7000122070312 + - 650.0 + - 67.5 + - 430.1000061035156 + - 88.0 + - 102.5999984741211 + - 107.5999984741211 + - 11.199999809265137 + - 71.5999984741211 + - 14.90000057220459 + - 17.100000381469727 + - 14.300000190734863 + - 2.1000001430511475 + - 9.600000381469727 + - 1.8000000715255737 + - 2.1000001430511475 + - 53.79999923706055 + - 5.599999904632568 + - 35.79999923706055 + - 7.5 + - 8.5 + - 65.5 + - -262.1000061035156 + - 65.5 + - 65.5 + - 65.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.100000023841858 + - 0.699999988079071 +14: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - - 3.200000047683716 + - 0.6000000238418579 + - - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 15 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 111.0999984741211 + - 12.0 + - 75.4000015258789 + - 15.90000057220459 + - 17.700000762939453 + - - 14.40000057220459 + - 8.199999809265137 + - 4.700000286102295 + - 2.6000001430511475 + - 6.099999904632568 + - - 0.6000000238418579 + - 0.10000000149011612 + - 1.2000000476837158 + - 0.4000000059604645 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.100000023841858 + - 0.699999988079071 + obs: + - - 0.0 + - 15.0 + - 3.200000047683716 + - 0.6000000238418579 + - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 111.0999984741211 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 15.0 + - 3.200000047683716 + - 0.6000000238418579 + - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 15.0 + - 3.200000047683716 + - 0.6000000238418579 + - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 75.4000015258789 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 15.0 + - 3.200000047683716 + - 0.6000000238418579 + - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 15.90000057220459 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 15.0 + - 3.200000047683716 + - 0.6000000238418579 + - 1468.4000244140625 + - 751.4000244140625 + - 1761.800048828125 + - 0.800000011920929 + - 0.5 + - 1164.300048828125 + - 283.3999938964844 + - 3995.10009765625 + - 2026.0 + - 634.2999877929688 + - 653.7000122070312 + - 68.0 + - 433.8999938964844 + - 89.9000015258789 + - 103.5 + - 108.0 + - 11.199999809265137 + - 72.0 + - 15.199999809265137 + - 17.100000381469727 + - 14.40000057220459 + - 2.1000001430511475 + - 9.600000381469727 + - 1.899999976158142 + - 2.200000047683716 + - 54.0 + - 5.599999904632568 + - 36.0 + - 7.599999904632568 + - 8.600000381469727 + - 74.5999984741211 + - -298.3000183105469 + - 74.5999984741211 + - 74.5999984741211 + - 74.5999984741211 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.700000762939453 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.100000023841858 + - 0.699999988079071 +15: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - - 3.4000000953674316 + - 0.699999988079071 + - - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 16 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 111.70000457763672 + - 12.0 + - 75.80000305175781 + - 16.200000762939453 + - 17.80000114440918 + - - 14.40000057220459 + - 8.199999809265137 + - 4.700000286102295 + - 2.6000001430511475 + - 6.099999904632568 + - - 0.5 + - 0.10000000149011612 + - 1.100000023841858 + - 0.4000000059604645 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 + obs: + - - 0.0 + - 16.0 + - 3.4000000953674316 + - 0.699999988079071 + - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 111.70000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 16.0 + - 3.4000000953674316 + - 0.699999988079071 + - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 16.0 + - 3.4000000953674316 + - 0.699999988079071 + - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 75.80000305175781 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 16.0 + - 3.4000000953674316 + - 0.699999988079071 + - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.200000762939453 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 16.0 + - 3.4000000953674316 + - 0.699999988079071 + - 1514.7000732421875 + - 777.6000366210938 + - 1764.5 + - 0.9000000357627869 + - 0.5 + - 1164.5999755859375 + - 283.3000183105469 + - 3994.199951171875 + - 2057.400146484375 + - 634.9000244140625 + - 656.6000366210938 + - 68.30000305175781 + - 436.70001220703125 + - 91.80000305175781 + - 104.0999984741211 + - 108.30000305175781 + - 11.300000190734863 + - 72.20000457763672 + - 15.5 + - 17.200000762939453 + - 14.40000057220459 + - 2.1000001430511475 + - 9.699999809265137 + - 1.899999976158142 + - 2.200000047683716 + - 54.10000228881836 + - 5.599999904632568 + - 36.10000228881836 + - 7.700000286102295 + - 8.600000381469727 + - 84.5 + - -338.1000061035156 + - 84.5 + - 84.5 + - 84.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.80000114440918 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 +16: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - - 3.6000001430511475 + - 0.800000011920929 + - - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 17 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 112.20000457763672 + - 12.100000381469727 + - 76.20000457763672 + - 16.600000381469727 + - 17.899999618530273 + - - 14.5 + - 8.199999809265137 + - 4.700000286102295 + - 2.6000001430511475 + - 6.099999904632568 + - - 0.5 + - 0.10000000149011612 + - 1.100000023841858 + - 0.30000001192092896 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 + obs: + - - 0.0 + - 17.0 + - 3.6000001430511475 + - 0.800000011920929 + - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 112.20000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 17.0 + - 3.6000001430511475 + - 0.800000011920929 + - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 17.0 + - 3.6000001430511475 + - 0.800000011920929 + - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 76.20000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 17.0 + - 3.6000001430511475 + - 0.800000011920929 + - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.600000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 17.0 + - 3.6000001430511475 + - 0.800000011920929 + - 1561.0999755859375 + - 804.1000366210938 + - 1767.300048828125 + - 0.9000000357627869 + - 0.5 + - 1165.0 + - 283.1000061035156 + - 3993.300048828125 + - 2088.60009765625 + - 635.4000244140625 + - 659.0 + - 68.5999984741211 + - 438.8999938964844 + - 93.5999984741211 + - 104.5999984741211 + - 108.5 + - 11.300000190734863 + - 72.4000015258789 + - 15.800000190734863 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.29999923706055 + - 5.599999904632568 + - 36.20000076293945 + - 7.900000095367432 + - 8.600000381469727 + - 95.5 + - -382.0 + - 95.5 + - 95.5 + - 95.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.899999618530273 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 +17: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - - 3.700000047683716 + - 0.800000011920929 + - - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 18 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 112.70000457763672 + - 12.100000381469727 + - 76.5999984741211 + - 16.899999618530273 + - 18.0 + - - 14.5 + - 8.300000190734863 + - 4.700000286102295 + - 2.6000001430511475 + - 6.200000286102295 + - - 0.5 + - 0.10000000149011612 + - 1.0 + - 0.30000001192092896 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 + obs: + - - 0.0 + - 18.0 + - 3.700000047683716 + - 0.800000011920929 + - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 112.70000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 18.0 + - 3.700000047683716 + - 0.800000011920929 + - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 18.0 + - 3.700000047683716 + - 0.800000011920929 + - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 76.5999984741211 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 18.0 + - 3.700000047683716 + - 0.800000011920929 + - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 16.899999618530273 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 18.0 + - 3.700000047683716 + - 0.800000011920929 + - 1607.5999755859375 + - 830.7999877929688 + - 1770.4000244140625 + - 0.9000000357627869 + - 0.5 + - 1165.2000732421875 + - 283.0 + - 3992.60009765625 + - 2119.699951171875 + - 635.9000244140625 + - 660.9000244140625 + - 68.80000305175781 + - 440.6000061035156 + - 95.30000305175781 + - 105.0 + - 108.70000457763672 + - 11.300000190734863 + - 72.5999984741211 + - 16.0 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.29999923706055 + - 8.0 + - 8.699999809265137 + - 107.5 + - -430.1000061035156 + - 107.5 + - 107.5 + - 107.5 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 18.0 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 +18: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - [] + - - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - - 3.9000000953674316 + - 0.9000000357627869 + - - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 19 + - - 0.20000000298023224 + - 1.100000023841858 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - - 0.0 + - 0.20000000298023224 + - 0.10000000149011612 + - 0.10000000149011612 + - 0.10000000149011612 + - 0 + - - 113.20000457763672 + - 12.199999809265137 + - 76.9000015258789 + - 17.200000762939453 + - 18.100000381469727 + - - 14.5 + - 8.300000190734863 + - 4.700000286102295 + - 2.6000001430511475 + - 6.200000286102295 + - - 0.4000000059604645 + - 0.10000000149011612 + - 0.9000000357627869 + - 0.30000001192092896 + - 0.20000000298023224 + - - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 + obs: + - - 0.0 + - 19.0 + - 3.9000000953674316 + - 0.9000000357627869 + - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 113.20000457763672 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 1.0 + - 19.0 + - 3.9000000953674316 + - 0.9000000357627869 + - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 12.199999809265137 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 2.0 + - 19.0 + - 3.9000000953674316 + - 0.9000000357627869 + - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 76.9000015258789 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 3.0 + - 19.0 + - 3.9000000953674316 + - 0.9000000357627869 + - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 17.200000762939453 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - - 4.0 + - 19.0 + - 3.9000000953674316 + - 0.9000000357627869 + - 1654.2000732421875 + - 857.7000122070312 + - 1773.5999755859375 + - 1.0 + - 0.5 + - 1165.5 + - 282.8999938964844 + - 3991.900146484375 + - 2150.699951171875 + - 636.4000244140625 + - 662.5 + - 68.9000015258789 + - 441.8999938964844 + - 97.0 + - 105.4000015258789 + - 108.9000015258789 + - 11.300000190734863 + - 72.70000457763672 + - 16.30000114440918 + - 17.30000114440918 + - 14.5 + - 2.1000001430511475 + - 9.699999809265137 + - 2.0 + - 2.200000047683716 + - 54.400001525878906 + - 5.700000286102295 + - 36.400001525878906 + - 8.100000381469727 + - 8.699999809265137 + - 120.80000305175781 + - -483.20001220703125 + - 120.80000305175781 + - 120.80000305175781 + - 120.80000305175781 + - 1.0 + - 0.0 + - 1.0 + - 1.0 + - 1.0 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.20000000298023224 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1.0 + - 0.0 + - 18.100000381469727 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + - 0.5 + rewards: + - 1.8000000715255737 + - 0.4000000059604645 + - 3.799999952316284 + - 1.2000000476837158 + - 0.699999988079071 +19: + env_state: + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - [] + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0 + - 0 + - 0 + - 0 + - 0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - - 0.800000011920929 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0 + - - 0.20000000298023224 + - 1.2000000476837158 + - 0.699999988079071 + - 0.699999988079071 + - 0.30000001192092896 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 1 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 12.600000381469727 + - 5.099999904632568 + - 3.299999952316284 + - 2.0 + - 4.900000095367432 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + obs: + - - 0.0 + - 0.0 + - 0.800000011920929 + - 0.0 + - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - - 1.0 + - 0.0 + - 0.800000011920929 + - 0.0 + - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - - 2.0 + - 0.0 + - 0.800000011920929 + - 0.0 + - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - - 3.0 + - 0.0 + - 0.800000011920929 + - 0.0 + - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - - 4.0 + - 0.0 + - 0.800000011920929 + - 0.0 + - 851.0 + - 460.0 + - 1740.0 + - 0.0 + - 0.0 + - 1153.5999755859375 + - 286.0 + - 4031.5 + - 1545.800048828125 + - 618.1000366210938 + - 67.5 + - 0.9000000357627869 + - 27.5 + - 6.300000190734863 + - 2.200000047683716 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + - 0.9000000357627869 + - 0.4000000059604645 + - 0.30000001192092896 + - 0.699999988079071 + - 0.30000001192092896 + - 0.800000011920929 + - 0.5 + - 0.0 + - 0.9000000357627869 + - 0.20000000298023224 + - 0.4000000059604645 + - 0.30000001192092896 + - 1.0 + rewards: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 diff --git a/test/test_strategies.py b/test/test_strategies.py index ccc38fcc..1d4f9b9b 100644 --- a/test/test_strategies.py +++ b/test/test_strategies.py @@ -247,58 +247,6 @@ def test_naive_tft(): assert jnp.allclose(0.99, reward0, atol=0.01) -def test_naive_tft(): - num_envs = 2 - rng = jnp.concatenate([jax.random.PRNGKey(0)] * num_envs).reshape( - num_envs, -1 - ) - - env = InfiniteMatrixGame(num_steps=jnp.inf) - env_params = InfiniteMatrixGameParams( - payoff_matrix=[[2, 2], [0, 3], [3, 0], [1, 1]], gamma=0.96 - ) - - # vmaps - split = jax.vmap(jax.random.split, in_axes=(0, None)) - env.reset = jax.jit( - jax.vmap(env.reset, in_axes=(0, None), out_axes=(0, None)) - ) - env.step = jax.jit( - jax.vmap( - env.step, in_axes=(0, None, 0, None), out_axes=(0, None, 0, 0, 0) - ) - ) - - (obs1, _), env_state = env.reset(rng, env_params) - agent = NaiveExact( - action_dim=5, - env_params=env_params, - lr=1, - num_envs=num_envs, - player_id=0, - ) - agent_state, agent_memory = agent.make_initial_state(obs1) - tft_action = jnp.tile( - 20 * jnp.array([1.0, -1.0, 1.0, -1.0, 1.0]), (num_envs, 1) - ) - - for _ in range(50): - rng, _ = split(rng, 2) - action, agent_state, agent_memory = agent._policy( - agent_state, obs1, agent_memory - ) - - (obs1, _), env_state, rewards, _, _ = env.step( - rng, env_state, (action, tft_action), env_params - ) - - action, _, _ = agent._policy(agent_state, obs1, agent_memory) - _, _, (reward0, reward1), _, _ = env.step( - rng, env_state, (action, tft_action), env_params - ) - assert jnp.allclose(2.0, reward0, atol=0.01) - - def test_naive_tft_as_second_player(): num_envs = 2 rng = jnp.concatenate([jax.random.PRNGKey(0)] * num_envs).reshape(