Skip to content

Commit

Permalink
[examples] Fix notebook code style.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Dec 14, 2021
1 parent 56f7bfb commit a9c62da
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions examples/rllib.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,17 @@
"from itertools import islice\n",
"\n",
"with make_env() as env:\n",
" # The two datasets we will be using:\n",
" npb = env.datasets[\"npb-v0\"]\n",
" chstone = env.datasets[\"chstone-v0\"]\n",
" # The two datasets we will be using:\n",
" npb = env.datasets[\"npb-v0\"]\n",
" chstone = env.datasets[\"chstone-v0\"]\n",
"\n",
" # Each dataset has a `benchmarks()` method that returns an iterator over the\n",
" # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n",
" # handful of benchmarks for training and validation.\n",
" train_benchmarks = list(islice(npb.benchmarks(), 55))\n",
" train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n",
" # We will use the entire chstone-v0 dataset for testing.\n",
" test_benchmarks = list(chstone.benchmarks())\n",
" # Each dataset has a `benchmarks()` method that returns an iterator over the\n",
" # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n",
" # handful of benchmarks for training and validation.\n",
" train_benchmarks = list(islice(npb.benchmarks(), 55))\n",
" train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n",
" # We will use the entire chstone-v0 dataset for testing.\n",
" test_benchmarks = list(chstone.benchmarks())\n",
"\n",
"print(\"Number of benchmarks for training:\", len(train_benchmarks))\n",
"print(\"Number of benchmarks for validation:\", len(val_benchmarks))\n",
Expand Down Expand Up @@ -221,11 +221,11 @@
"from compiler_gym.wrappers import CycleOverBenchmarks\n",
"\n",
"def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:\n",
" \"\"\"Make a reinforcement learning environment that cycles over the\n",
" set of training benchmarks in use.\n",
" \"\"\"\n",
" del args # Unused env_config argument passed by ray\n",
" return CycleOverBenchmarks(make_env(), train_benchmarks)\n",
" \"\"\"Make a reinforcement learning environment that cycles over the\n",
" set of training benchmarks in use.\n",
" \"\"\"\n",
" del args # Unused env_config argument passed by ray\n",
" return CycleOverBenchmarks(make_env(), train_benchmarks)\n",
"\n",
"tune.register_env(\"compiler_gym\", make_training_env)"
]
Expand All @@ -245,12 +245,12 @@
"# Lets cycle through a few calls to reset() to demonstrate that this environment\n",
"# selects a new benchmark for each episode.\n",
"with make_training_env() as env:\n",
" env.reset()\n",
" print(env.benchmark)\n",
" env.reset()\n",
" print(env.benchmark)\n",
" env.reset()\n",
" print(env.benchmark)"
" env.reset()\n",
" print(env.benchmark)\n",
" env.reset()\n",
" print(env.benchmark)\n",
" env.reset()\n",
" print(env.benchmark)"
]
},
{
Expand Down Expand Up @@ -282,7 +282,7 @@
"\n",
"# (Re)Start the ray runtime.\n",
"if ray.is_initialized():\n",
" ray.shutdown()\n",
" ray.shutdown()\n",
"ray.init(include_dashboard=False, ignore_reinit_error=True)\n",
"\n",
"tune.register_env(\"compiler_gym\", make_training_env)\n",
Expand Down Expand Up @@ -370,18 +370,18 @@
"# performance on a set of benchmarks.\n",
"\n",
"def run_agent_on_benchmarks(benchmarks):\n",
" \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n",
" with make_env() as env:\n",
" \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n",
" rewards = []\n",
" for i, benchmark in enumerate(benchmarks, start=1):\n",
" observation, done = env.reset(benchmark=benchmark), False\n",
" while not done:\n",
" action = agent.compute_action(observation)\n",
" observation, _, done, _ = env.step(action)\n",
" rewards.append(env.episode_reward)\n",
" print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n",
" with make_env() as env:\n",
" for i, benchmark in enumerate(benchmarks, start=1):\n",
" observation, done = env.reset(benchmark=benchmark), False\n",
" while not done:\n",
" action = agent.compute_action(observation)\n",
" observation, _, done, _ = env.step(action)\n",
" rewards.append(env.episode_reward)\n",
" print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n",
"\n",
" return rewards\n",
" return rewards\n",
"\n",
"# Evaluate agent performance on the validation set.\n",
"val_rewards = run_agent_on_benchmarks(val_benchmarks)"
Expand Down Expand Up @@ -417,14 +417,15 @@
"outputs": [],
"source": [
"# Finally lets plot our results to see how we did!\n",
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"\n",
"def plot_results(x, y, name, ax):\n",
" plt.sca(ax)\n",
" plt.bar(range(len(y)), y)\n",
" plt.ylabel(\"Reward (higher is better)\")\n",
" plt.xticks(range(len(x)), x, rotation = 90)\n",
" plt.title(f\"Performance on {name} set\")\n",
" plt.sca(ax)\n",
" plt.bar(range(len(y)), y)\n",
" plt.ylabel(\"Reward (higher is better)\")\n",
" plt.xticks(range(len(x)), x, rotation = 90)\n",
" plt.title(f\"Performance on {name} set\")\n",
"\n",
"fig, (ax1, ax2) = plt.subplots(1, 2)\n",
"fig.set_size_inches(13, 3)\n",
Expand Down

0 comments on commit a9c62da

Please sign in to comment.