|
182 | 182 | "from itertools import islice\n",
|
183 | 183 | "\n",
|
184 | 184 | "with make_env() as env:\n",
|
185 |
| - " # The two datasets we will be using:\n", |
186 |
| - " npb = env.datasets[\"npb-v0\"]\n", |
187 |
| - " chstone = env.datasets[\"chstone-v0\"]\n", |
| 185 | + " # The two datasets we will be using:\n", |
| 186 | + " npb = env.datasets[\"npb-v0\"]\n", |
| 187 | + " chstone = env.datasets[\"chstone-v0\"]\n", |
188 | 188 | "\n",
|
189 |
| - " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", |
190 |
| - " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", |
191 |
| - " # handful of benchmarks for training and validation.\n", |
192 |
| - " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", |
193 |
| - " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", |
194 |
| - " # We will use the entire chstone-v0 dataset for testing.\n", |
195 |
| - " test_benchmarks = list(chstone.benchmarks())\n", |
| 189 | + " # Each dataset has a `benchmarks()` method that returns an iterator over the\n", |
| 190 | + " # benchmarks within the dataset. Here we will use iterator sliceing to grab a \n", |
| 191 | + " # handful of benchmarks for training and validation.\n", |
| 192 | + " train_benchmarks = list(islice(npb.benchmarks(), 55))\n", |
| 193 | + " train_benchmarks, val_benchmarks = train_benchmarks[:50], train_benchmarks[50:]\n", |
| 194 | + " # We will use the entire chstone-v0 dataset for testing.\n", |
| 195 | + " test_benchmarks = list(chstone.benchmarks())\n", |
196 | 196 | "\n",
|
197 | 197 | "print(\"Number of benchmarks for training:\", len(train_benchmarks))\n",
|
198 | 198 | "print(\"Number of benchmarks for validation:\", len(val_benchmarks))\n",
|
|
221 | 221 | "from compiler_gym.wrappers import CycleOverBenchmarks\n",
|
222 | 222 | "\n",
|
223 | 223 | "def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:\n",
|
224 |
| - " \"\"\"Make a reinforcement learning environment that cycles over the\n", |
225 |
| - " set of training benchmarks in use.\n", |
226 |
| - " \"\"\"\n", |
227 |
| - " del args # Unused env_config argument passed by ray\n", |
228 |
| - " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", |
| 224 | + " \"\"\"Make a reinforcement learning environment that cycles over the\n", |
| 225 | + " set of training benchmarks in use.\n", |
| 226 | + " \"\"\"\n", |
| 227 | + " del args # Unused env_config argument passed by ray\n", |
| 228 | + " return CycleOverBenchmarks(make_env(), train_benchmarks)\n", |
229 | 229 | "\n",
|
230 | 230 | "tune.register_env(\"compiler_gym\", make_training_env)"
|
231 | 231 | ]
|
|
245 | 245 | "# Lets cycle through a few calls to reset() to demonstrate that this environment\n",
|
246 | 246 | "# selects a new benchmark for each episode.\n",
|
247 | 247 | "with make_training_env() as env:\n",
|
248 |
| - " env.reset()\n", |
249 |
| - " print(env.benchmark)\n", |
250 |
| - " env.reset()\n", |
251 |
| - " print(env.benchmark)\n", |
252 |
| - " env.reset()\n", |
253 |
| - " print(env.benchmark)" |
| 248 | + " env.reset()\n", |
| 249 | + " print(env.benchmark)\n", |
| 250 | + " env.reset()\n", |
| 251 | + " print(env.benchmark)\n", |
| 252 | + " env.reset()\n", |
| 253 | + " print(env.benchmark)" |
254 | 254 | ]
|
255 | 255 | },
|
256 | 256 | {
|
|
282 | 282 | "\n",
|
283 | 283 | "# (Re)Start the ray runtime.\n",
|
284 | 284 | "if ray.is_initialized():\n",
|
285 |
| - " ray.shutdown()\n", |
| 285 | + " ray.shutdown()\n", |
286 | 286 | "ray.init(include_dashboard=False, ignore_reinit_error=True)\n",
|
287 | 287 | "\n",
|
288 | 288 | "tune.register_env(\"compiler_gym\", make_training_env)\n",
|
|
370 | 370 | "# performance on a set of benchmarks.\n",
|
371 | 371 | "\n",
|
372 | 372 | "def run_agent_on_benchmarks(benchmarks):\n",
|
373 |
| - " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", |
374 |
| - " with make_env() as env:\n", |
| 373 | + " \"\"\"Run agent on a list of benchmarks and return a list of cumulative rewards.\"\"\"\n", |
375 | 374 | " rewards = []\n",
|
376 |
| - " for i, benchmark in enumerate(benchmarks, start=1):\n", |
377 |
| - " observation, done = env.reset(benchmark=benchmark), False\n", |
378 |
| - " while not done:\n", |
379 |
| - " action = agent.compute_action(observation)\n", |
380 |
| - " observation, _, done, _ = env.step(action)\n", |
381 |
| - " rewards.append(env.episode_reward)\n", |
382 |
| - " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", |
| 375 | + " with make_env() as env:\n", |
| 376 | + " for i, benchmark in enumerate(benchmarks, start=1):\n", |
| 377 | + " observation, done = env.reset(benchmark=benchmark), False\n", |
| 378 | + " while not done:\n", |
| 379 | + " action = agent.compute_action(observation)\n", |
| 380 | + " observation, _, done, _ = env.step(action)\n", |
| 381 | + " rewards.append(env.episode_reward)\n", |
| 382 | + " print(f\"[{i}/{len(benchmarks)}] {env.state}\")\n", |
383 | 383 | "\n",
|
384 |
| - " return rewards\n", |
| 384 | + " return rewards\n", |
385 | 385 | "\n",
|
386 | 386 | "# Evaluate agent performance on the validation set.\n",
|
387 | 387 | "val_rewards = run_agent_on_benchmarks(val_benchmarks)"
|
|
417 | 417 | "outputs": [],
|
418 | 418 | "source": [
|
419 | 419 | "# Finally lets plot our results to see how we did!\n",
|
| 420 | + "%matplotlib inline\n", |
420 | 421 | "from matplotlib import pyplot as plt\n",
|
421 | 422 | "\n",
|
422 | 423 | "def plot_results(x, y, name, ax):\n",
|
423 |
| - " plt.sca(ax)\n", |
424 |
| - " plt.bar(range(len(y)), y)\n", |
425 |
| - " plt.ylabel(\"Reward (higher is better)\")\n", |
426 |
| - " plt.xticks(range(len(x)), x, rotation = 90)\n", |
427 |
| - " plt.title(f\"Performance on {name} set\")\n", |
| 424 | + " plt.sca(ax)\n", |
| 425 | + " plt.bar(range(len(y)), y)\n", |
| 426 | + " plt.ylabel(\"Reward (higher is better)\")\n", |
| 427 | + " plt.xticks(range(len(x)), x, rotation = 90)\n", |
| 428 | + " plt.title(f\"Performance on {name} set\")\n", |
428 | 429 | "\n",
|
429 | 430 | "fig, (ax1, ax2) = plt.subplots(1, 2)\n",
|
430 | 431 | "fig.set_size_inches(13, 3)\n",
|
|
0 commit comments