# Neal's funnel

## Tutorial: **Neal's funnel**β

This tutorial demonstrates modeling and running inference on the so-called Neal's funnel model in Bean Machine.

Neal's funnel has proven difficult-to-handle for classical inference methods. This tutorial demonstrates how to overcome this by using second-order gradient methods in Bean Machine. It also demonstrates how to implement models with factors in Bean Machine through custom distributions.

## Problemβ

Neal's funnel is a synthetic model that is fairly simple, but has proven challenging for automatic inference engines to handle due to its unusual geometry. This model has an unfavorable, exponential geometry in one direction, and a narrow "funnel" bending into that direction.

## Prerequisitesβ

Let's model this in Bean Machine! Import the Bean Machine library and some fundamental PyTorch classes.

`# Install Bean Machine in Colab if using Colab.`

import sys

if "google.colab" in sys.modules and "beanmachine" not in sys.modules:

!pip install beanmachine

`import math`

import os

import warnings

import arviz as az

import beanmachine.ppl as bm

import matplotlib as mpl

import matplotlib.pyplot as plt

import torch

import torch.distributions as dist

from beanmachine.ppl.inference.bmg_inference import BMGInference

from IPython.display import Markdown

The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for reproducibility.

`# Eliminate excess UserWarnings from Python.`

warnings.filterwarnings("ignore")

# Plotting settings

plt.rc("figure", figsize=[8, 6])

plt.rc("font", size=14)

plt.rc("lines", linewidth=2.5)

# Manual seed

torch.manual_seed(12)

# Other settings for the notebook.

smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

## Modelβ

Neal's funnel is defined mathematically as follows:

- $z\sim\mathcal{N}(0,3)$
- $x\sim\mathcal{N}(0,e^{z/2})$

Let's visualize the model's density. To do this, we recognize that the joint density is factored as follows:

- $P(z,x)=\mathcal{N}(x;0,e^{z/2})\cdot\mathcal{N}(z;0,3)$

`xs, zs = torch.meshgrid(`

torch.arange(-50, 50, 0.1),

torch.arange(-15.0, 15.0, 0.1),

)

density = (

dist.Normal(0.0, (zs / 2.0).exp()).log_prob(xs).exp()

* dist.Normal(0.0, 3.0).log_prob(zs).exp()

)

`plt.contourf(xs, zs, density, levels=[0.0001] + torch.linspace(0.001, 0.1, 10).tolist())`

plt.xlabel("x")

plt.ylabel("z")

plt.colorbar();

Plotting the log density is usually easier to visualize and reason about.

`plt.contourf(xs, zs, density.log(), levels=range(-10, 0))`

plt.xlabel("x")

plt.ylabel("z")

plt.colorbar();

As we can see, the funnelβs neck is particularly sharp because of the exponential function applied to $z$. The density decays exponentially the farther that $x$ deviates from $0$. This makes it challenging to learn a good scale for proposal updates. Let's go about modeling this in Bean Machine!

Compared to many of the other tutorials, this one is more of a contrived model. In the
typical problem setup found in other tutorials, there are ground-truth values that we're
trying to infer distributions about based on observed data. In this case, however, we're
trying to exactly replicate a ground-truth *distribution*, by using the mechanics of the
Bean Machine inference engine to guide the sampling process.

Since Neal's funnel describes a mathematical relationship instead of a generative process, we'll have to reframe it into a generative process in order to run inference on it. We'll do this as follows:

- Sample priors for $z$ and $x$.
- Imagine weighting the probabilities of $z$ and $x$ according to how likely they are
under the true Neal's funnel model. We can do this by imagining we're flipping a
coin, where the probability of it landing heads is the probability of drawing that
$z$ and $x$ from the Neal's funnel model, but where we've actually
*observed*it to be heads. - Later, we will inform the inference engine that we observed heads. This will cause the engine to find values for $z$ and $x$ that are consistent with samples from the true Neal's funnel posterior β since those are the samples that would have resulted in the observed "heads" from our coin flip!

A few notes for advanced readers (feel free to skip over these):

