Skip to main content

Neal's funnel

Tutorial: Neal's funnel​

This tutorial demonstrates modeling and running inference on the so-called Neal's funnel model in Bean Machine.

Neal's funnel has proven difficult-to-handle for classical inference methods. This tutorial demonstrates how to overcome this by using second-order gradient methods in Bean Machine. It also demonstrates how to implement models with factors in Bean Machine through custom distributions.

Problem​

Neal's funnel is a synthetic model that is fairly simple, but has proven challenging for automatic inference engines to handle due to its unusual geometry. This model has an unfavorable, exponential geometry in one direction, and a narrow "funnel" bending into that direction.

Prerequisites​

Let's model this in Bean Machine! Import the Bean Machine library and some fundamental PyTorch classes.

# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import math
import os
import warnings

import arviz as az
import beanmachine.ppl as bm
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
from beanmachine.ppl.inference.bmg_inference import BMGInference
from IPython.display import Markdown

The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for reproducibility.

# Eliminate excess UserWarnings from Python.
warnings.filterwarnings("ignore")

# Plotting settings
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)

# Manual seed
torch.manual_seed(12)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

Model​

Neal's funnel is defined mathematically as follows:

  • z∼N(0,3)z\sim\mathcal{N}(0,3)
  • x∼N(0,ez/2)x\sim\mathcal{N}(0,e^{z/2})

Let's visualize the model's density. To do this, we recognize that the joint density is factored as follows:

  • P(z,x)=N(x;0,ez/2)β‹…N(z;0,3)P(z,x)=\mathcal{N}(x;0,e^{z/2})\cdot\mathcal{N}(z;0,3)
xs, zs = torch.meshgrid(
torch.arange(-50, 50, 0.1),
torch.arange(-15.0, 15.0, 0.1),
)
density = (
dist.Normal(0.0, (zs / 2.0).exp()).log_prob(xs).exp()
* dist.Normal(0.0, 3.0).log_prob(zs).exp()
)
plt.contourf(xs, zs, density, levels=[0.0001] + torch.linspace(0.001, 0.1, 10).tolist())
plt.xlabel("x")
plt.ylabel("z")
plt.colorbar();

Plotting the log density is usually easier to visualize and reason about.

plt.contourf(xs, zs, density.log(), levels=range(-10, 0))
plt.xlabel("x")
plt.ylabel("z")
plt.colorbar();

As we can see, the funnel’s neck is particularly sharp because of the exponential function applied to zz. The density decays exponentially the farther that xx deviates from 00. This makes it challenging to learn a good scale for proposal updates. Let's go about modeling this in Bean Machine!

Compared to many of the other tutorials, this one is more of a contrived model. In the typical problem setup found in other tutorials, there are ground-truth values that we're trying to infer distributions about based on observed data. In this case, however, we're trying to exactly replicate a ground-truth distribution, by using the mechanics of the Bean Machine inference engine to guide the sampling process.

Since Neal's funnel describes a mathematical relationship instead of a generative process, we'll have to reframe it into a generative process in order to run inference on it. We'll do this as follows:

  1. Sample priors for zz and xx.
  2. Imagine weighting the probabilities of zz and xx according to how likely they are under the true Neal's funnel model. We can do this by imagining we're flipping a coin, where the probability of it landing heads is the probability of drawing that zz and xx from the Neal's funnel model, but where we've actually observed it to be heads.
  3. Later, we will inform the inference engine that we observed heads. This will cause the engine to find values for zz and xx that are consistent with samples from the true Neal's funnel posterior β€” since those are the samples that would have resulted in the observed "heads" from our coin flip!

A few notes for advanced readers (feel free to skip over these):

  • In the above statistical model, we already had definitions for z∼N(0,3)z\sim\mathcal{N}(0,3) and x∼N(0,ez/2)x\sim\mathcal{N}(0,e^{z/2}). We are free to reuse these definitions as priors. However, that's giving inference an unfair advantage, since our priors exactly match our posterior. Instead, in this tutorial we will choose non-informative priors.
  • It is common to refer to this coin-flipping approach as a "factor". Traditionally, PPLs have been implemented by weighting a particular run of inference according to the log probability of that run of inference. We're exactly doing that in this model β€” based on a particular draw of zz and xx, we're weighting that overall run by the probability that those values would have been sampled from the true Neal's funnel posterior.

