# Variational inference in Bean Machine

Bean Machine also includes support for Variational Inference similar to Pyro. In this tutorial we will be exploring settings and scenarios where one might benefit from using a variational inference approach as opposed to the MCMC methods we often use by default.

## Learning Outcomes​

• Create and run stochastic variational inference on a Bean Machine model
• To understand which probabilistic models most benefit from a VI approach
• To diagnose and debug problems with VI models

## Prerequisties​

We will be using the following packages within this tutorial.

• arviz and bokeh for interactive visualizations; and
• pandas for data manipulation.
# Install Bean Machine in Colab if using Colab.import sysif "google.colab" in sys.modules and "beanmachine" not in sys.modules:    !pip install beanmachine
import beanmachine.ppl as bmimport torchimport torch.distributions as distimport torch.optim as optimfrom beanmachine.ppl.inference.vi.variational_infer import (    kl_reverse,    monte_carlo_approximate_reparam,    VariationalInfer,)from beanmachine.ppl.inference.vi.autoguide import ADVI
from beanmachine.tutorials.utils import plotsfrom bokeh.io import output_notebookfrom bokeh.models import ColumnDataSourcefrom bokeh.palettes import Colorblind3from bokeh.plotting import gridplot, showfrom IPython.display import Markdownfrom tqdm import tqdm

## What is Variational Inference?​

In sampling-based inference algorithms, we take define a probability distribution represening model $p(z | y)$ with latent variables $z$ and observed data $y$. Because sampling from most distributions is challenging to do directly, we then use various techniques to generate samples.

Alternatively, Variational Inference is a class of inference algorithms which take a parameterised family of surrogate probability distributions $q_\phi(z)$ and uses optimisation to find the best parameters $\phi$ which are closest to the posterior distribution $p(z | y)$.

The main disadvantage of this class of algorithms is they struggle in learning a distribution that behaves the same in the tails as the true distribution, as well as underestimating the variance in the true distribution. The main advantage as we will see is the algorithms themselves are substantially faster.

## Data​

We will demonstrate how you might use variational inference using a toy dataset which generates overlapping half-circles.

from sklearn.datasets import make_moonsX, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)color = ['yellow' if y == 0 else 'cyan' for y in Y]
output_notebook(hide_banner=True)cds = ColumnDataSource({"x": X[:, 0].tolist(), "y": X[:, 1].tolist(), "color": color})moon_plot = plots.scatter_plot(    plot_sources=[cds],    tooltips=[[("x", "@y{0.000}"), ("y", "@x{0.000}")]],    figure_kwargs={        "x_axis_label": "X",        "y_axis_label": "Y",        "title": f"Binary classification data",    },    plot_kwargs={"color": "color"},)show(moon_plot)

To do well on this model, we will need to explicitly model the non-linearity of this data.

## Toy Model​

But first let's use a simpler example to use all the machinery at play, and start with a very simple model.

import beanmachine.ppl as bmimport torchimport torch.distributions as dist@bm.random_variabledef mu():    return dist.Normal(0., 1.)@bm.random_variabledef x():    return dist.Normal(mu(), 1.)

To approximate mu we will create a $guide$ distribution of a Normal distribution and make the parameters learnable.

@bm.paramdef phi():    return torch.zeros(2)@bm.random_variabledef q_mu():    phi_mean, phi_sd = phi()    softplus = torch.nn.Softplus()    return dist.Normal(phi_mean, softplus(phi_sd))
world = VariationalInfer(    queries_to_guides={mu(): q_mu()},    observations={x(): 5.},    optimizer=lambda params: optim.Adam(params, lr=1e-1),).infer(    num_steps=100,    num_samples=5,    discrepancy_fn=kl_reverse,    mc_approx=monte_carlo_approximate_reparam,)world.get_variable(q_mu()).distribution
Out:
  0%|          | 0/100 [00:00<?, ?it/s]
Out:
Normal(loc: 2.4571237564086914, scale: 0.6683022379875183)

In this particular setting, there is a closed-form solution for the posterior in terms of conjugate priors, so we can compare our variational distribution to the true posterior.

true_posterior = dist.Normal(2.5, torch.tensor(0.5).sqrt())true_posterior
Out:
Normal(loc: 2.5, scale: 0.7071067690849304)

We also compare against a NUTS sampler which is asymptotically guaranteed to converge to the true posterior.

nuts_samples = bm.GlobalNoUTurnSampler().infer(    [mu()],    {x(): torch.tensor(5.)},    num_samples=2000,    num_chains=1,)
Out:
Samples collected:   0%|          | 0/3000 [00:00<?, ?it/s]
nuts_samples[mu()].std()
Out:
tensor(0.6904)

As you can see, the mean is of $\mu$ is reasonably approximated but it has less success approximating the standard deviation.

## Bayesian Neural Network model​

We will model this problem using a Bayesian neural network (BNN). A BNN is a probabilistic model where we model the uncertainty of the weights of the neural network as latent variables in our model.

