From a9c62da381496e466eb6d28546c718bee96e7a4e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 7 Oct 2021 17:23:19 +0100 Subject: [PATCH] [examples] Fix notebook code style. --- examples/rllib.ipynb | 75 ++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/examples/rllib.ipynb b/examples/rllib.ipynb index 1913b7bb09..b512b78f86 100644 --- a/examples/rllib.ipynb +++ b/examples/rllib.ipynb @@ -182,17 +182,17 @@ "from itertools import islice\n", "\n", "with make_env() as env:\n", - " # The two datasets we will be using:\n", - " npb = env.datasets[\"npb-v0\"]\n", - " chstone = env.datasets[\"chstone-v0\"]\n", + " # The two datasets we will be using:\n", + " npb = env.datasets[\"npb-v0\"]\n", + " chstone = env.datasets[\"chstone-v0\"]\n", "\n", - " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", - " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", - " # handful of benchmarks for training and validation.\n", - " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", - " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", - " # We will use the entire chstone-v0 dataset for testing.\n", - " test_benchmarks = list(chstone.benchmarks())\n", + " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", + " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", + " # handful of benchmarks for training and validation.\n", + " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", + " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", + " # We will use the entire chstone-v0 dataset for testing.\n", + " test_benchmarks = list(chstone.benchmarks())\n", "\n", "print(\"Number of benchmarks for training:\", len(train_benchmarks))\n", "print(\"Number of benchmarks for validation:\", len(val_benchmarks))\n", @@ -221,11 +221,11 @@ "from compiler_gym.wrappers import CycleOverBenchmarks\n", "\n", "def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:\n", - " \"\"\"Make a reinforcement learning environment that cycles over the\n", - " set of training benchmarks in use.\n", - " \"\"\"\n", - " del args # Unused env_config argument passed by ray\n", - " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", + " \"\"\"Make a reinforcement learning environment that cycles over the\n", + " set of training benchmarks in use.\n", + " \"\"\"\n", + " del args # Unused env_config argument passed by ray\n", + " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", "\n", "tune.register_env(\"compiler_gym\", make_training_env)" ] @@ -245,12 +245,12 @@ "# Lets cycle through a few calls to reset() to demonstrate that this environment\n", "# selects a new benchmark for each episode.\n", "with make_training_env() as env:\n", - " env.reset()\n", - " print(env.benchmark)\n", - " env.reset()\n", - " print(env.benchmark)\n", - " env.reset()\n", - " print(env.benchmark)" + " env.reset()\n", + " print(env.benchmark)\n", + " env.reset()\n", + " print(env.benchmark)\n", + " env.reset()\n", + " print(env.benchmark)" ] }, { @@ -282,7 +282,7 @@ "\n", "# (Re)Start the ray runtime.\n", "if ray.is_initialized():\n", - " ray.shutdown()\n", + " ray.shutdown()\n", "ray.init(include_dashboard=False, ignore_reinit_error=True)\n", "\n", "tune.register_env(\"compiler_gym\", make_training_env)\n", @@ -370,18 +370,18 @@ "# performance on a set of benchmarks.\n", "\n", "def run_agent_on_benchmarks(benchmarks):\n", - " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", - " with make_env() as env:\n", + " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", " rewards = []\n", - " for i, benchmark in enumerate(benchmarks, start=1):\n", - " observation, done = env.reset(benchmark=benchmark), False\n", - " while not done:\n", - " action = agent.compute_action(observation)\n", - " observation, _, done, _ = env.step(action)\n", - " rewards.append(env.episode_reward)\n", - " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", + " with make_env() as env:\n", + " for i, benchmark in enumerate(benchmarks, start=1):\n", + " observation, done = env.reset(benchmark=benchmark), False\n", + " while not done:\n", + " action = agent.compute_action(observation)\n", + " observation, _, done, _ = env.step(action)\n", + " rewards.append(env.episode_reward)\n", + " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", "\n", - " return rewards\n", + " return rewards\n", "\n", "# Evaluate agent performance on the validation set.\n", "val_rewards = run_agent_on_benchmarks(val_benchmarks)" @@ -417,14 +417,15 @@ "outputs": [], "source": [ "# Finally lets plot our results to see how we did!\n", + "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "\n", "def plot_results(x, y, name, ax):\n", - " plt.sca(ax)\n", - " plt.bar(range(len(y)), y)\n", - " plt.ylabel(\"Reward (higher is better)\")\n", - " plt.xticks(range(len(x)), x, rotation = 90)\n", - " plt.title(f\"Performance on {name} set\")\n", + " plt.sca(ax)\n", + " plt.bar(range(len(y)), y)\n", + " plt.ylabel(\"Reward (higher is better)\")\n", + " plt.xticks(range(len(x)), x, rotation = 90)\n", + " plt.title(f\"Performance on {name} set\")\n", "\n", "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "fig.set_size_inches(13, 3)\n",