We can implement this model in Bean Machine by defining random variable objects with the @bm.random_variable decorator. These functions behave differently than ordinary Python functions.

Semantics for @bm.random_variable functions:
  • They must return PyTorch Distribution objects.
  • Though they return distributions, callees actually receive samples from the distribution. The machinery for obtaining samples from distributions is handled internally by Bean Machine.
  • Inference runs the model through many iterations. During a particular inference iteration, a distinct random variable will correspond to exactly one sampled value: calls to the same random variable function with the same arguments will receive the same sampled value within one inference iteration. This makes it easy for multiple components of your model to refer to the same logical random variable.
  • Consequently, to define distinct random variables that correspond to different sampled values during a particular inference iteration, an effective practice is to add a dummy "indexing" parameter to the function. Distinct random variables can be referred to with different values for this index.
  • Please see the documentation for more information about this decorator.
@bm.random_variable
def z():
"""
An uninformative (flat) prior for z.
"""
# TODO(tingley): Replace with Flat once it's part of the framework.
return dist.Normal(0, 10000)


@bm.random_variable
def x():
"""
An uninformative (flat) prior for x.
"""
# TODO(tingley): Replace with Flat once it's part of the framework.
return dist.Normal(0, 10000)


@bm.random_variable
def neals_funnel_coin_flip():
"""
Flip a "coin", which is heads with probability equal to the probability
of drawing z and x from the true Neal's funnel posterior.
"""
return dist.Bernoulli(
(
dist.Normal(0.0, (z() / 2.0).exp()).log_prob(x())
+ dist.Normal(0.0, 3.0).log_prob(z())
).exp()
)

Inference​

Inference is the process of combining model with data to obtain insights, in the form of probability distributions over values of interest. Bean Machine offers a powerful and general inference framework to enable fitting arbitrary models to data.

As discussed in the previous section, we'll pretend that we've observed heads when flipping a coin whose heads rate is weighted according to how likely the z and x values were to be drawn from the true Neal's funnel posterior. Let's set that up right now.

Our inference algorithms expect observations in the form of a dictionary. This dictionary should consist of @bm.random_variable invocations as keys, and tensor data as values.

observations = {neals_funnel_coin_flip(): torch.tensor(1.0)}

Next, we'll run inference on the model and observations.

Since this model is comprised entirely of differentiable random variables, we'll make use of the Newtonian Monte Carlo (NMC) inference method. NMC is a second-order method, which uses the Hessian to automatically scale the step size in each dimension. The hope is that this inference method will take the exponential growth rate of Neal's funnel into account, and explore the entire posterior surface, including the neck of the funnel. Check out the documentation for more information on NMC.

Running inference consists of a few arguments:

NameUsage
queriesA list of @bm.random_variable targets to fit posterior distributions for.
observationsThe Dict of observations we built up, above.
num_samplesNumber of samples to build up distributions for the values listed in queries.
num_chainsNumber of separate inference runs to use. Multiple chains can verify inference ran correctly.

Let's run inference:

num_samples = 2 if smoke_test else 1000
num_chains = 1 if smoke_test else 4

single_site_nmc_samples = bm.SingleSiteNewtonianMonteCarlo().infer(
queries=[z(), x()],
observations=observations,
num_samples=num_samples,
num_chains=num_chains,
)
Out:

Samples collected: 0%| | 0/1000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1000 [00:00<?, ?it/s]

Analysis​

samples now contains our inference results.

z_marginal = single_site_nmc_samples[z()].flatten().detach()
x_marginal = single_site_nmc_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")
Out:

z_marginal: tensor([-1.1416, -0.1679, -0.5534, ..., 3.4613, 3.4613, 3.4613])

x_marginal: tensor([-0.2330, 0.4234, -0.7036, ..., -0.6279, -7.6473, 3.4894])

Let's plot our inferred posterior, along with the marginal distributions for zz and xx.

