Skip to main content

Analysis

Inference results are useful not only for learning posterior distributions, but for verifying that inference ran correctly. We'll cover common techniques for analyzing results in this section. As is the case for everything else in this Overview, the code for this section is available as a notebook on GitHub and Colab.

Results of Inference​

Bean Machine stores the results of inference in an object of type MonteCarloSamples. Internally, this class uses a dictionary to map random variables to PyTorch tensors of posterior samples. The class can be accessed like a dictionary, and there are additional wrapper methods to make function calls more explicit.

In the Inference section, we obtained results on the disease modeling example by running inference:

samples = bm.CompositionalInference().infer(
queries=[reproduction_rate()],
observations=observations,
num_samples=7000,
num_adaptive_samples=3000,
num_chains=4,
)
samples
Out:

<beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples>

Extracting Samples for a Specific Variable​

In order to perform inference on the random variable reproduction_rate(), we added it to the queries list. We can see that it, and no other random variable, is available in samples:

list(samples.keys())
Out:

[RVIdentifier(function=<function reproduction_rate>, arguments=())]

To extract the inference results for reproduction_rate(), we can use get_variable():

samples.get_variable(reproduction_rate())
Out:

tensor([[1.0000, 0.4386, 0.2751, ..., 0.2177, 0.2177, 0.2193],

[0.2183, 0.2183, 0.2184, ..., 0.2177, 0.2177, 0.2177],

[0.2170, 0.2180, 0.2183, ..., 0.2180, 0.2180, 0.2180],

[0.2180, 0.2180, 0.2172, ..., 0.2180, 0.2180, 0.2176]])

The result has shape 4Γ—70004 \times 7000, representing the 70007000 samples that we drew in each of the four chains of inference from the posterior distribution.

MonteCarloSamples supports convenient dictionary indexing syntax to obtain the same information:

samples[reproduction_rate()]
Out:

tensor([[1.0000, 0.4386, 0.2751, ..., 0.2177, 0.2177, 0.2193],

[0.2183, 0.2183, 0.2184, ..., 0.2177, 0.2177, 0.2177],

[0.2170, 0.2180, 0.2183, ..., 0.2180, 0.2180, 0.2180],

[0.2180, 0.2180, 0.2172, ..., 0.2180, 0.2180, 0.2176]])

Please note that many inference methods require a small number of samples before they start drawing samples that correctly resemble the posterior distribution. The 3000 samples that we specified in num_adaptive_samples were already excluded for us, so nothing needs to be done here. However, if you use no adaptive samples, we recommend you discard at least a few hundred samples before using your inference results.

Extracting Samples for a Specific Chain​

We'll see how to make use of chains in Diagnostics; for inspecting the samples themselves, it is often useful to examine each chain individually. The recommended way to access the results of a specific chain is with get_chain():

chain = samples.get_chain(chain=0)
chain
Out:

<beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples>

This returns a new MonteCarloSamples object which is limited to the specified chain. Tensors no longer have a dimension representing the chain:

chain[reproduction_rate()]
Out:

tensor([1.0000, 0.4386, 0.2751, ..., 0.2177, 0.2177, 0.2193])

Visualizing Distributions​

Visualizing the results of inference can be a great help in understanding them. Since you now know how to access posterior samples, you're free to use whatever visualization tools you prefer.

reproduction_rate_samples = samples.get_chain(0)[reproduction_rate()]
loading...

Diagnostics​

After running inference it is useful to run diagnostic tools to assess reliability of the inference run. Bean Machine provides two standard types of such diagnostic tools, discussed below.

Summary Statistics​

Bean Machine provides important summary statistics for individual, numerically-valued random variables. Let's take a look at the code to generate them, and then we'll break down the statistics themselves.

Bean Machine's interface to the ArviZ libray makes it easy to generate a Pandas DataFrame presenting these statistics for all queried random variables:

import arviz as az

