diff --git a/proof_of_concept.ipynb b/proof_of_concept.ipynb new file mode 100644 index 00000000..783b3c32 --- /dev/null +++ b/proof_of_concept.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pax: Proof of Concept\n", + "\n", + "This notebook serves as a minimal working version for all concepts touched upon in `pax`. It does not create meaningful results, it simply exists to familiarize yourselfs with the code. If you are looking for working examples, please refer to the repository. There are many working examples of runners, environments and agents." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environments: A simple rollout\n", + "\n", + "Let's have a look at how you typically initial environments, how to parameterize them and how to vmap them. In `pax`, environments are initialized in the `experiments.py` file. Rollouts are usually defined in a runners file, e.g. `runners/runner_marl.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n", + "from pax.envs.iterated_matrix_game import (\n", + " IteratedMatrixGame,\n", + " EnvParams,\n", + ")\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "devices = jax.local_devices()\n", + "print(devices)\n", + "# batch over env initalisations\n", + "num_envs = 2\n", + "payoff = [[2, 2], [0, 3], [3, 0], [1, 1]]\n", + "rollout_length = 50\n", + "\n", + "rng = jnp.concatenate(\n", + " [jax.random.PRNGKey(0), jax.random.PRNGKey(1)]\n", + ").reshape(num_envs, -1)\n", + "\n", + "env = IteratedMatrixGame(num_inner_steps=rollout_length, num_outer_steps=1)\n", + "env_params = EnvParams(payoff_matrix=payoff)\n", + "\n", + "action = jnp.ones((num_envs,), dtype=jnp.float32)\n", + "\n", + "# we want to batch over rngs, actions\n", + "env.step = jax.vmap(\n", + " env.step,\n", + " in_axes=(0, None, 0, None),\n", + " out_axes=(0, None, 0, 0, 0),\n", + ")\n", + "env.reset = jax.vmap(\n", + " env.reset, in_axes=(0, None), out_axes=(0, None))\n", + "obs, env_state = env.reset(rng, env_params)\n", + "\n", + "# lets scan the rollout for speed\n", + "def rollout(carry, unused):\n", + " last_obs, env_state, env_rng = carry\n", + " actions = (action, action)\n", + " obs, env_state, rewards, done, info = env.step(\n", + " env_rng, env_state, actions, env_params\n", + " )\n", + "\n", + " return (obs, env_state, env_rng), (\n", + " obs,\n", + " actions,\n", + " rewards,\n", + " done,\n", + " )\n", + "\n", + "\n", + "final_state, trajectory = jax.lax.scan(\n", + " rollout, (obs, env_state, rng), None, rollout_length\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agents\n", + "\n", + "Now let's define an agent, in this case a PPO agent. Our ppo agent requires some arguments to get initialized. These are typically defined in a `.yaml` file, as we use `Hydra` to keep track of hyperparameters. They are stored in `conf/experiment/`. For a simple example, have a look at `conf/experiment/ipd/ppo_v_tft.yaml`, that defines an experiment where a PPO agent plays against a TitForTat agent in the Iterated Prisoner's Dilemma." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import NamedTuple\n", + "\n", + "class EnvArgs(NamedTuple):\n", + " env_id='iterated_matrix_game'\n", + " runner='rl'\n", + " num_envs=num_envs\n", + "\n", + "class PPOArgs(NamedTuple):\n", + " num_minibatches=10\n", + " num_epochs=4\n", + " gamma=0.96\n", + " gae_lambda=0.95\n", + " ppo_clipping_epsilon=0.2\n", + " value_coeff=0.5\n", + " clip_value=True\n", + " max_gradient_norm=0.5\n", + " anneal_entropy=True\n", + " entropy_coeff_start=0.2\n", + " entropy_coeff_horizon=5e5\n", + " entropy_coeff_end=0.01\n", + " lr_scheduling=True\n", + " learning_rate=2.5e-2\n", + " adam_epsilon=1e-5\n", + " with_memory=False\n", + " with_cnn=False\n", + " hidden_size=16" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After having defined the arguments, we're ready to initialize the agents. We also vmap over the reset, policy and update function. We do not vmap over the initialisation, as the initialisation already assumes that we are running over multiple environments. If you were to add additional batch dimensions, you'd want to vmap over the initialisation too. The vmapping of the agents is typically done in the runner file." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Making network for iterated_matrix_game\n" + ] + } + ], + "source": [ + "import jax.numpy as jnp\n", + "from pax.agents.ppo.ppo import make_agent\n", + "\n", + "args = EnvArgs()\n", + "agent_args = PPOArgs()\n", + "agent = make_agent(args, \n", + " agent_args=agent_args,\n", + " obs_spec=env.observation_space(env_params).n,\n", + " action_spec=env.num_actions,\n", + " seed=42,\n", + " num_iterations=1e3,\n", + " player_id=0,\n", + " tabular=False,)\n", + "\n", + "\n", + "# batch MemoryState not TrainingState\n", + "agent.batch_reset = jax.jit(\n", + " jax.vmap(agent.reset_memory, (0, None), 0),\n", + " static_argnums=1,\n", + ")\n", + "\n", + "agent.batch_policy = jax.jit(\n", + " jax.vmap(agent._policy, (None, 0, 0), (0, None, 0))\n", + ")\n", + "\n", + "agent.batch_update = jax.vmap(\n", + " agent.update, (0, 0, None, 0), (None, 0, 0)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll define a transition. The transitions is what we're stacking when we're gonna use `jax.lax.scan`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import NamedTuple\n", + "\n", + "class Sample(NamedTuple):\n", + " \"\"\"Object containing a batch of data\"\"\"\n", + "\n", + " observations: jnp.ndarray\n", + " actions: jnp.ndarray\n", + " rewards: jnp.ndarray\n", + " behavior_log_probs: jnp.ndarray\n", + " behavior_values: jnp.ndarray\n", + " dones: jnp.ndarray\n", + " hiddens: jnp.ndarray" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's define a very simple rollout. Note that the agent plays against itself in the IPD, which isn't particularly meaningful. If you had a second agent, you'd obviously want to fill in that code accordingly. Please refer to the many runners within the repo that already have that code in place." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 5) next obs shape\n" + ] + } + ], + "source": [ + "def _rollout(carry, unused):\n", + " \"\"\"Runner for inner episode\"\"\"\n", + " (\n", + " rng,\n", + " obs,\n", + " a_state,\n", + " a_mem,\n", + " env_state,\n", + " env_params,\n", + " ) = carry\n", + " # unpack rngs\n", + " rngs = jax.random.split(rng, num_envs+1)\n", + " rngs, rng = rngs[:-1], rngs[-1]\n", + "\n", + " action, a_state, new_a_mem = agent.batch_policy(\n", + " a_state,\n", + " obs[0],\n", + " a_mem,\n", + " )\n", + "\n", + " next_obs, env_state, rewards, done, info = env.step(\n", + " rngs,\n", + " env_state,\n", + " (action, action),\n", + " env_params,\n", + " )\n", + "\n", + " traj = Sample(\n", + " obs[0],\n", + " action,\n", + " rewards[0],\n", + " new_a_mem.extras[\"log_probs\"],\n", + " new_a_mem.extras[\"values\"],\n", + " done,\n", + " a_mem.hidden,\n", + " )\n", + "\n", + " return (\n", + " rng,\n", + " next_obs,\n", + " a_state,\n", + " new_a_mem,\n", + " env_state,\n", + " env_params,\n", + " ), traj" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can put everything together and train the agent. This already covers almost everything there is to know in `pax`. All the runners are just modified versions of this notebook, where initialisations are centralized in `experiments.py`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = jax.random.PRNGKey(42)\n", + "init_hidden = jnp.zeros((agent_args.hidden_size))\n", + "rng, _rng = jax.random.split(rng)\n", + "a_state, a_memory = agent.make_initial_state(_rng, init_hidden)\n", + "rngs = jax.random.split(rng, num_envs)\n", + "obs, env_state = env.reset(rngs, env_params)\n", + "\n", + "for _ in range(10):\n", + " carry = (rng, obs, a_state, a_memory, env_state, env_params)\n", + " final_timestep, batch_trajectory = jax.lax.scan(\n", + " _rollout,\n", + " carry,\n", + " None,\n", + " 10,\n", + " )\n", + "\n", + " rng, obs, a_state, a_memory, env_state, env_params = final_timestep\n", + "\n", + " a_state, a_memory, stats = agent.update(\n", + " batch_trajectory, obs[0], a_state, a_memory\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}