grid = mpl.gridspec.GridSpec(4, 4)

plt.subplot(grid[1:, :3])
plt.contour(
xs.numpy(),
zs.numpy(),
density.log().numpy(),
levels=range(-10, 0),
zorder=0,
)
plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)
plt.xlabel("x")
plt.ylabel("z")
plt.xlim(-50, 50)
plt.ylim(-15, 15)

plt.subplot(grid[0, :3])
plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))
plt.ylabel("density")
plt.xlim(-50, 50)
plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])
zs_marginal = torch.linspace(-10, 10, 100)
plt.hist(
z_marginal.numpy(),
bins=30,
density=True,
range=(-15, 15),
orientation="horizontal",
)
plt.plot(
dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),
zs_marginal.numpy(),
color="black",
)
plt.xlabel("density")
plt.ylim(-15, 15)
plt.gca().axes.get_yaxis().set_ticklabels([]);

The samples appear to match the correct posterior well! Inference also seems to have successfully fit the zz marginal, which is analytically known.

Bean Machine provides a Diagnostics package that provides helpful statistics about the result of the inference algorithm. We can query this information as follows:

summary_df = az.summary(single_site_nmc_samples.to_inference_data())
Markdown(summary_df.to_markdown())
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
x()-0.0997.478-7.9086.3420.130.33932022611.02
z()-0.0792.952-6.2284.8370.2560.1811361091.03

zz and xx have means very close to zero, which is expected. xx as a much higher standard deviation than zz, which is expected as well. The quantiles give useful insights into the spread of the two variables.

The diagnostics output shows two diagnostic statistics: R^\hat{R} (r_hat) and NeffN_\text{eff} (effective sample size, n_eff).

  • R^∈[1,∞)\hat{R}\in[1,\infty) summarizes how effective inference was at converging on the correct posterior distribution for a particular random variable. It uses information from all chains run in order to assess whether inference had a good understanding of the distribution or not. Values very close to zero indicate that all chains discovered similar distributions for a particular random variable. We do not recommend using inference results where R^>1.1\hat{R}>1.1, as inference may not have converged. In that case, you may want to run inference for more samples.
  • Neff∈[1,num_samples]N_\text{eff}\in[1,\texttt{num}\_\texttt{samples}] summarizes how independent posterior samples are from one another. Although inference was run for num_samples iterations, it's possible that those samples were very similar to each other (due to the way inference is implemented), and may not each be representative of the full posterior space. Larger numbers are better here, and if your particular use case calls for a certain number of samples to be considered, you should ensure that NeffN_\text{eff} is at least that large.

In this case, R^\hat{R} and NeffN_\text{eff} seem to have acceptable values.

Bean Machine can also plot diagnostical information to assess model fit. Let's take a look:

bm.Diagnostics(single_site_nmc_samples).plot(display=True);
loading...
loading...
loading...
loading...

The diagnostics output shows two diagnostic plots for individual random variables: trace plots and autocorrelation plots.

  • Trace plots are simply a time series of values assigned to random variables over each iteration of inference. The concrete values assigned are usually problem-specific. However, it's important that these values are "mixing" well over time. This means that they don't tend to get stuck in one region for large periods of time, and that each of the chains ends up exploring the same space as the other chains throughout the course of inference.

  • Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). (We compute autocorrelation approximately, so it may sometimes exceed these bounds.) In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works.

From the autocorrelation plots, we see the absolute magnitude of autocorrelation tends to be quite small. The trace plots are a little more suspicious, especially for xx. Let's take a deeper look at the spike in chain 3. Here, if we look at the corresponding trace plot for zz at this time, we see that it is exploring large outlier values for zz, around 6 or greater. We expect xx to have high variance when zz is large, so this is as expected.

This concludes the main Neal's funnel tutorial! However, we'll also walk through the same model using Hamiltonian Monte Carlo inference to compare relative performance.

Appendix: Adaptive Hamiltonian Monte Carlo​

Hamiltonian Monte Carlo is a classic gradient-based inference method. HMC proceeds by taking a sequence of steps towards the gradient, but with some injected noise, before proposing a candidate sample. Bean Machine provides an implementation of HMC that we can use to fit Neal's funnel. Bean Machine uses an adaptive version of HMC by default to select some parameter values. For an in-depth discussion of this inference method, check out our documentation on HMC.