- In the above statistical model, we already
*had*definitions for $z\sim\mathcal{N}(0,3)$ and $x\sim\mathcal{N}(0,e^{z/2})$. We are free to reuse these definitions as priors. However, that's giving inference an unfair advantage, since our priors exactly match our posterior. Instead, in this tutorial we will choose non-informative priors. - It is common to refer to this coin-flipping approach as a "factor". Traditionally, PPLs have been implemented by weighting a particular run of inference according to the log probability of that run of inference. We're exactly doing that in this model β based on a particular draw of $z$ and $x$, we're weighting that overall run by the probability that those values would have been sampled from the true Neal's funnel posterior.

We can implement this model in Bean Machine by defining random variable objects with the
`@bm.random_variable`

decorator. These functions behave differently than ordinary Python
functions.

`@bm.random_variable`

functions:- They must return PyTorch
`Distribution`

objects. - Though they return distributions, callees actually receive
*samples*from the distribution. The machinery for obtaining samples from distributions is handled internally by Bean Machine. - Inference runs the model through many iterations. During a particular inference iteration, a distinct random variable will correspond to exactly one sampled value:
**calls to the same random variable function with the same arguments will receive the same sampled value within one inference iteration**. This makes it easy for multiple components of your model to refer to the same logical random variable. - Consequently, to define distinct random variables that correspond to different sampled values during a particular inference iteration, an effective practice is to add a dummy "indexing" parameter to the function. Distinct random variables can be referred to with different values for this index.
- Please see the documentation for more information about this decorator.

`@bm.random_variable`

def z():

"""

An uninformative (flat) prior for z.

"""

# TODO(tingley): Replace with Flat once it's part of the framework.

return dist.Normal(0, 10000)

@bm.random_variable

def x():

"""

An uninformative (flat) prior for x.

"""

# TODO(tingley): Replace with Flat once it's part of the framework.

return dist.Normal(0, 10000)

@bm.random_variable

def neals_funnel_coin_flip():

"""

Flip a "coin", which is heads with probability equal to the probability

of drawing z and x from the true Neal's funnel posterior.

"""

return dist.Bernoulli(

(

dist.Normal(0.0, (z() / 2.0).exp()).log_prob(x())

+ dist.Normal(0.0, 3.0).log_prob(z())

).exp()

)

## Inferenceβ

Inference is the process of combining *model* with *data* to obtain *insights*, in the
form of probability distributions over values of interest. Bean Machine offers a
powerful and general inference framework to enable fitting arbitrary models to data.

As discussed in the previous section, we'll pretend that we've observed heads when flipping a coin whose heads rate is weighted according to how likely the z and x values were to be drawn from the true Neal's funnel posterior. Let's set that up right now.

Our inference algorithms expect observations in the form of a dictionary. This
dictionary should consist of `@bm.random_variable`

invocations as keys, and tensor data
as values.

`observations = {neals_funnel_coin_flip(): torch.tensor(1.0)}`

Next, we'll run inference on the model and observations.

Since this model is comprised entirely of differentiable random variables, we'll make use of the Newtonian Monte Carlo (NMC) inference method. NMC is a second-order method, which uses the Hessian to automatically scale the step size in each dimension. The hope is that this inference method will take the exponential growth rate of Neal's funnel into account, and explore the entire posterior surface, including the neck of the funnel. Check out the documentation for more information on NMC.

Running inference consists of a few arguments:

Name | Usage |
---|---|

`queries` | A list of @bm.random_variable targets to fit posterior distributions for. |

`observations` | The Dict of observations we built up, above. |

`num_samples` | Number of samples to build up distributions for the values listed in queries. |

`num_chains` | Number of separate inference runs to use. Multiple chains can verify inference ran correctly. |

Let's run inference:

`num_samples = 2 if smoke_test else 1000`

num_chains = 1 if smoke_test else 4

single_site_nmc_samples = bm.SingleSiteNewtonianMonteCarlo().infer(

queries=[z(), x()],

observations=observations,

num_samples=num_samples,

num_chains=num_chains,

)

## Analysisβ

`samples`

now contains our inference results.

`z_marginal = single_site_nmc_samples[z()].flatten().detach()`