az.rcParams["stats.hdi_prob"] = 0.89
az.summary(samples.to_xarray(), round_to=5)
meansdhdi_5.5%hdi_94.5%mcse_meanmcse_sdess_bulkess_tailr_hat
reproduction_rate()0.21960.00030.21920.22010.00.019252.37719175.78751.0002

We recommend reading the official ArviZ documentation for a full explanation, but the statistics presented are:

  1. Mean and standard deviation (SD).
  2. 89% highest density interval (HDI).
  3. Markov chain standard error (MCSE).
  4. Effective sample size (ESS) NeffN_\text{eff}.
  5. Convergence diagnostic R^\hat{R}.

We choose to display the 89% highest density interval (HDI), following recommendations in Statistical Rethinking: A Bayesian Course with Examples in R and Stan (McElreath, 2020). The statistics above provide different insights into the quality of the results of inference. For instance, we can use them in combination with synthetically generated data for which we know ground truth values for parameters and then check to make sure that these values fall within some HDI of our posterior samples. Of course, in doing so it is important to keep in mind that the prior distributions in our model (and not just the data) will have an influence on the posterior distribution. Similarly, we can use the size of the HDI to gain insights into the model: if it is large, this could indicate that either we have too few observations or that the prior is too weak.

R^∈[1,∞)\hat{R} \in [1, \infty) summarizes how effective inference was at converging on the correct posterior distribution for a particular random variable. It uses information from all chains run in order to assess whether inference had a good understanding of the distribution or not. Values very close to 1.01.0 indicate that all chains discovered similar distributions for a particular random variable. We do not recommend using inference results where R^>1.01\hat{R} > 1.01, as inference may not have converged. In that case, you may want to run inference for more samples. However, there are situations in which increasing the number of samples will not improve convergence. In this case, it is possible that the prior is too far from the posterior, or that the particular inference method is unable to reliably explore the posterior distribution.

Neff∈[1,N_\text{eff} \in [1, num_samples]] summarizes how independent posterior samples are from one another. Although inference was run for num_samples iterations, it's possible that those samples were very similar to each other (due to the way inference is implemented), and may not each be representative of the full posterior space. Larger numbers are better here, and if your particular use case calls for a certain number of samples to be considered, you should ensure that NeffN_\text{eff} is at least that large. For more information on R-hat and NeffN_\text{eff}, see the Diagnostics Section.

In the case of our example model, we have a healthy R^\hat{R} value very close to 1.0, and a healthy relative number of effective samples.

Diagnostic Plots​

Bean Machine can also plot diagnostic information to assess health of the inference run. Let's take a look:

az.plot_trace(
{"Reproduction rate": samples[reproduction_rate()]},
compact=False,
)
loading...
az.plot_autocorr({"Reproduction rate": samples[reproduction_rate()]})
loading...

The diagnostics output shows two diagnostic plots for individual random variables: trace plots and autocorrelation plots.

  • Trace plots are simply a time series of values assigned to random variables over each iteration of inference. The concrete values assigned are usually problem-specific. However, it's important that these values are "mixing" well over time. This means that they don't tend to get stuck in one region for large periods of time, and that each of the chains ends up exploring the same space as the other chains throughout the course of inference.
  • Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). (We compute autocorrelation approximately, so it may sometimes exceed these bounds.) In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works. The autocorrelation plots here plot how correlated samples from the end of the chain are compared with samples taken from elsewhere within the chain, as indicated by the iteration index on the x axis.

For our example model, we see from the trace plots that each of the chains are healthy: they don't get stuck, and do not explore a chain-specific subset of the space. From the autocorrelation plots, we see the absolute magnitude of autocorrelation to be very small, often well below 0.10.1, indicating a healthy exploration of the space.


Congratulations, you've made it through the Overview! If you're looking to get an even deeper understanding of Bean Machine, check out the Framework topics next. Or, if you're looking to get to coding, check out our Tutorials. In either case, happy modeling!