Compared to single-site inference methods, this version of HMC is a global infernece method. That means that it proposes new values for all random variables in the model at once, and accepts or rejects them jointly. Bean Machine does also provide a single-site variant of this method called SingleSiteHamiltonianMonteCarlo. Although we will not cover that method in this tutorial, it is a useful component for use within CompositionalInference.

hmc_samples = bm.GlobalHamiltonianMonteCarlo(trajectory_length=0.1).infer(
queries=[z(), x()],
observations=observations,
num_samples=num_samples,
num_chains=num_chains,
)
Out:

Samples collected: 0%| | 0/1500 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1500 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1500 [00:00<?, ?it/s]

Samples collected: 0%| | 0/1500 [00:00<?, ?it/s]

z_marginal = hmc_samples[z()].flatten().detach()
x_marginal = hmc_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")
Out:

z_marginal: tensor([-3.0901, -2.1569, -2.1951, ..., 0.5573, 0.3464, 0.3715])

x_marginal: tensor([-0.2174, 0.1300, -0.1934, ..., -0.3308, -0.2397, -0.0842])

grid = mpl.gridspec.GridSpec(4, 4)

plt.subplot(grid[1:, :3])
plt.contour(
xs.numpy(),
zs.numpy(),
density.log().numpy(),
levels=range(-10, 0),
zorder=0,
)
plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)
plt.xlabel("x")
plt.ylabel("z")
plt.xlim(-50, 50)
plt.ylim(-15, 15)

plt.subplot(grid[0, :3])
plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))
plt.ylabel("density")
plt.xlim(-50, 50)
plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])
zs_marginal = torch.linspace(-10, 10, 100)
plt.hist(
z_marginal.numpy(),
bins=60,
density=True,
range=(-15, 15),
orientation="horizontal",
)
plt.plot(
dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),
zs_marginal.numpy(),
color="black",
)
plt.xlabel("density")
plt.ylim(-15, 15)
plt.gca().axes.get_yaxis().set_ticklabels([]);

As we can see, HMC isn't able to fully explore values for zz, which prevents it from correctly recovering the posterior. We can confirm that HMC hasn't mixed well by examining the R^\hat{R} values:

hmc_summary_df = az.summary(hmc_samples.to_inference_data())
Markdown(hmc_summary_df.to_markdown())
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
x()0.0371.166-1.4723.720.210.26541161.26
z()-0.9092.576-5.9823.2680.670.48415281.19

And the unhealthy trace and autocorrelation plots:

bm.Diagnostics(hmc_samples).plot(display=True);
loading...
loading...
loading...
loading...

Appendix: Metropolis-Adjusted Langevin Algorithm​

The Metropolis-Adjusted Langevin Algorithm (MALA) is a special case of HMC, where only a single gradient step is taken before proposing a sample. Let's try it on Neal's funnel.

single_site_mala_samples = bm.SingleSiteHamiltonianMonteCarlo(
trajectory_length=0.5,
initial_step_size=0.5,
adapt_step_size=False,
).infer(
queries=[z(), x()],
observations=observations,
num_samples=2 * num_samples,
num_chains=num_chains,
)
Out:

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

z_marginal = single_site_mala_samples[z()].flatten().detach()
x_marginal = single_site_mala_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")
Out:

z_marginal: tensor([ 0.6981, 0.0460, -0.2532, ..., 2.0986, 1.8570, 1.0852])

x_marginal: tensor([ 0.0381, 0.1879, -0.6272, ..., 3.0555, 3.7887, 2.4106])

grid = mpl.gridspec.GridSpec(4, 4)

plt.subplot(grid[1:, :3])
plt.contour(
xs.numpy(),
zs.numpy(),
density.log().numpy(),
levels=range(-10, 0),
zorder=0,
)
plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)
plt.xlabel("x")
plt.ylabel("z")
plt.xlim(-50, 50)
plt.ylim(-15, 15)