x_marginal = single_site_nmc_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")

Let's plot our inferred posterior, along with the marginal distributions for $z$ and $x$.

`grid = mpl.gridspec.GridSpec(4, 4)`

plt.subplot(grid[1:, :3])

plt.contour(

xs.numpy(),

zs.numpy(),

density.log().numpy(),

levels=range(-10, 0),

zorder=0,

)

plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)

plt.xlabel("x")

plt.ylabel("z")

plt.xlim(-50, 50)

plt.ylim(-15, 15)

plt.subplot(grid[0, :3])

plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))

plt.ylabel("density")

plt.xlim(-50, 50)

plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])

zs_marginal = torch.linspace(-10, 10, 100)

plt.hist(

z_marginal.numpy(),

bins=30,

density=True,

range=(-15, 15),

orientation="horizontal",

)

plt.plot(

dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),

zs_marginal.numpy(),

color="black",

)

plt.xlabel("density")

plt.ylim(-15, 15)

plt.gca().axes.get_yaxis().set_ticklabels([]);

The samples appear to match the correct posterior well! Inference also seems to have successfully fit the $z$ marginal, which is analytically known.

Bean Machine provides a Diagnostics package that provides helpful statistics about the result of the inference algorithm. We can query this information as follows:

`summary_df = az.summary(single_site_nmc_samples.to_inference_data())`

Markdown(summary_df.to_markdown())

mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|

x() | -0.099 | 7.478 | -7.908 | 6.342 | 0.13 | 0.339 | 3202 | 261 | 1.02 |

z() | -0.079 | 2.952 | -6.228 | 4.837 | 0.256 | 0.181 | 136 | 109 | 1.03 |

$z$ and $x$ have means very close to zero, which is expected. $x$ as a much higher standard deviation than $z$, which is expected as well. The quantiles give useful insights into the spread of the two variables.

The diagnostics output shows two diagnostic statistics:
$\hat{R}$ (`r_hat`

) and
$N_\text{eff}$ (effective sample
size, `n_eff`

).

- $\hat{R}\in[1,\infty)$ summarizes how effective inference was at converging on the correct posterior distribution for a particular random variable. It uses information from all chains run in order to assess whether inference had a good understanding of the distribution or not. Values very close to zero indicate that all chains discovered similar distributions for a particular random variable. We do not recommend using inference results where $\hat{R}>1.1$, as inference may not have converged. In that case, you may want to run inference for more samples.
- $N_\text{eff}\in[1,\texttt{num}\_\texttt{samples}]$ summarizes how independent
posterior samples are from one another. Although inference was run for
`num_samples`

iterations, it's possible that those samples were very similar to each other (due to the way inference is implemented), and may not each be representative of the full posterior space. Larger numbers are better here, and if your particular use case calls for a certain number of samples to be considered, you should ensure that $N_\text{eff}$ is at least that large.

In this case, $\hat{R}$ and $N_\text{eff}$ seem to have acceptable values.

Bean Machine can also plot diagnostical information to assess model fit. Let's take a look:

`bm.Diagnostics(single_site_nmc_samples).plot(display=True);`

The diagnostics output shows two diagnostic plots for individual random variables: trace plots and autocorrelation plots.

Trace plots are simply a time series of values assigned to random variables over each iteration of inference. The concrete values assigned are usually problem-specific. However, it's important that these values are "mixing" well over time. This means that they don't tend to get stuck in one region for large periods of time, and that each of the chains ends up exploring the same space as the other chains throughout the course of inference.

Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). (We compute autocorrelation approximately, so it may sometimes exceed these bounds.) In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works.

From the autocorrelation plots, we see the absolute magnitude of autocorrelation tends to be quite small. The trace plots are a little more suspicious, especially for $x$. Let's take a deeper look at the spike in chain 3. Here, if we look at the corresponding trace plot for $z$ at this time, we see that it is exploring large outlier values for $z$, around 6 or greater. We expect $x$ to have high variance when $z$ is large, so this is as expected.

This concludes the main Neal's funnel tutorial! However, we'll also walk through the same model using Hamiltonian Monte Carlo inference to compare relative performance.

