|
62 | 62 | "outputs": [],
|
63 | 63 | "source": [
|
64 | 64 | "episode_x_boundaries = torch.linspace(0, n_episodes, n_episodes + 1)\n",
|
65 |
| - "xs = torch.stack([torch.linspace(episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode) for i in range(n_episodes)])\n", |
| 65 | + "xs = torch.stack(\n", |
| 66 | + " [\n", |
| 67 | + " torch.linspace(\n", |
| 68 | + " episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode\n", |
| 69 | + " )\n", |
| 70 | + " for i in range(n_episodes)\n", |
| 71 | + " ]\n", |
| 72 | + ")\n", |
66 | 73 | "ys = torch.stack([true_f(x) + y_sd * torch.randn_like(x) for x in xs])"
|
67 | 74 | ]
|
68 | 75 | },
|
|
85 | 92 | "source": [
|
86 | 93 | "plt_linsp = torch.linspace(-1, episode_x_boundaries[-1] + 1, 1000)\n",
|
87 | 94 | "\n",
|
| 95 | + "\n", |
88 | 96 | "def plot_data(ax, up_to_episode=None):\n",
|
89 | 97 | " if up_to_episode is None:\n",
|
90 | 98 | " up_to_episode = n_episodes\n",
|
91 |
| - " \n", |
92 |
| - " ax.plot(xs.flatten(), ys.flatten(), 'o', color='gray', alpha=0.2)\n", |
| 99 | + "\n", |
| 100 | + " ax.plot(xs.flatten(), ys.flatten(), \"o\", color=\"gray\", alpha=0.2)\n", |
93 | 101 | " for i in range(up_to_episode):\n",
|
94 |
| - " ax.plot(xs[i], ys[i], 'o', color='orange')\n", |
95 |
| - " \n", |
| 102 | + " ax.plot(xs[i], ys[i], \"o\", color=\"orange\")\n", |
| 103 | + "\n", |
96 | 104 | " for v in episode_x_boundaries:\n",
|
97 |
| - " ax.axvline(v, color='gray', linestyle='--', alpha=0.75)\n", |
98 |
| - " ax.plot(plt_linsp, true_f(plt_linsp), color='green', zorder=10)\n", |
99 |
| - " ax.set_ylim(-2., 2.5)\n", |
| 105 | + " ax.axvline(v, color=\"gray\", linestyle=\"--\", alpha=0.75)\n", |
| 106 | + " ax.plot(plt_linsp, true_f(plt_linsp), color=\"green\", zorder=10)\n", |
| 107 | + " ax.set_ylim(-2.0, 2.5)\n", |
| 108 | + "\n", |
100 | 109 | "\n",
|
101 | 110 | "fig, ax = plt.subplots()\n",
|
102 | 111 | "plot_data(ax)"
|
|
166 | 175 | "outputs": [],
|
167 | 176 | "source": [
|
168 | 177 | "def log_prior(p, prior_mean, prior_sd: float):\n",
|
169 |
| - " all_vals = tree_map(lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False).log_prob(p).sum(), p, prior_mean, prior_sd)\n", |
| 178 | + " all_vals = tree_map(\n", |
| 179 | + " lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False)\n", |
| 180 | + " .log_prob(p)\n", |
| 181 | + " .sum(),\n", |
| 182 | + " p,\n", |
| 183 | + " prior_mean,\n", |
| 184 | + " prior_sd,\n", |
| 185 | + " )\n", |
170 | 186 | " return tree_reduce(torch.add, all_vals)\n",
|
171 |
| - " \n", |
| 187 | + "\n", |
| 188 | + "\n", |
172 | 189 | "def log_likelihood(y_pred, y):\n",
|
173 |
| - " return torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()" |
| 190 | + " return (\n", |
| 191 | + " torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()\n", |
| 192 | + " )" |
174 | 193 | ]
|
175 | 194 | },
|
176 | 195 | {
|
|
182 | 201 | "def log_posterior(params, batch, prior_mean, prior_sd):\n",
|
183 | 202 | " x, y = batch\n",
|
184 | 203 | " y_pred = mlp_functional(params, x)\n",
|
185 |
| - " log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n", |
| 204 | + " log_post = (\n", |
| 205 | + " log_likelihood(y_pred, y)\n", |
| 206 | + " + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n", |
| 207 | + " )\n", |
186 | 208 | " return log_post, y_pred"
|
187 | 209 | ]
|
188 | 210 | },
|
|
213 | 235 | "outputs": [],
|
214 | 236 | "source": [
|
215 | 237 | "batch_size = 3\n",
|
216 |
| - "dataloaders = [torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)), batch_size=batch_size) for x, y in zip(xs, ys)]" |
| 238 | + "dataloaders = [\n", |
| 239 | + " torch.utils.data.DataLoader(\n", |
| 240 | + " torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)),\n", |
| 241 | + " batch_size=batch_size,\n", |
| 242 | + " )\n", |
| 243 | + " for x, y in zip(xs, ys)\n", |
| 244 | + "]" |
217 | 245 | ]
|
218 | 246 | },
|
219 | 247 | {
|
|
227 | 255 | " for _ in range(n_epochs):\n",
|
228 | 256 | " for batch in dataloader:\n",
|
229 | 257 | " opt.zero_grad()\n",
|
230 |
| - " loss = -log_posterior(dict(mlp.named_parameters()), batch, prior_mean, prior_sd)[0]\n", |
| 258 | + " loss = -log_posterior(\n", |
| 259 | + " dict(mlp.named_parameters()), batch, prior_mean, prior_sd\n", |
| 260 | + " )[0]\n", |
231 | 261 | " loss.backward()\n",
|
232 | 262 | " opt.step()"
|
233 | 263 | ]
|
|
252 | 282 | "metadata": {},
|
253 | 283 | "outputs": [],
|
254 | 284 | "source": [
|
255 |
| - "def plot_predictions(params, ax, x, sd=y_sd, alpha=1.):\n", |
| 285 | + "def plot_predictions(params, ax, x, sd=y_sd, alpha=1.0):\n", |
256 | 286 | " preds = mlp_functional(params, x.unsqueeze(-1)).detach().numpy().squeeze()\n",
|
257 |
| - " ax.plot(x, preds, color='blue', alpha=alpha)\n", |
258 |
| - " ax.fill_between(x, preds - sd, preds + sd, color='blue', alpha=0.2)" |
| 287 | + " ax.plot(x, preds, color=\"blue\", alpha=alpha)\n", |
| 288 | + " ax.fill_between(x, preds - sd, preds + sd, color=\"blue\", alpha=0.2)" |
259 | 289 | ]
|
260 | 290 | },
|
261 | 291 | {
|
|
275 | 305 | }
|
276 | 306 | ],
|
277 | 307 | "source": [
|
278 |
| - "fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n", |
| 308 | + "fig, axes = plt.subplots(\n", |
| 309 | + " 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n", |
| 310 | + ")\n", |
279 | 311 | "\n",
|
280 | 312 | "for i, ax in enumerate(axes):\n",
|
281 |
| - " plot_data(ax, up_to_episode=i+1)\n", |
| 313 | + " plot_data(ax, up_to_episode=i + 1)\n", |
282 | 314 | " plot_predictions(trained_params[i], ax, plt_linsp)\n",
|
283 |
| - " ax.set_title(f\"After Episode {i+1}\")" |
| 315 | + " ax.set_title(f\"After Episode {i + 1}\")" |
284 | 316 | ]
|
285 | 317 | },
|
286 | 318 | {
|
|
318 | 350 | "def train_for_vi(dataloader, prior_mean, prior_sd, n_epochs=200, init_log_sds=None):\n",
|
319 | 351 | " seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)\n",
|
320 | 352 | " optimizer = torchopt.adam(lr=2e-3)\n",
|
321 |
| - " transform = posteriors.vi.diag.build(seq_log_post, optimizer, temperature=1/samps_per_episode, init_log_sds=init_log_sds, stl=False)\n", |
| 353 | + " transform = posteriors.vi.diag.build(\n", |
| 354 | + " seq_log_post,\n", |
| 355 | + " optimizer,\n", |
| 356 | + " temperature=1 / samps_per_episode,\n", |
| 357 | + " init_log_sds=init_log_sds,\n", |
| 358 | + " stl=False,\n", |
| 359 | + " )\n", |
322 | 360 | " state = transform.init(dict(mlp.named_parameters()))\n",
|
323 | 361 | " nelbos = []\n",
|
324 | 362 | " for _ in range(n_epochs):\n",
|
|
346 | 384 | "nelbos = []\n",
|
347 | 385 | "for i in range(n_episodes):\n",
|
348 | 386 | " seq_prior_mean = prior_mean if i == 0 else vi_states[i - 1].params\n",
|
349 |
| - " seq_prior_sd = prior_sd if i == 0 else tree_map(lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd ** 2), vi_states[i - 1].log_sd_diag)\n", |
350 |
| - " seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6., mlp.state_dict())\n", |
351 |
| - " state, nelbos_i = train_for_vi(dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds)\n", |
| 387 | + " seq_prior_sd = (\n", |
| 388 | + " prior_sd\n", |
| 389 | + " if i == 0\n", |
| 390 | + " else tree_map(\n", |
| 391 | + " lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd**2),\n", |
| 392 | + " vi_states[i - 1].log_sd_diag,\n", |
| 393 | + " )\n", |
| 394 | + " )\n", |
| 395 | + " seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6.0, mlp.state_dict())\n", |
| 396 | + " state, nelbos_i = train_for_vi(\n", |
| 397 | + " dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds\n", |
| 398 | + " )\n", |
352 | 399 | " vi_states += [state]\n",
|
353 | 400 | " nelbos += [nelbos_i]\n",
|
354 | 401 | " mlp.load_state_dict(vi_states[i].params)"
|
|
371 | 418 | }
|
372 | 419 | ],
|
373 | 420 | "source": [
|
374 |
| - "fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n", |
| 421 | + "fig, axes = plt.subplots(\n", |
| 422 | + " 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n", |
| 423 | + ")\n", |
375 | 424 | "\n",
|
376 | 425 | "n_samples = 20\n",
|
377 | 426 | "\n",
|
378 | 427 | "for i, ax in enumerate(axes):\n",
|
379 | 428 | " for _ in range(n_samples):\n",
|
380 | 429 | " sample = posteriors.vi.diag.sample(vi_states[i])\n",
|
381 | 430 | " plot_predictions(sample, ax, plt_linsp, sd=y_sd, alpha=0.2)\n",
|
382 |
| - " plot_data(ax, up_to_episode=i+1)\n", |
383 |
| - " ax.set_title(f\"After Episode {i+1}\")" |
| 431 | + " plot_data(ax, up_to_episode=i + 1)\n", |
| 432 | + " ax.set_title(f\"After Episode {i + 1}\")" |
384 | 433 | ]
|
385 | 434 | },
|
386 | 435 | {
|
|
0 commit comments