plt.subplot(grid[0, :3])
plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))
plt.ylabel("density")
plt.xlim(-50, 50)
plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])
zs_marginal = torch.linspace(-10, 10, 100)
plt.hist(
z_marginal.numpy(),
bins=60,
density=True,
range=(-15, 15),
orientation="horizontal",
)
plt.plot(
dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),
zs_marginal.numpy(),
color="black",
)
plt.xlabel("density")
plt.ylim(-15, 15)
plt.gca().axes.get_yaxis().set_ticklabels([]);

Interestingly, MALA seems capable of fitting Neal's funnel better than HMC!

BMGInference​

Bean Machine Graph (BMG) Inference is an experimental feature of the Bean Machine framework that aims to deliver higher performance for specialized models. The model in this tutorial is almost but not quite in the subset of the language supported by BMGInference. In particular, log_prob is not yet supported. But we can rewrite our model to avoid it the need for log_prob as follows:

def normal_log_prob(mu, sigma, x):
z = (x - mu) / sigma
return -math.log(sigma) + (-1.0 / 2.0) * math.log(2.0 * math.pi) - (z**2.0 / 2.0)


@bm.random_variable
def neals_funnel_coin_flip_bmg():
"""
Flip a "coin", which is heads with probability equal to the probability
of drawing z and x from the true Neal's funnel posterior.
"""
return dist.Bernoulli(
(
normal_log_prob(0.0, 3.0, z())
+ normal_log_prob(0.0, (z() / 2.0).exp(), x())
).exp()
)

Now we can use this definition in our model, and in particular, in our notion of observation:

observations = {neals_funnel_coin_flip_bmg(): torch.tensor(1.0)}

To run our model using BMGInference, the only change needed is the following:

%%time

single_site_bmg_samples = BMGInference().infer(
queries=[z(), x()],
observations=observations,
num_samples=num_samples,
num_chains=num_chains,
)
Out:

0% 10 20 30 40 50 60 70 80 90 100%

|----|----|----|----|----|----|----|----|----|----|

***************************************************

CPU times: user 2.42 s, sys: 9.94 ms, total: 2.43 s

Wall time: 2.42 s

Wall time numbers will naturally vary on different platforms, but with with these parameters (model, observations, queries, sample size, and number of chains) speedup on the author's machine is about 7x. Generally speaking, larger speedups are expected with larger sample sizes. More information about BMGInference can be found on the website in "Advanced" section of the documentation.

We can confirm that BMGInference provides good accuracy by examining R^\hat{R} values and examining the marginal plots in the next two code cells.

single_site_bmg_summary_df = az.summary(single_site_bmg_samples.to_inference_data())
Markdown(single_site_bmg_summary_df.to_markdown())
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
x()-0.0585.746-8.3477.750.1020.17431548071
z()-0.0853.047-5.2536.060.1820.1292742531.01
z_marginal = single_site_bmg_samples[z()].flatten().detach()
x_marginal = single_site_bmg_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")

grid = mpl.gridspec.GridSpec(4, 4)

plt.subplot(grid[1:, :3])
plt.contour(
xs.numpy(),
zs.numpy(),
density.log().numpy(),
levels=range(-10, 0),
zorder=0,
)
plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)
plt.xlabel("x")
plt.ylabel("z")
plt.xlim(-50, 50)
plt.ylim(-15, 15)

plt.subplot(grid[0, :3])
plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))
plt.ylabel("density")
plt.xlim(-50, 50)
plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])
zs_marginal = torch.linspace(-10, 10, 100)
plt.hist(
z_marginal.numpy(),
bins=60,
density=True,
range=(-15, 15),
orientation="horizontal",
)
plt.plot(
dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),
zs_marginal.numpy(),
color="black",
)
plt.xlabel("density")
plt.ylim(-15, 15)
plt.gca().axes.get_yaxis().set_ticklabels([]);
Out:

z_marginal: tensor([-6.8999, -4.1096, -5.5291, ..., 1.4906, 1.4906, 1.3204])

x_marginal: tensor([-0.0360, 0.0126, 0.1043, ..., -0.5560, -1.7656, -2.2602])