## Appendix: Adaptive Hamiltonian Monte Carloβ

Hamiltonian Monte Carlo is a classic gradient-based inference method. HMC proceeds by taking a sequence of steps towards the gradient, but with some injected noise, before proposing a candidate sample. Bean Machine provides an implementation of HMC that we can use to fit Neal's funnel. Bean Machine uses an adaptive version of HMC by default to select some parameter values. For an in-depth discussion of this inference method, check out our documentation on HMC.

Compared to single-site inference methods, this version of HMC is a global infernece
method. That means that it proposes new values for all random variables in the model at
once, and accepts or rejects them jointly. Bean Machine does also provide a single-site
variant of this method called `SingleSiteHamiltonianMonteCarlo`

. Although we will not
cover that method in this tutorial, it is a useful component for use within
`CompositionalInference`

.

`hmc_samples = bm.GlobalHamiltonianMonteCarlo(trajectory_length=0.1).infer(`

queries=[z(), x()],

observations=observations,

num_samples=num_samples,

num_chains=num_chains,

)

`z_marginal = hmc_samples[z()].flatten().detach()`

x_marginal = hmc_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")

`grid = mpl.gridspec.GridSpec(4, 4)`

plt.subplot(grid[1:, :3])

plt.contour(

xs.numpy(),

zs.numpy(),

density.log().numpy(),

levels=range(-10, 0),

zorder=0,

)

plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)

plt.xlabel("x")

plt.ylabel("z")

plt.xlim(-50, 50)

plt.ylim(-15, 15)

plt.subplot(grid[0, :3])

plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))

plt.ylabel("density")

plt.xlim(-50, 50)

plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])

zs_marginal = torch.linspace(-10, 10, 100)

plt.hist(

z_marginal.numpy(),

bins=60,

density=True,

range=(-15, 15),

orientation="horizontal",

)

plt.plot(

dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),

zs_marginal.numpy(),

color="black",

)

plt.xlabel("density")

plt.ylim(-15, 15)

plt.gca().axes.get_yaxis().set_ticklabels([]);

As we can see, HMC isn't able to fully explore values for $z$, which prevents it from correctly recovering the posterior. We can confirm that HMC hasn't mixed well by examining the $\hat{R}$ values:

`hmc_summary_df = az.summary(hmc_samples.to_inference_data())`

Markdown(hmc_summary_df.to_markdown())

mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|

x() | 0.037 | 1.166 | -1.472 | 3.72 | 0.21 | 0.265 | 41 | 16 | 1.26 |

z() | -0.909 | 2.576 | -5.982 | 3.268 | 0.67 | 0.484 | 15 | 28 | 1.19 |

And the unhealthy trace and autocorrelation plots:

`bm.Diagnostics(hmc_samples).plot(display=True);`

## Appendix: Metropolis-Adjusted Langevin Algorithmβ

The Metropolis-Adjusted Langevin Algorithm (MALA) is a special case of HMC, where only a single gradient step is taken before proposing a sample. Let's try it on Neal's funnel.

`single_site_mala_samples = bm.SingleSiteHamiltonianMonteCarlo(`

trajectory_length=0.5,

initial_step_size=0.5,

adapt_step_size=False,

).infer(

queries=[z(), x()],

observations=observations,

num_samples=2 * num_samples,

num_chains=num_chains,

)

`z_marginal = single_site_mala_samples[z()].flatten().detach()`

x_marginal = single_site_mala_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")

`grid = mpl.gridspec.GridSpec(4, 4)`

plt.subplot(grid[1:, :3])

plt.contour(

xs.numpy(),

zs.numpy(),

density.log().numpy(),

levels=range(-10, 0),

zorder=0,

)

plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)

plt.xlabel("x")

plt.ylabel("z")

plt.xlim(-50, 50)

plt.ylim(-15, 15)

plt.subplot(grid[0, :3])

plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))

plt.ylabel("density")

plt.xlim(-50, 50)

plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])

zs_marginal = torch.linspace(-10, 10, 100)

