Skip to main content

Bayesian Neural Networks with ADVI

Variational inference in Bean Machine

Bean Machine also includes support for Variational Inference similar to Pyro. In this tutorial we will be exploring settings and scenarios where one might benefit from using a variational inference approach as opposed to the MCMC methods we often use by default.

Learning Outcomes​

  • Create and run stochastic variational inference on a Bean Machine model
  • To understand which probabilistic models most benefit from a VI approach
  • To diagnose and debug problems with VI models

Prerequisties​

We will be using the following packages within this tutorial.

  • arviz and bokeh for interactive visualizations; and
  • pandas for data manipulation.
# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
import torch.optim as optim

from beanmachine.ppl.inference.vi.variational_infer import (
kl_reverse,
monte_carlo_approximate_reparam,
VariationalInfer,
)
from beanmachine.ppl.inference.vi.autoguide import ADVI
from beanmachine.tutorials.utils import plots
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.palettes import Colorblind3
from bokeh.plotting import gridplot, show
from IPython.display import Markdown
from tqdm import tqdm

What is Variational Inference?​

In sampling-based inference algorithms, we take define a probability distribution represening model p(z∣y)p(z | y) with latent variables zz and observed data yy. Because sampling from most distributions is challenging to do directly, we then use various techniques to generate samples.

Alternatively, Variational Inference is a class of inference algorithms which take a parameterised family of surrogate probability distributions qΟ•(z)q_\phi(z) and uses optimisation to find the best parameters Ο•\phi which are closest to the posterior distribution p(z∣y)p(z | y).

The main disadvantage of this class of algorithms is they struggle in learning a distribution that behaves the same in the tails as the true distribution, as well as underestimating the variance in the true distribution. The main advantage as we will see is the algorithms themselves are substantially faster.

Data​

We will demonstrate how you might use variational inference using a toy dataset which generates overlapping half-circles.

from sklearn.datasets import make_moons

X, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)

color = ['yellow' if y == 0 else 'cyan' for y in Y]
output_notebook(hide_banner=True)

cds = ColumnDataSource({"x": X[:, 0].tolist(), "y": X[:, 1].tolist(), "color": color})

moon_plot = plots.scatter_plot(
plot_sources=[cds],
tooltips=[[("x", "@y{0.000}"), ("y", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "X",
"y_axis_label": "Y",
"title": f"Binary classification data",
},
plot_kwargs={"color": "color"},
)
show(moon_plot)
loading...

To do well on this model, we will need to explicitly model the non-linearity of this data.

Toy Model​

But first let's use a simpler example to use all the machinery at play, and start with a very simple model.

import beanmachine.ppl as bm
import torch
import torch.distributions as dist

@bm.random_variable
def mu():
return dist.Normal(0., 1.)

@bm.random_variable
def x():
return dist.Normal(mu(), 1.)

To approximate mu we will create a guideguide distribution of a Normal distribution and make the parameters learnable.

@bm.param
def phi():
return torch.zeros(2)

@bm.random_variable
def q_mu():
phi_mean, phi_sd = phi()
softplus = torch.nn.Softplus()
return dist.Normal(phi_mean, softplus(phi_sd))
world = VariationalInfer(
queries_to_guides={mu(): q_mu()},
observations={x(): 5.},
optimizer=lambda params: optim.Adam(params, lr=1e-1),
).infer(
num_steps=100,
num_samples=5,
discrepancy_fn=kl_reverse,
mc_approx=monte_carlo_approximate_reparam,
)
world.get_variable(q_mu()).distribution
Out:

0%| | 0/100 [00:00<?, ?it/s]

Out:

Normal(loc: 2.4571237564086914, scale: 0.6683022379875183)

In this particular setting, there is a closed-form solution for the posterior in terms of conjugate priors, so we can compare our variational distribution to the true posterior.

true_posterior = dist.Normal(2.5, torch.tensor(0.5).sqrt())
true_posterior
Out:

Normal(loc: 2.5, scale: 0.7071067690849304)

We also compare against a NUTS sampler which is asymptotically guaranteed to converge to the true posterior.

nuts_samples = bm.GlobalNoUTurnSampler().infer(
[mu()],
{x(): torch.tensor(5.)},
num_samples=2000,
num_chains=1,
)
Out:

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

nuts_samples[mu()].std()
Out:

tensor(0.6904)

As you can see, the mean is of ΞΌ\mu is reasonably approximated but it has less success approximating the standard deviation.

Bayesian Neural Network model​

We will model this problem using a Bayesian neural network (BNN). A BNN is a probabilistic model where we model the uncertainty of the weights of the neural network as latent variables in our model.

import beanmachine.ppl as bm
import torch
import torch.distributions as dist


class BayesianNeuralNetwork:
def __init__(self, X, Y, hidden_size):
self.X = X
self.Y = Y
self.hidden_size = hidden_size

@bm.random_variable
def input_layer(self):
return dist.Normal(0., 1.).expand((self.X.shape[1], self.hidden_size))

@bm.random_variable
def hidden_layer(self):
return dist.Normal(0., 1.).expand((self.hidden_size, self.hidden_size))

@bm.random_variable
def output_layer(self):
return dist.Normal(0., 1.).expand((self.hidden_size,))

@bm.random_variable
def forward(self):
y1 = torch.tanh(torch.mm(self.X, self.input_layer()))
y2 = torch.tanh(torch.mm(y1, self.hidden_layer()))
y3 = torch.sigmoid(torch.matmul(y2, self.output_layer()))
return dist.Bernoulli(y3)

We create queries and observations just as we did for MCMC inference

X = torch.tensor(X, dtype=torch.float)
Y = torch.tensor(Y, dtype=torch.float)
nn = BayesianNeuralNetwork(X, Y, 10)

queries = [nn.input_layer(), nn.hidden_layer(), nn.output_layer()]
observations = {nn.forward(): Y}

The easiest way to get started using the variational inference method is using ADVI. This will automatically select a guide distribution to match up with each of your queries.

bnn_world = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=4e-2),
).infer(
num_steps=8000,
num_samples=1,
discrepancy_fn=kl_reverse,
)
Out:

0%| | 0/8000 [00:00<?, ?it/s]

Notice, how quickly fitting this model took to run. Modify the above cell to use bm.GlobalNoUTurnSampler() instead of ADVI to really appreciate the speed difference.

Monitoring and diagnosing model inferred with Variational Inference​

We can plot the divergence over time and to make sure the optimization didn't get stuck. The divergence measures the difference the approximate distribution we are learning and the true distribution we are trying to approximate.

vi = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=1e-2),
)

num_steps = 1000
losses = []
for i in tqdm(range(num_steps)):
loss = vi.step()
losses.append(loss.item())

vi.initialize_world()
Out:

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [00:04<00:00, 202.75it/s]

Out:

<beanmachine.ppl.inference.vi.variational_world.VariationalWorld at 0x7fa371dd5ed0>

output_notebook(hide_banner=True)

cds = ColumnDataSource({"x": range(num_steps), "y": losses})

elbo_plot = plots.line_plot(
plot_sources=[cds],
tooltips=[[("ELBO", "@y{0.000}"), ("step", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "steps",
"y_axis_label": "ELBO",
"title": f"Divergence between model and guide programs",
},
plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(elbo_plot)
loading...

We can also check the convergence of the parameters

vi_world = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=1e-2),
)

num_steps = 5000
vals = []
for i in tqdm(range(num_steps)):
vi_world.step()
param = list(vi_world.params.values())[0]
vals.append(param[0, 0].item())

Out:

100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [00:22<00:00, 219.71it/s]

output_notebook(hide_banner=True)

cds = ColumnDataSource({"x": range(num_steps), "y": vals})

param_plot = plots.line_plot(
plot_sources=[cds],
tooltips=[[("param", "@y{0.000}"), ("step", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "steps",
"y_axis_label": "param",
"title": f"Value of param during VI",
},
plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(param_plot)
loading...

We need to introduce a helper method. The default method for accessing the guide distribution associated with each model variable is quite long. To make it easier to understand what we are doing when we generate predictions, we are adding the following shorthand.

def d(x):
return bnn_world.get_guide_distribution(x)

We create posterior predictive samples by using the learned distributions to generate observations. This is just the forward method from our original model with some minor changes to enable batch sampling.

def predictions(X, samples=100):
il = d(nn.input_layer()).expand((samples,-1, -1)).sample()
y1 = torch.tanh(torch.matmul(X, il))
hl = d(nn.hidden_layer()).expand((samples, -1, -1)).sample()
y2 = torch.tanh(torch.matmul(y1, hl))
ol = d(nn.output_layer()).expand((samples, -1)).sample()
y3 = torch.sigmoid(torch.einsum('bij,bj->bi', y2, ol))
return dist.Bernoulli(y3).sample()

We visualise our predictions to show that the confidence of the prediction decreases as we approach the boundary.

import numpy as np

x_points = 100
y_points = 100

x = np.linspace(-3, 3, x_points)
y = np.linspace(-3, 3, y_points)
xx, yy = np.meshgrid(x, y)
grid_2d = torch.tensor([xx, yy], dtype=torch.float).reshape(2, -1).T
preds = predictions(grid_2d).mean(axis=0).reshape(x_points, y_points)


from bokeh.plotting import figure, output_file, show
p = figure(width=400, height=400)
p.x_range.range_padding = p.y_range.range_padding = 0

p.image(image=[preds.numpy()], x=-2, y=-2, dw=5, dh=4, palette="PRGn11", level="image")
p.grid.grid_line_width = 0

p.circle(X[Y==0,0].numpy(), X[Y==0,1].numpy(), color="yellow")
p.circle(X[Y==1,0].numpy(), X[Y==1,1].numpy(), color="cyan")

show(p)
Out:

/home/zv/upstream/miniconda3/envs/bean-machine/lib/python3.7/site-packages/ipykernel_launcher.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:204.)

import sys

loading...

As we can see, the posterior inference is effective and not just separating the classes, but also highlighting where we uncertainty within our model.