Bayesian Neural Networks with ADVI
Variational inference in Bean Machine
Bean Machine also includes support for Variational Inference similar to Pyro. In this tutorial we will be exploring settings and scenarios where one might benefit from using a variational inference approach as opposed to the MCMC methods we often use by default.
Learning Outcomesβ
- Create and run stochastic variational inference on a Bean Machine model
- To understand which probabilistic models most benefit from a VI approach
- To diagnose and debug problems with VI models
Prerequistiesβ
We will be using the following packages within this tutorial.
- arviz and bokeh for interactive visualizations; and
- pandas for data manipulation.
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
import torch.optim as optim
from beanmachine.ppl.inference.vi.variational_infer import (
kl_reverse,
monte_carlo_approximate_reparam,
VariationalInfer,
)
from beanmachine.ppl.inference.vi.autoguide import ADVI
from beanmachine.tutorials.utils import plots
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.palettes import Colorblind3
from bokeh.plotting import gridplot, show
from IPython.display import Markdown
from tqdm import tqdm
What is Variational Inference?β
In sampling-based inference algorithms, we take define a probability distribution represening model with latent variables and observed data . Because sampling from most distributions is challenging to do directly, we then use various techniques to generate samples.
Alternatively, Variational Inference is a class of inference algorithms which take a parameterised family of surrogate probability distributions and uses optimisation to find the best parameters which are closest to the posterior distribution .
The main disadvantage of this class of algorithms is they struggle in learning a distribution that behaves the same in the tails as the true distribution, as well as underestimating the variance in the true distribution. The main advantage as we will see is the algorithms themselves are substantially faster.
Dataβ
We will demonstrate how you might use variational inference using a toy dataset which generates overlapping half-circles.
from sklearn.datasets import make_moons
X, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)
color = ['yellow' if y == 0 else 'cyan' for y in Y]
output_notebook(hide_banner=True)
cds = ColumnDataSource({"x": X[:, 0].tolist(), "y": X[:, 1].tolist(), "color": color})
moon_plot = plots.scatter_plot(
plot_sources=[cds],
tooltips=[[("x", "@y{0.000}"), ("y", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "X",
"y_axis_label": "Y",
"title": f"Binary classification data",
},
plot_kwargs={"color": "color"},
)
show(moon_plot)
To do well on this model, we will need to explicitly model the non-linearity of this data.
Toy Modelβ
But first let's use a simpler example to use all the machinery at play, and start with a very simple model.
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
@bm.random_variable
def mu():
return dist.Normal(0., 1.)
@bm.random_variable
def x():
return dist.Normal(mu(), 1.)
To approximate mu
we will create a distribution of a Normal distribution and
make the parameters learnable.
@bm.param
def phi():
return torch.zeros(2)
@bm.random_variable
def q_mu():
phi_mean, phi_sd = phi()
softplus = torch.nn.Softplus()
return dist.Normal(phi_mean, softplus(phi_sd))
world = VariationalInfer(
queries_to_guides={mu(): q_mu()},
observations={x(): 5.},
optimizer=lambda params: optim.Adam(params, lr=1e-1),
).infer(
num_steps=100,
num_samples=5,
discrepancy_fn=kl_reverse,
mc_approx=monte_carlo_approximate_reparam,
)
world.get_variable(q_mu()).distribution
In this particular setting, there is a closed-form solution for the posterior in terms of conjugate priors, so we can compare our variational distribution to the true posterior.
true_posterior = dist.Normal(2.5, torch.tensor(0.5).sqrt())
true_posterior
We also compare against a NUTS sampler which is asymptotically guaranteed to converge to the true posterior.
nuts_samples = bm.GlobalNoUTurnSampler().infer(
[mu()],
{x(): torch.tensor(5.)},
num_samples=2000,
num_chains=1,
)
nuts_samples[mu()].std()
As you can see, the mean is of is reasonably approximated but it has less success approximating the standard deviation.
Bayesian Neural Network modelβ
We will model this problem using a Bayesian neural network (BNN). A BNN is a probabilistic model where we model the uncertainty of the weights of the neural network as latent variables in our model.
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
class BayesianNeuralNetwork:
def __init__(self, X, Y, hidden_size):
self.X = X
self.Y = Y
self.hidden_size = hidden_size
@bm.random_variable
def input_layer(self):
return dist.Normal(0., 1.).expand((self.X.shape[1], self.hidden_size))
@bm.random_variable
def hidden_layer(self):
return dist.Normal(0., 1.).expand((self.hidden_size, self.hidden_size))
@bm.random_variable
def output_layer(self):
return dist.Normal(0., 1.).expand((self.hidden_size,))
@bm.random_variable
def forward(self):
y1 = torch.tanh(torch.mm(self.X, self.input_layer()))
y2 = torch.tanh(torch.mm(y1, self.hidden_layer()))
y3 = torch.sigmoid(torch.matmul(y2, self.output_layer()))
return dist.Bernoulli(y3)
We create queries and observations just as we did for MCMC inference
X = torch.tensor(X, dtype=torch.float)
Y = torch.tensor(Y, dtype=torch.float)
nn = BayesianNeuralNetwork(X, Y, 10)
queries = [nn.input_layer(), nn.hidden_layer(), nn.output_layer()]
observations = {nn.forward(): Y}
The easiest way to get started using the variational inference method is using ADVI. This will automatically select a guide distribution to match up with each of your queries.
bnn_world = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=4e-2),
).infer(
num_steps=8000,
num_samples=1,
discrepancy_fn=kl_reverse,
)
Notice, how quickly fitting this model took to run. Modify the above cell to use
bm.GlobalNoUTurnSampler()
instead of ADVI
to really appreciate the speed difference.
Monitoring and diagnosing model inferred with Variational Inferenceβ
We can plot the divergence over time and to make sure the optimization didn't get stuck. The divergence measures the difference the approximate distribution we are learning and the true distribution we are trying to approximate.
vi = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=1e-2),
)
num_steps = 1000
losses = []
for i in tqdm(range(num_steps)):
loss = vi.step()
losses.append(loss.item())
vi.initialize_world()
output_notebook(hide_banner=True)
cds = ColumnDataSource({"x": range(num_steps), "y": losses})
elbo_plot = plots.line_plot(
plot_sources=[cds],
tooltips=[[("ELBO", "@y{0.000}"), ("step", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "steps",
"y_axis_label": "ELBO",
"title": f"Divergence between model and guide programs",
},
plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(elbo_plot)
We can also check the convergence of the parameters
vi_world = ADVI(
queries=queries,
observations=observations,
optimizer=lambda params: optim.Adam(params, lr=1e-2),
)
num_steps = 5000
vals = []
for i in tqdm(range(num_steps)):
vi_world.step()
param = list(vi_world.params.values())[0]
vals.append(param[0, 0].item())
output_notebook(hide_banner=True)
cds = ColumnDataSource({"x": range(num_steps), "y": vals})
param_plot = plots.line_plot(
plot_sources=[cds],
tooltips=[[("param", "@y{0.000}"), ("step", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "steps",
"y_axis_label": "param",
"title": f"Value of param during VI",
},
plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(param_plot)
We need to introduce a helper method. The default method for accessing the guide distribution associated with each model variable is quite long. To make it easier to understand what we are doing when we generate predictions, we are adding the following shorthand.
def d(x):
return bnn_world.get_guide_distribution(x)
We create posterior predictive samples by using the learned distributions to generate
observations. This is just the forward
method from our original model with some minor
changes to enable batch sampling.
def predictions(X, samples=100):
il = d(nn.input_layer()).expand((samples,-1, -1)).sample()
y1 = torch.tanh(torch.matmul(X, il))
hl = d(nn.hidden_layer()).expand((samples, -1, -1)).sample()
y2 = torch.tanh(torch.matmul(y1, hl))
ol = d(nn.output_layer()).expand((samples, -1)).sample()
y3 = torch.sigmoid(torch.einsum('bij,bj->bi', y2, ol))
return dist.Bernoulli(y3).sample()
We visualise our predictions to show that the confidence of the prediction decreases as we approach the boundary.
import numpy as np
x_points = 100
y_points = 100
x = np.linspace(-3, 3, x_points)
y = np.linspace(-3, 3, y_points)
xx, yy = np.meshgrid(x, y)
grid_2d = torch.tensor([xx, yy], dtype=torch.float).reshape(2, -1).T
preds = predictions(grid_2d).mean(axis=0).reshape(x_points, y_points)
from bokeh.plotting import figure, output_file, show
p = figure(width=400, height=400)
p.x_range.range_padding = p.y_range.range_padding = 0
p.image(image=[preds.numpy()], x=-2, y=-2, dw=5, dh=4, palette="PRGn11", level="image")
p.grid.grid_line_width = 0
p.circle(X[Y==0,0].numpy(), X[Y==0,1].numpy(), color="yellow")
p.circle(X[Y==1,0].numpy(), X[Y==1,1].numpy(), color="cyan")
show(p)
As we can see, the posterior inference is effective and not just separating the classes, but also highlighting where we uncertainty within our model.