Skip to content

Commit

Permalink
adding an introductory notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidandos committed Nov 17, 2023
1 parent 4acde5a commit 4d5518d
Showing 1 changed file with 339 additions and 0 deletions.
339 changes: 339 additions & 0 deletions proof_of_concept.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pax: Proof of Concept\n",
"\n",
"This notebook serves as a minimal working version for all concepts touched upon in `pax`. It does not create meaningful results, it simply exists to familiarize yourselfs with the code. If you are looking for working examples, please refer to the repository. There are many working examples of runners, environments and agents."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environments: A simple rollout\n",
"\n",
"Let's have a look at how you typically initial environments, how to parameterize them and how to vmap them. In `pax`, environments are initialized in the `experiments.py` file. Rollouts are usually defined in a runners file, e.g. `runners/runner_marl.py`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]\n"
]
}
],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
"from pax.envs.iterated_matrix_game import (\n",
" IteratedMatrixGame,\n",
" EnvParams,\n",
")\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"devices = jax.local_devices()\n",
"print(devices)\n",
"# batch over env initalisations\n",
"num_envs = 2\n",
"payoff = [[2, 2], [0, 3], [3, 0], [1, 1]]\n",
"rollout_length = 50\n",
"\n",
"rng = jnp.concatenate(\n",
" [jax.random.PRNGKey(0), jax.random.PRNGKey(1)]\n",
").reshape(num_envs, -1)\n",
"\n",
"env = IteratedMatrixGame(num_inner_steps=rollout_length, num_outer_steps=1)\n",
"env_params = EnvParams(payoff_matrix=payoff)\n",
"\n",
"action = jnp.ones((num_envs,), dtype=jnp.float32)\n",
"\n",
"# we want to batch over rngs, actions\n",
"env.step = jax.vmap(\n",
" env.step,\n",
" in_axes=(0, None, 0, None),\n",
" out_axes=(0, None, 0, 0, 0),\n",
")\n",
"env.reset = jax.vmap(\n",
" env.reset, in_axes=(0, None), out_axes=(0, None))\n",
"obs, env_state = env.reset(rng, env_params)\n",
"\n",
"# lets scan the rollout for speed\n",
"def rollout(carry, unused):\n",
" last_obs, env_state, env_rng = carry\n",
" actions = (action, action)\n",
" obs, env_state, rewards, done, info = env.step(\n",
" env_rng, env_state, actions, env_params\n",
" )\n",
"\n",
" return (obs, env_state, env_rng), (\n",
" obs,\n",
" actions,\n",
" rewards,\n",
" done,\n",
" )\n",
"\n",
"\n",
"final_state, trajectory = jax.lax.scan(\n",
" rollout, (obs, env_state, rng), None, rollout_length\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agents\n",
"\n",
"Now let's define an agent, in this case a PPO agent. Our ppo agent requires some arguments to get initialized. These are typically defined in a `.yaml` file, as we use `Hydra` to keep track of hyperparameters. They are stored in `conf/experiment/`. For a simple example, have a look at `conf/experiment/ipd/ppo_v_tft.yaml`, that defines an experiment where a PPO agent plays against a TitForTat agent in the Iterated Prisoner's Dilemma."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"\n",
"class EnvArgs(NamedTuple):\n",
" env_id='iterated_matrix_game'\n",
" runner='rl'\n",
" num_envs=num_envs\n",
"\n",
"class PPOArgs(NamedTuple):\n",
" num_minibatches=10\n",
" num_epochs=4\n",
" gamma=0.96\n",
" gae_lambda=0.95\n",
" ppo_clipping_epsilon=0.2\n",
" value_coeff=0.5\n",
" clip_value=True\n",
" max_gradient_norm=0.5\n",
" anneal_entropy=True\n",
" entropy_coeff_start=0.2\n",
" entropy_coeff_horizon=5e5\n",
" entropy_coeff_end=0.01\n",
" lr_scheduling=True\n",
" learning_rate=2.5e-2\n",
" adam_epsilon=1e-5\n",
" with_memory=False\n",
" with_cnn=False\n",
" hidden_size=16"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After having defined the arguments, we're ready to initialize the agents. We also vmap over the reset, policy and update function. We do not vmap over the initialisation, as the initialisation already assumes that we are running over multiple environments. If you were to add additional batch dimensions, you'd want to vmap over the initialisation too. The vmapping of the agents is typically done in the runner file."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Making network for iterated_matrix_game\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"from pax.agents.ppo.ppo import make_agent\n",
"\n",
"args = EnvArgs()\n",
"agent_args = PPOArgs()\n",
"agent = make_agent(args, \n",
" agent_args=agent_args,\n",
" obs_spec=env.observation_space(env_params).n,\n",
" action_spec=env.num_actions,\n",
" seed=42,\n",
" num_iterations=1e3,\n",
" player_id=0,\n",
" tabular=False,)\n",
"\n",
"\n",
"# batch MemoryState not TrainingState\n",
"agent.batch_reset = jax.jit(\n",
" jax.vmap(agent.reset_memory, (0, None), 0),\n",
" static_argnums=1,\n",
")\n",
"\n",
"agent.batch_policy = jax.jit(\n",
" jax.vmap(agent._policy, (None, 0, 0), (0, None, 0))\n",
")\n",
"\n",
"agent.batch_update = jax.vmap(\n",
" agent.update, (0, 0, None, 0), (None, 0, 0)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we'll define a transition. The transitions is what we're stacking when we're gonna use `jax.lax.scan`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"\n",
"class Sample(NamedTuple):\n",
" \"\"\"Object containing a batch of data\"\"\"\n",
"\n",
" observations: jnp.ndarray\n",
" actions: jnp.ndarray\n",
" rewards: jnp.ndarray\n",
" behavior_log_probs: jnp.ndarray\n",
" behavior_values: jnp.ndarray\n",
" dones: jnp.ndarray\n",
" hiddens: jnp.ndarray"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's define a very simple rollout. Note that the agent plays against itself in the IPD, which isn't particularly meaningful. If you had a second agent, you'd obviously want to fill in that code accordingly. Please refer to the many runners within the repo that already have that code in place."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 5) next obs shape\n"
]
}
],
"source": [
"def _rollout(carry, unused):\n",
" \"\"\"Runner for inner episode\"\"\"\n",
" (\n",
" rng,\n",
" obs,\n",
" a_state,\n",
" a_mem,\n",
" env_state,\n",
" env_params,\n",
" ) = carry\n",
" # unpack rngs\n",
" rngs = jax.random.split(rng, num_envs+1)\n",
" rngs, rng = rngs[:-1], rngs[-1]\n",
"\n",
" action, a_state, new_a_mem = agent.batch_policy(\n",
" a_state,\n",
" obs[0],\n",
" a_mem,\n",
" )\n",
"\n",
" next_obs, env_state, rewards, done, info = env.step(\n",
" rngs,\n",
" env_state,\n",
" (action, action),\n",
" env_params,\n",
" )\n",
"\n",
" traj = Sample(\n",
" obs[0],\n",
" action,\n",
" rewards[0],\n",
" new_a_mem.extras[\"log_probs\"],\n",
" new_a_mem.extras[\"values\"],\n",
" done,\n",
" a_mem.hidden,\n",
" )\n",
"\n",
" return (\n",
" rng,\n",
" next_obs,\n",
" a_state,\n",
" new_a_mem,\n",
" env_state,\n",
" env_params,\n",
" ), traj"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can put everything together and train the agent. This already covers almost everything there is to know in `pax`. All the runners are just modified versions of this notebook, where initialisations are centralized in `experiments.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = jax.random.PRNGKey(42)\n",
"init_hidden = jnp.zeros((agent_args.hidden_size))\n",
"rng, _rng = jax.random.split(rng)\n",
"a_state, a_memory = agent.make_initial_state(_rng, init_hidden)\n",
"rngs = jax.random.split(rng, num_envs)\n",
"obs, env_state = env.reset(rngs, env_params)\n",
"\n",
"for _ in range(10):\n",
" carry = (rng, obs, a_state, a_memory, env_state, env_params)\n",
" final_timestep, batch_trajectory = jax.lax.scan(\n",
" _rollout,\n",
" carry,\n",
" None,\n",
" 10,\n",
" )\n",
"\n",
" rng, obs, a_state, a_memory, env_state, env_params = final_timestep\n",
"\n",
" a_state, a_memory, stats = agent.update(\n",
" batch_trajectory, obs[0], a_state, a_memory\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 4d5518d

Please sign in to comment.