This repository contains the accompanying code for the paper "Effective Bayesian Heteroscedastic Regression with Deep Neural Networks" published at NeurIPS 2023. The PDF is available here.
For experiments, we used python version 3.9 and torch version 1.12.1.
Additional online dependencies are listed in requirements.txt
and have to be installed with pip install -r requirements.txt
.
Further, dependencies/
contain modified versions of laplace-torch
and asdl
and have to be installed with pip install -e dependencies/laplace
and pip install -e dependencies/asdl
.
To install local utilities, run pip install -e .
in the root directory of this repository.
A simple example using an MLP on the Skafte data set can be found in plot_skafte_example.py
.
It reproduces the first figure in the paper comparing homoscedastic to heteroscedastic regression and additionally showing the epistmic uncertainty given by the respective Laplace approximations.
We extended laplace-torch to support (natural) heteroscedastic Gaussian likelihoods and asdl to support their Fisher/GGN curvature approximations as described in the paper.
The modifications can be found in dependencies/
.
Laplace requires the natural parameterization of the likelihood for positie semidefiniteness of the Hessian as described in the paper.
To accomplish this, the two outputs have to be [eta_1, eta_2]
as in hetreg.models.NaturalHead
.
To use an existing mean-variance parameterized network, the hetreg.models.NaturalReparamHead
can be applied and transforms the outputs accordingly.
The remaining code is for the experiments and figures in the paper.
To use the proposed rotated image regression datasets, see hetreg/image_datasets.py
and their use in run_image_regression.py
.
In hetreg/models.py
we provide implementations of an MLP
, LeNet
and ResNet18
that are compatible with our method.
Image, UCI, and CRISPR regression commands can be generated by python commands/jobs_image_generate.py
, python commands/jobs_uci_generate.py
, and python commands/jobs_crispr_generate.py
respectively, which also serve as examples for running our method.
For example, python commands/jobs_image_generate.py > jobs.sh
generates the commands for image regression line-by-line so they can be submitted to a scheduler.
Alternatively, an example is shown in plot_skafte_example.py
for the toy example visualizations used in the paper.
Both rely on the corresponding scripts run_image_regression.py
and run_uci_crispr_egression.py
to be run and save results to wandb.
The result tables can then be obtained by aggregation on the wandb frontend or downloading results from the API.
For the CRISPR-Cas-13 datasets we use, we are thankful to the authors of previous work for having shared the datasets.