plt.hist(

z_marginal.numpy(),

bins=60,

density=True,

range=(-15, 15),

orientation="horizontal",

)

plt.plot(

dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),

zs_marginal.numpy(),

color="black",

)

plt.xlabel("density")

plt.ylim(-15, 15)

plt.gca().axes.get_yaxis().set_ticklabels([]);

Interestingly, MALA seems capable of fitting Neal's funnel better than HMC!

## BMGInferenceβ

Bean Machine Graph (BMG) Inference is an experimental feature of the Bean Machine
framework that aims to deliver higher performance for specialized models. The model in
this tutorial is *almost* but not quite in the subset of the language supported by
BMGInference. In particular, `log_prob`

is not yet supported. But we can rewrite our
model to avoid it the need for `log_prob`

as follows:

`def normal_log_prob(mu, sigma, x):`

z = (x - mu) / sigma

return -math.log(sigma) + (-1.0 / 2.0) * math.log(2.0 * math.pi) - (z**2.0 / 2.0)

@bm.random_variable

def neals_funnel_coin_flip_bmg():

"""

Flip a "coin", which is heads with probability equal to the probability

of drawing z and x from the true Neal's funnel posterior.

"""

return dist.Bernoulli(

(

normal_log_prob(0.0, 3.0, z())

+ normal_log_prob(0.0, (z() / 2.0).exp(), x())

).exp()

)

Now we can use this definition in our model, and in particular, in our notion of observation:

`observations = {neals_funnel_coin_flip_bmg(): torch.tensor(1.0)}`

To run our model using BMGInference, the only change needed is the following:

`%%time`

single_site_bmg_samples = BMGInference().infer(

queries=[z(), x()],

observations=observations,

num_samples=num_samples,

num_chains=num_chains,

)

Wall time numbers will naturally vary on different platforms, but with with these parameters (model, observations, queries, sample size, and number of chains) speedup on the author's machine is about 7x. Generally speaking, larger speedups are expected with larger sample sizes. More information about BMGInference can be found on the website in "Advanced" section of the documentation.

We can confirm that `BMGInference`

provides good accuracy by examining $\hat{R}$ values
and examining the marginal plots in the next two code cells.

`single_site_bmg_summary_df = az.summary(single_site_bmg_samples.to_inference_data())`

Markdown(single_site_bmg_summary_df.to_markdown())

mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|

x() | -0.058 | 5.746 | -8.347 | 7.75 | 0.102 | 0.174 | 3154 | 807 | 1 |

z() | -0.085 | 3.047 | -5.253 | 6.06 | 0.182 | 0.129 | 274 | 253 | 1.01 |

`z_marginal = single_site_bmg_samples[z()].flatten().detach()`

x_marginal = single_site_bmg_samples[x()].flatten().detach()

print(f"z_marginal: {z_marginal}\n" f"x_marginal: {x_marginal}")

grid = mpl.gridspec.GridSpec(4, 4)

plt.subplot(grid[1:, :3])

plt.contour(

xs.numpy(),

zs.numpy(),

density.log().numpy(),

levels=range(-10, 0),

zorder=0,

)

plt.scatter(x_marginal.numpy(), z_marginal.numpy(), alpha=0.25)

plt.xlabel("x")

plt.ylabel("z")

plt.xlim(-50, 50)

plt.ylim(-15, 15)

plt.subplot(grid[0, :3])

plt.hist(x_marginal.numpy(), bins=60, density=True, range=(-50, 50))

plt.ylabel("density")

plt.xlim(-50, 50)

plt.gca().axes.get_xaxis().set_ticklabels([])

plt.subplot(grid[1:, 3])

zs_marginal = torch.linspace(-10, 10, 100)

plt.hist(

z_marginal.numpy(),

bins=60,

density=True,

range=(-15, 15),

orientation="horizontal",

)

plt.plot(

dist.Normal(0, 3).log_prob(zs_marginal).exp().numpy(),

zs_marginal.numpy(),

color="black",

)

plt.xlabel("density")

plt.ylim(-15, 15)

plt.gca().axes.get_yaxis().set_ticklabels([]);