import beanmachine.ppl as bmimport torchimport torch.distributions as distclass BayesianNeuralNetwork:    def __init__(self, X, Y, hidden_size):        self.X = X        self.Y = Y        self.hidden_size = hidden_size            @bm.random_variable    def input_layer(self):        return dist.Normal(0., 1.).expand((self.X.shape, self.hidden_size))        @bm.random_variable    def hidden_layer(self):        return dist.Normal(0., 1.).expand((self.hidden_size, self.hidden_size))        @bm.random_variable    def output_layer(self):        return dist.Normal(0., 1.).expand((self.hidden_size,))        @bm.random_variable    def forward(self):        y1 = torch.tanh(torch.mm(self.X, self.input_layer()))        y2 = torch.tanh(torch.mm(y1, self.hidden_layer()))        y3 = torch.sigmoid(torch.matmul(y2, self.output_layer()))        return dist.Bernoulli(y3)

We create queries and observations just as we did for MCMC inference

X = torch.tensor(X, dtype=torch.float)Y = torch.tensor(Y, dtype=torch.float)nn = BayesianNeuralNetwork(X, Y, 10)queries = [nn.input_layer(), nn.hidden_layer(), nn.output_layer()]observations = {nn.forward(): Y}

The easiest way to get started using the variational inference method is using ADVI. This will automatically select a guide distribution to match up with each of your queries.

bnn_world = ADVI(    queries=queries,    observations=observations,    optimizer=lambda params: optim.Adam(params, lr=4e-2),).infer(    num_steps=8000,    num_samples=1,    discrepancy_fn=kl_reverse,)
Out:
  0%|          | 0/8000 [00:00<?, ?it/s]

Notice, how quickly fitting this model took to run. Modify the above cell to use bm.GlobalNoUTurnSampler() instead of ADVI to really appreciate the speed difference.

### Monitoring and diagnosing model inferred with Variational Inference​

We can plot the divergence over time and to make sure the optimization didn't get stuck. The divergence measures the difference the approximate distribution we are learning and the true distribution we are trying to approximate.

vi = ADVI(    queries=queries,    observations=observations,    optimizer=lambda params: optim.Adam(params, lr=1e-2),)num_steps = 1000losses = []for i in tqdm(range(num_steps)):    loss = vi.step()    losses.append(loss.item())vi.initialize_world()
Out:
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 202.75it/s]
Out:
<beanmachine.ppl.inference.vi.variational_world.VariationalWorld at 0x7fa371dd5ed0>
output_notebook(hide_banner=True)cds = ColumnDataSource({"x": range(num_steps), "y": losses})elbo_plot = plots.line_plot(    plot_sources=[cds],    tooltips=[[("ELBO", "@y{0.000}"), ("step", "@x{0.000}")]],    figure_kwargs={        "x_axis_label": "steps",        "y_axis_label": "ELBO",        "title": f"Divergence between model and guide programs",    },    plot_kwargs={"line_width": 2, "hover_line_color": "orange"},)show(elbo_plot)

We can also check the convergence of the parameters

vi_world = ADVI(    queries=queries,    observations=observations,    optimizer=lambda params: optim.Adam(params, lr=1e-2),)num_steps = 5000vals = []for i in tqdm(range(num_steps)):    vi_world.step()    param = list(vi_world.params.values())    vals.append(param[0, 0].item())
Out:
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:22<00:00, 219.71it/s]
output_notebook(hide_banner=True)cds = ColumnDataSource({"x": range(num_steps), "y": vals})param_plot = plots.line_plot(    plot_sources=[cds],    tooltips=[[("param", "@y{0.000}"), ("step", "@x{0.000}")]],    figure_kwargs={        "x_axis_label": "steps",        "y_axis_label": "param",        "title": f"Value of param during VI",    },    plot_kwargs={"line_width": 2, "hover_line_color": "orange"},)show(param_plot)

We need to introduce a helper method. The default method for accessing the guide distribution associated with each model variable is quite long. To make it easier to understand what we are doing when we generate predictions, we are adding the following shorthand.

def d(x):    return bnn_world.get_guide_distribution(x)

We create posterior predictive samples by using the learned distributions to generate observations. This is just the forward method from our original model with some minor changes to enable batch sampling.

def predictions(X, samples=100):    il = d(nn.input_layer()).expand((samples,-1, -1)).sample()    y1 = torch.tanh(torch.matmul(X, il))    hl = d(nn.hidden_layer()).expand((samples, -1, -1)).sample()    y2 = torch.tanh(torch.matmul(y1, hl))    ol = d(nn.output_layer()).expand((samples, -1)).sample()    y3 = torch.sigmoid(torch.einsum('bij,bj->bi', y2, ol))    return dist.Bernoulli(y3).sample()

We visualise our predictions to show that the confidence of the prediction decreases as we approach the boundary.

import numpy as npx_points = 100y_points = 100x = np.linspace(-3, 3, x_points)y = np.linspace(-3, 3, y_points)xx, yy = np.meshgrid(x, y)grid_2d = torch.tensor([xx, yy], dtype=torch.float).reshape(2, -1).Tpreds = predictions(grid_2d).mean(axis=0).reshape(x_points, y_points)from bokeh.plotting import figure, output_file, showp = figure(width=400, height=400)p.x_range.range_padding = p.y_range.range_padding = 0p.image(image=[preds.numpy()], x=-2, y=-2, dw=5, dh=4, palette="PRGn11", level="image")p.grid.grid_line_width = 0p.circle(X[Y==0,0].numpy(), X[Y==0,1].numpy(), color="yellow")p.circle(X[Y==1,0].numpy(), X[Y==1,1].numpy(), color="cyan")show(p)
Out:
/home/zv/upstream/miniconda3/envs/bean-machine/lib/python3.7/site-packages/ipykernel_launcher.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:204.)  import sys