Skip to content

Commit 750b733

Browse files
committed
[examples] Fix notebook code style.
1 parent 2813d15 commit 750b733

File tree

1 file changed

+38
-37
lines changed

1 file changed

+38
-37
lines changed

examples/rllib.ipynb

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,17 @@
182182
"from itertools import islice\n",
183183
"\n",
184184
"with make_env() as env:\n",
185-
" # The two datasets we will be using:\n",
186-
" npb = env.datasets[\"npb-v0\"]\n",
187-
" chstone = env.datasets[\"chstone-v0\"]\n",
185+
" # The two datasets we will be using:\n",
186+
" npb = env.datasets[\"npb-v0\"]\n",
187+
" chstone = env.datasets[\"chstone-v0\"]\n",
188188
"\n",
189-
" # Each dataset has a `benchmarks()` method that returns an iterator over the\n",
190-
" # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n",
191-
" # handful of benchmarks for training and validation.\n",
192-
" train_benchmarks = list(islice(npb.benchmarks(), 55))\n",
193-
" train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n",
194-
" # We will use the entire chstone-v0 dataset for testing.\n",
195-
" test_benchmarks = list(chstone.benchmarks())\n",
189+
" # Each dataset has a `benchmarks()` method that returns an iterator over the\n",
190+
" # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n",
191+
" # handful of benchmarks for training and validation.\n",
192+
" train_benchmarks = list(islice(npb.benchmarks(), 55))\n",
193+
" train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n",
194+
" # We will use the entire chstone-v0 dataset for testing.\n",
195+
" test_benchmarks = list(chstone.benchmarks())\n",
196196
"\n",
197197
"print(\"Number of benchmarks for training:\", len(train_benchmarks))\n",
198198
"print(\"Number of benchmarks for validation:\", len(val_benchmarks))\n",
@@ -221,11 +221,11 @@
221221
"from compiler_gym.wrappers import CycleOverBenchmarks\n",
222222
"\n",
223223
"def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:\n",
224-
" \"\"\"Make a reinforcement learning environment that cycles over the\n",
225-
" set of training benchmarks in use.\n",
226-
" \"\"\"\n",
227-
" del args # Unused env_config argument passed by ray\n",
228-
" return CycleOverBenchmarks(make_env(), train_benchmarks)\n",
224+
" \"\"\"Make a reinforcement learning environment that cycles over the\n",
225+
" set of training benchmarks in use.\n",
226+
" \"\"\"\n",
227+
" del args # Unused env_config argument passed by ray\n",
228+
" return CycleOverBenchmarks(make_env(), train_benchmarks)\n",
229229
"\n",
230230
"tune.register_env(\"compiler_gym\", make_training_env)"
231231
]
@@ -245,12 +245,12 @@
245245
"# Lets cycle through a few calls to reset() to demonstrate that this environment\n",
246246
"# selects a new benchmark for each episode.\n",
247247
"with make_training_env() as env:\n",
248-
" env.reset()\n",
249-
" print(env.benchmark)\n",
250-
" env.reset()\n",
251-
" print(env.benchmark)\n",
252-
" env.reset()\n",
253-
" print(env.benchmark)"
248+
" env.reset()\n",
249+
" print(env.benchmark)\n",
250+
" env.reset()\n",
251+
" print(env.benchmark)\n",
252+
" env.reset()\n",
253+
" print(env.benchmark)"
254254
]
255255
},
256256
{
@@ -282,7 +282,7 @@
282282
"\n",
283283
"# (Re)Start the ray runtime.\n",
284284
"if ray.is_initialized():\n",
285-
" ray.shutdown()\n",
285+
" ray.shutdown()\n",
286286
"ray.init(include_dashboard=False, ignore_reinit_error=True)\n",
287287
"\n",
288288
"tune.register_env(\"compiler_gym\", make_training_env)\n",
@@ -370,18 +370,18 @@
370370
"# performance on a set of benchmarks.\n",
371371
"\n",
372372
"def run_agent_on_benchmarks(benchmarks):\n",
373-
" \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n",
374-
" with make_env() as env:\n",
373+
" \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n",
375374
" rewards = []\n",
376-
" for i, benchmark in enumerate(benchmarks, start=1):\n",
377-
" observation, done = env.reset(benchmark=benchmark), False\n",
378-
" while not done:\n",
379-
" action = agent.compute_action(observation)\n",
380-
" observation, _, done, _ = env.step(action)\n",
381-
" rewards.append(env.episode_reward)\n",
382-
" print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n",
375+
" with make_env() as env:\n",
376+
" for i, benchmark in enumerate(benchmarks, start=1):\n",
377+
" observation, done = env.reset(benchmark=benchmark), False\n",
378+
" while not done:\n",
379+
" action = agent.compute_action(observation)\n",
380+
" observation, _, done, _ = env.step(action)\n",
381+
" rewards.append(env.episode_reward)\n",
382+
" print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n",
383383
"\n",
384-
" return rewards\n",
384+
" return rewards\n",
385385
"\n",
386386
"# Evaluate agent performance on the validation set.\n",
387387
"val_rewards = run_agent_on_benchmarks(val_benchmarks)"
@@ -417,14 +417,15 @@
417417
"outputs": [],
418418
"source": [
419419
"# Finally lets plot our results to see how we did!\n",
420+
"%matplotlib inline\n",
420421
"from matplotlib import pyplot as plt\n",
421422
"\n",
422423
"def plot_results(x, y, name, ax):\n",
423-
" plt.sca(ax)\n",
424-
" plt.bar(range(len(y)), y)\n",
425-
" plt.ylabel(\"Reward (higher is better)\")\n",
426-
" plt.xticks(range(len(x)), x, rotation = 90)\n",
427-
" plt.title(f\"Performance on {name} set\")\n",
424+
" plt.sca(ax)\n",
425+
" plt.bar(range(len(y)), y)\n",
426+
" plt.ylabel(\"Reward (higher is better)\")\n",
427+
" plt.xticks(range(len(x)), x, rotation = 90)\n",
428+
" plt.title(f\"Performance on {name} set\")\n",
428429
"\n",
429430
"fig, (ax1, ax2) = plt.subplots(1, 2)\n",
430431
"fig.set_size_inches(13, 3)\n",

0 commit comments

Comments
 (0)