Skip to content

Commit a4bd04e

Browse files
authored
Update version and ruff (#140)
* Update ruff * Version bump
1 parent 9ef71d5 commit a4bd04e

File tree

6 files changed

+85
-35
lines changed

6 files changed

+85
-35
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.3.7
3+
rev: v0.11.5
44
hooks:
55
- id: ruff
66
args: ["--fix"]

examples/bayes_llama3/llama3/eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, config: FrozenConfigDict):
4343
config["pretrained_model_name_or_path"]
4444
)
4545
else:
46-
assert os.path.isdir(
47-
config["checkpoints_folder"]
48-
), "Provided checkpoints is not a path to a folder"
46+
assert os.path.isdir(config["checkpoints_folder"]), (
47+
"Provided checkpoints is not a path to a folder"
48+
)
4949
checkpoints = [
5050
os.path.join(config["checkpoints_folder"], path)
5151
for path in os.listdir(config["checkpoints_folder"])

examples/continual_regression.ipynb

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@
6262
"outputs": [],
6363
"source": [
6464
"episode_x_boundaries = torch.linspace(0, n_episodes, n_episodes + 1)\n",
65-
"xs = torch.stack([torch.linspace(episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode) for i in range(n_episodes)])\n",
65+
"xs = torch.stack(\n",
66+
" [\n",
67+
" torch.linspace(\n",
68+
" episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode\n",
69+
" )\n",
70+
" for i in range(n_episodes)\n",
71+
" ]\n",
72+
")\n",
6673
"ys = torch.stack([true_f(x) + y_sd * torch.randn_like(x) for x in xs])"
6774
]
6875
},
@@ -85,18 +92,20 @@
8592
"source": [
8693
"plt_linsp = torch.linspace(-1, episode_x_boundaries[-1] + 1, 1000)\n",
8794
"\n",
95+
"\n",
8896
"def plot_data(ax, up_to_episode=None):\n",
8997
" if up_to_episode is None:\n",
9098
" up_to_episode = n_episodes\n",
91-
" \n",
92-
" ax.plot(xs.flatten(), ys.flatten(), 'o', color='gray', alpha=0.2)\n",
99+
"\n",
100+
" ax.plot(xs.flatten(), ys.flatten(), \"o\", color=\"gray\", alpha=0.2)\n",
93101
" for i in range(up_to_episode):\n",
94-
" ax.plot(xs[i], ys[i], 'o', color='orange')\n",
95-
" \n",
102+
" ax.plot(xs[i], ys[i], \"o\", color=\"orange\")\n",
103+
"\n",
96104
" for v in episode_x_boundaries:\n",
97-
" ax.axvline(v, color='gray', linestyle='--', alpha=0.75)\n",
98-
" ax.plot(plt_linsp, true_f(plt_linsp), color='green', zorder=10)\n",
99-
" ax.set_ylim(-2., 2.5)\n",
105+
" ax.axvline(v, color=\"gray\", linestyle=\"--\", alpha=0.75)\n",
106+
" ax.plot(plt_linsp, true_f(plt_linsp), color=\"green\", zorder=10)\n",
107+
" ax.set_ylim(-2.0, 2.5)\n",
108+
"\n",
100109
"\n",
101110
"fig, ax = plt.subplots()\n",
102111
"plot_data(ax)"
@@ -166,11 +175,21 @@
166175
"outputs": [],
167176
"source": [
168177
"def log_prior(p, prior_mean, prior_sd: float):\n",
169-
" all_vals = tree_map(lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False).log_prob(p).sum(), p, prior_mean, prior_sd)\n",
178+
" all_vals = tree_map(\n",
179+
" lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False)\n",
180+
" .log_prob(p)\n",
181+
" .sum(),\n",
182+
" p,\n",
183+
" prior_mean,\n",
184+
" prior_sd,\n",
185+
" )\n",
170186
" return tree_reduce(torch.add, all_vals)\n",
171-
" \n",
187+
"\n",
188+
"\n",
172189
"def log_likelihood(y_pred, y):\n",
173-
" return torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()"
190+
" return (\n",
191+
" torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()\n",
192+
" )"
174193
]
175194
},
176195
{
@@ -182,7 +201,10 @@
182201
"def log_posterior(params, batch, prior_mean, prior_sd):\n",
183202
" x, y = batch\n",
184203
" y_pred = mlp_functional(params, x)\n",
185-
" log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n",
204+
" log_post = (\n",
205+
" log_likelihood(y_pred, y)\n",
206+
" + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n",
207+
" )\n",
186208
" return log_post, y_pred"
187209
]
188210
},
@@ -213,7 +235,13 @@
213235
"outputs": [],
214236
"source": [
215237
"batch_size = 3\n",
216-
"dataloaders = [torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)), batch_size=batch_size) for x, y in zip(xs, ys)]"
238+
"dataloaders = [\n",
239+
" torch.utils.data.DataLoader(\n",
240+
" torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)),\n",
241+
" batch_size=batch_size,\n",
242+
" )\n",
243+
" for x, y in zip(xs, ys)\n",
244+
"]"
217245
]
218246
},
219247
{
@@ -227,7 +255,9 @@
227255
" for _ in range(n_epochs):\n",
228256
" for batch in dataloader:\n",
229257
" opt.zero_grad()\n",
230-
" loss = -log_posterior(dict(mlp.named_parameters()), batch, prior_mean, prior_sd)[0]\n",
258+
" loss = -log_posterior(\n",
259+
" dict(mlp.named_parameters()), batch, prior_mean, prior_sd\n",
260+
" )[0]\n",
231261
" loss.backward()\n",
232262
" opt.step()"
233263
]
@@ -252,10 +282,10 @@
252282
"metadata": {},
253283
"outputs": [],
254284
"source": [
255-
"def plot_predictions(params, ax, x, sd=y_sd, alpha=1.):\n",
285+
"def plot_predictions(params, ax, x, sd=y_sd, alpha=1.0):\n",
256286
" preds = mlp_functional(params, x.unsqueeze(-1)).detach().numpy().squeeze()\n",
257-
" ax.plot(x, preds, color='blue', alpha=alpha)\n",
258-
" ax.fill_between(x, preds - sd, preds + sd, color='blue', alpha=0.2)"
287+
" ax.plot(x, preds, color=\"blue\", alpha=alpha)\n",
288+
" ax.fill_between(x, preds - sd, preds + sd, color=\"blue\", alpha=0.2)"
259289
]
260290
},
261291
{
@@ -275,12 +305,14 @@
275305
}
276306
],
277307
"source": [
278-
"fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n",
308+
"fig, axes = plt.subplots(\n",
309+
" 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n",
310+
")\n",
279311
"\n",
280312
"for i, ax in enumerate(axes):\n",
281-
" plot_data(ax, up_to_episode=i+1)\n",
313+
" plot_data(ax, up_to_episode=i + 1)\n",
282314
" plot_predictions(trained_params[i], ax, plt_linsp)\n",
283-
" ax.set_title(f\"After Episode {i+1}\")"
315+
" ax.set_title(f\"After Episode {i + 1}\")"
284316
]
285317
},
286318
{
@@ -318,7 +350,13 @@
318350
"def train_for_vi(dataloader, prior_mean, prior_sd, n_epochs=200, init_log_sds=None):\n",
319351
" seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)\n",
320352
" optimizer = torchopt.adam(lr=2e-3)\n",
321-
" transform = posteriors.vi.diag.build(seq_log_post, optimizer, temperature=1/samps_per_episode, init_log_sds=init_log_sds, stl=False)\n",
353+
" transform = posteriors.vi.diag.build(\n",
354+
" seq_log_post,\n",
355+
" optimizer,\n",
356+
" temperature=1 / samps_per_episode,\n",
357+
" init_log_sds=init_log_sds,\n",
358+
" stl=False,\n",
359+
" )\n",
322360
" state = transform.init(dict(mlp.named_parameters()))\n",
323361
" nelbos = []\n",
324362
" for _ in range(n_epochs):\n",
@@ -346,9 +384,18 @@
346384
"nelbos = []\n",
347385
"for i in range(n_episodes):\n",
348386
" seq_prior_mean = prior_mean if i == 0 else vi_states[i - 1].params\n",
349-
" seq_prior_sd = prior_sd if i == 0 else tree_map(lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd ** 2), vi_states[i - 1].log_sd_diag)\n",
350-
" seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6., mlp.state_dict())\n",
351-
" state, nelbos_i = train_for_vi(dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds)\n",
387+
" seq_prior_sd = (\n",
388+
" prior_sd\n",
389+
" if i == 0\n",
390+
" else tree_map(\n",
391+
" lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd**2),\n",
392+
" vi_states[i - 1].log_sd_diag,\n",
393+
" )\n",
394+
" )\n",
395+
" seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6.0, mlp.state_dict())\n",
396+
" state, nelbos_i = train_for_vi(\n",
397+
" dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds\n",
398+
" )\n",
352399
" vi_states += [state]\n",
353400
" nelbos += [nelbos_i]\n",
354401
" mlp.load_state_dict(vi_states[i].params)"
@@ -371,16 +418,18 @@
371418
}
372419
],
373420
"source": [
374-
"fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n",
421+
"fig, axes = plt.subplots(\n",
422+
" 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n",
423+
")\n",
375424
"\n",
376425
"n_samples = 20\n",
377426
"\n",
378427
"for i, ax in enumerate(axes):\n",
379428
" for _ in range(n_samples):\n",
380429
" sample = posteriors.vi.diag.sample(vi_states[i])\n",
381430
" plot_predictions(sample, ax, plt_linsp, sd=y_sd, alpha=0.2)\n",
382-
" plot_data(ax, up_to_episode=i+1)\n",
383-
" ax.set_title(f\"After Episode {i+1}\")"
431+
" plot_data(ax, up_to_episode=i + 1)\n",
432+
" ax.set_title(f\"After Episode {i + 1}\")"
384433
]
385434
},
386435
{

examples/pyro_pima_indians_sghmc.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@
365365
" samples[:, i] = torch.stack([state.params for state in states])\n",
366366
" if i > N_warmup:\n",
367367
" j = i - N_warmup\n",
368-
" gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup:i + 1])\n"
368+
" gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup : i + 1])"
369369
]
370370
},
371371
{
@@ -469,7 +469,7 @@
469469
"for ind, ax in enumerate(axes.flatten()):\n",
470470
" ax.hist(samples[:, N_warmup:, ind].flatten(), bins=50, density=True)\n",
471471
" ax.set_title(column_names[ind])\n",
472-
"fig.tight_layout()\n"
472+
"fig.tight_layout()"
473473
]
474474
},
475475
{

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "posteriors"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "Uncertainty quantification with PyTorch"
55
readme = "README.md"
66
requires-python =">=3.9"

tests/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def test_model_to_function():
110110

111111
func_output2 = func_lm(dict(lm.named_parameters()), input_ids, attention_mask)
112112

113-
assert type(output) == type(func_output1) == type(func_output2)
113+
assert type(output) is type(func_output1)
114+
assert type(output) is type(func_output2)
114115
assert torch.allclose(output["logits"], func_output1["logits"])
115116
assert torch.allclose(output["logits"], func_output2["logits"])
116117

0 commit comments

Comments
 (0)