Skip to main content

Zero inflated count data

Tutorial: Marginalizing discrete variables in zero inflated count data

In this tutorial we will be investigating data that originated from Berry, and was analyzed by Farewell and Sprott from a study about the efficacy of a medication that helps prevent irregular heartbeats. Counts of patients' irregular heartbeats were observed 60 seconds before the administration of the drug, and 60 seconds after the medication was taken. A large percentage of records show zero irregular heartbeats in the 60 seconds after taking the medication. There are more observed zeros than would be expected if we were to sample from one of the common statistical discrete distributions, see Wikipedia-Distributions for a list of common discrete distributions. The problem we face is trying to model these zero counts in order to appropriately quantify the medication's impact on reducing irregular heartbeats.

Learning outcomes

On completion of this tutorial, you should be able:

  • to execute a zero-inflated count data model with Bean Machine;
  • to increment log probabilities using Bean Machine;
  • to run diagnostics and to understand what Bean Machine is doing.

Problem

We will address how to model response data that has more zeros than would be expected based on a common discrete statistical distribution. If your data have a large number of zeros in them, then they are called zero-inflated. Recognizing when your data is zero-inflated is straightforward using a histogram as we will see below in our data section. Zero inflation occurs in many different scientific fields and application areas. Examples include social science, traffic accident research, econometrics, psychology, and a well known investigation done by Lambert on manufacturing defects that we will take inspiration from.

Prerequisites

We will be using the following packages within this tutorial.

  • arviz and bokeh for interactive visualizations; and
  • pandas for data manipulation.
# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import os
import warnings
from io import StringIO

import arviz as az
import beanmachine.ppl as bm
import pandas as pd
import torch
import torch.distributions as dist
from beanmachine.ppl.model import RVIdentifier
from beanmachine.tutorials.utils import hearts
from bokeh.io import output_notebook
from bokeh.plotting import gridplot, show
from IPython.display import Markdown

The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for torch, for reproducibility.

# Ignore excessive warnings from ArviZ.
warnings.simplefilter("ignore")

# Plotting settings.
az.rcParams["plot.backend"] = "bokeh"
# See McElreath for an exposition on highest density intervals, and why we use 89%.
az.rcParams["stats.hdi_prob"] = 0.89

# Manual seed for torch.
torch.manual_seed(1199)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

Data

The data set we use in this tutorial originates from a study by Berry with subsequent analysis done by Farewell and Sprott. The data consist of EKG measurements where patients' heartbeats were counted for one minute before and after the administration of a drug, see Wikipedia-EKG for a description of what an EKG is. The patients in the study suffer from an affliction called premature ventricular contraction or PVC, see Wikipedia-PVC for more information about PVC events. Acute PVC events are common, but chronic PVC events are considered risk factors for other diseases.

The data from Berry are replicated below, with the addition of one column we will use in our model as outlined in Farewell and Sprott. This column is called total and amounts to the sum of the predrug and postdrug PVC counts.

Column nameDescription
patient_numberPatient ID.
predrugPVC event counts before administering the medication.
postdrugPVC event counts after administering the medication.
decreaseThe difference predrug - postdrug.
totalThe sum of predrug + postdrug.

The data only contain 12 records. Instead of loading the data from a CSV file, we will load it from the data_string below into a pandas dataframe object.

data_string = """patient_number,predrug,postdrug,decrease,total
1,6,5,1,11
2,9,2,7,11
3,17,0,17,17
4,22,0,22,22
5,7,2,5,9
6,5,1,4,6
7,5,0,5,5
8,14,0,14,14
9,9,0,9,9
10,7,0,7,7
11,9,13,-4,22
12,51,0,51,51"""
df = pd.read_csv(StringIO(data_string))
Markdown(df.set_index("patient_number").to_markdown())
patient_numberpredrugpostdrugdecreasetotal
165111
292711
31701717
42202222
57259
65146
75055
81401414
99099
107077
11913-422
125105151

Below are histograms of the PVC event counts pre- and post-medication for all patients. What we observe from these histograms is a lot of zero values after the medication has been administered. These zero values will be important for our model defined in the next section, and are characteristic of zero-inflated data. If we disregarded modeling these zero values we would be throwing out an important feature of our data, and our parameter estimates would be biased.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

value_counts_plot = gridplot(hearts.plot_value_counts(df))
show(value_counts_plot)
loading...

The literature classifies zero count (yi=0)(y_i=0) observations broadly into two kinds of plausible explanations, which we are interested in modeling statistically (see Farewell and Sprott).

  1. A general chance for "curing" by the intervention. In this scenario, the zero count is considered certain, and the observation is understood as a true positive event.
  2. A patient-specific chance for a "varied response level" or "relief". In this scenario, the PVC event frequency is understood to be merely reduced, and the zero count observation is therefore understood as a false positive event.

Note that we cannot assess the biological significance of the statement "cured" in this tutorial.

Models

First model - Binomial linear model

The data come in pairs (xi,yi)(x_i, y_i), where

iithpatient (patient_number)xitotal count of PVC events 60 seconds before medication administration (predrug)yitotal count of PVC events 60 seconds after medication administration (postdrug).\begin{aligned} i &\longrightarrow i^{th}\text{patient (}\texttt{patient\_number}\text{)}\\ x_i &\longrightarrow \text{total count of PVC events 60 seconds before medication administration (}\texttt{predrug}\text{)}\\ y_i &\longrightarrow \text{total count of PVC events 60 seconds after medication administration (}\texttt{postdrug}\text{)}. \end{aligned}

Note that both xix_i and yiy_i data are Z+{0}\in\mathbb{Z}^+\cup\{0\} (positive integers including zero). The data come in pairs because each of the twelve patients had their heartbeats monitored for 60 seconds before and 60 seconds after the administration of an anti-arithmetic drug. The number of PVC events within the 60 second windows are assumed to originate from independent Poisson distributions with some mean λi\lambda_i, thus

xiPoisson(λi)yiPoisson(βλi).\begin{aligned} x_i &\sim \text{Poisson}(\lambda_i)\\ y_i &\sim \text{Poisson}(\beta\lambda_i). \end{aligned}

β\beta is a scaling factor for the rate of PVC events after the administration of the medication. We can eliminate the nuisance parameters (all the λi\lambda_is, which only indirectly contribute to the answer of our modeling question) by using the conditional distribution of yiy_i given the total PVC counts for any patient; ti=xi+yit_i=x_i+y_i. To accomplish this, first let XiX_i and YiY_i denote the Poisson-distributed random variables from which xix_i and yiy_i are assumed to be samples. Then our conditional distribution can be written as

Pr(Yi=yiXi+Yi=ti)=Pr(Xi+Yi=ti,Yi=yi)Pr(Xi+Yi=ti)=Pr(Xi=tiyi)Pr(Yi=yi)Pr(Xi+Yi=ti).\begin{aligned} \text{Pr}(Y_i=y_i\mid X_i+Y_i=t_i) &= \frac{\text{Pr}(X_i+Y_i=t_i,Y_i=y_i)}{\text{Pr}(X_i+Y_i=t_i)}\\ &= \frac{\text{Pr}(X_i=t_i-y_i)\text{Pr}(Y_i=y_i)}{\text{Pr}(X_i+Y_i=t_i)}. \end{aligned}

The denominator is a probability density for the sum of the independent Poisson distributions XiX_i and YiY_i. A property of Poisson distributions is that the addition of two Poisson distributions is another Poisson with the parameters added together, see Wikipedia-Poisson distribution. Using this property, our conditional probability becomes

Pr(Yi=yiXi+Yi=ti)=Poisson(tiyiλi)Poisson(yiβλi)Poisson(tiλi+βλi)=λitiyi(tiyi)!eλi(βλi)yiyi!eβλi(λi+βλi)titi!e(λi+βλi)=(tiyi)(βλiλi+βλi)yi(λiλi+βλi)tiyiPr(yiti,p)=Binomial(yiti,p),\begin{aligned} \text{Pr}(Y_i=y_i\mid X_i+Y_i=t_i) &= \frac{\text{Poisson}(t_i-y_i\mid\lambda_i)\text{Poisson}(y_i\mid\beta\lambda_i)} {\text{Poisson}(t_i\mid\lambda_i+\beta\lambda_i)} = \frac{ \frac{\lambda_i^{t_i-y_i}}{(t_i-y_i)!}e^{-\lambda_i}\cdot \frac{(\beta\lambda_i)^{y_i}}{y_i!}e^{-\beta\lambda_i} } {\frac{(\lambda_i+\beta\lambda_i)^{t_i}}{t_i!}e^{-(\lambda_i+\beta\lambda_i)}}\\ &= \binom{t_i}{y_i} \left(\frac{\beta\lambda_i}{\lambda_i+\beta\lambda_i}\right)^{y_i} \left(\frac{\lambda_i}{\lambda_i+\beta\lambda_i}\right)^{t_i-y_i}\\ \therefore\text{Pr}(y_i\mid t_i, p) &= \text{Binomial}(y_i\mid t_i, p), \end{aligned}

where pp is

p=βλiλi+βλi=β1+β.p=\frac{\beta\lambda_i}{\lambda_i+\beta\lambda_i}=\frac{\beta}{1+\beta}.

The conditional probability distribution Pr(yiti,p)\text{Pr}(y_i\mid t_i, p) no longer has any of the nuisance parameters λi\lambda_i. We will use this likelihood in our first model, with an uninformed prior on pp leading to the generative model

pUniform(0,1)yiti,pBinomial(ti,p).\begin{aligned} p &\sim \text{Uniform}(0, 1)\\ y_i\mid t_i, p &\sim \text{Binomial}(t_i, p). \end{aligned}

NOTE We can implement this model in Bean Machine by defining random variable objects with the decorator @bm.random_variable and deterministic values using the bm.functional decorator. These functions behave differently than ordinary Python functions.

Semantics for @bm.random_variable functions:
  • They must return PyTorch Distribution objects.
  • Though they return distributions, callees actually receive a sample from the distribution. The machinery for obtaining samples from distributions is handled internally by Bean Machine.
  • Inference runs the model through many iterations. During a particular inference iteration, a distinct random variable will correspond to exactly one sampled value: calls to the same random variable function with the same arguments will receive the same sampled value within one inference iteration. This makes it easy for multiple components of your model to refer to the same logical random variable.
  • Consequently, to define distinct random variables that correspond to different sampled values during a particular inference iteration, an effective practice is to add a dummy "indexing" parameter to the function. Distinct random variables can be referred to with different values for this index.
  • Please see the documentation for more information about this decorator.
Semantics for @bm.functional functions:
  • This is a decorator that lets you treat deterministic code as if it is a Bean Machine random variable. This is used to transform the results of one or more random variables.
  • Follows the same naming practice as @bm.random_variable where variables are distinguished by their argument call values.
  • Please see the documentation for more information about this decorator.
# Define a variable with the total number of records.
N = df.shape[0]

# Create a torch tensor containing the total PVC events for each record.
t = torch.tensor(df["total"].astype(float).values)
@bm.random_variable
def first_model_p() -> RVIdentifier:
return dist.Uniform(low=0.0, high=1.0)


@bm.random_variable
def first_model_y(i: int) -> RVIdentifier:
return dist.Binomial(total_count=t[i], probs=first_model_p())

Prior predictive checks

Before we run inference, we will conduct prior predictive checks. These checks will determine if our priors and model can capture relevant statistics about the observed data. We will use Bean Machine and its infer method to run the prior predictive checks even though we are not actually running posterior inference. Bean Machine's infer method has the following API.

NameUsage
queriesList of @bm.random_variable targets to fit posterior distributions for.
observationsA dictionary of observations.
num_samplesNumber of Monte Carlo samples to approximate the posterior distributions for the variables in queries.
num_chainsNumber of separate inference runs to use. Multiple chains can help verify that inference ran correctly.

When we run prior predictive checks, we pass the infer method an empty dictionary for the observations keyword. This corresponds to a target density for Bean Machine's MCMC inference which does not include any observation likelihood terms next to the model prior density (here our prior density is first_model_p). We can then manually feed the prior samples thus obtained into the model likelihood (here lambda p, i: dist.Binomial(total_count=t[i], probs=p)), to obtain synthetic data from the prior predictive distribution.

first_model_num_samples = 1 if smoke_test else 500
first_model_num_adaptive_samples = 0 if smoke_test else first_model_num_samples // 2
first_model_num_chains = 1 if smoke_test else 4
# Experimental backend using the Pytorch NNC compiler
nnc_compile = "SANDCASTLE_NEXUS" not in os.environ

# Model queries.
first_model_queries = [first_model_p()]

# Note our observations dictionary is empty.
first_model_observations = {}

# Note we are using the GlobalNoUTurnSampler, this will change in our second model.
first_model_samples_no_observations = bm.GlobalNoUTurnSampler(nnc_compile=nnc_compile).infer(
queries=first_model_queries,
observations=first_model_observations,
num_samples=first_model_num_samples,
num_chains=first_model_num_chains,
num_adaptive_samples=first_model_num_adaptive_samples,
)
Out:

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Below is a plot of synthetic data simulated from the prior predictive distribution, along with the real observed data (magenta square) and total PVC count data (brown diamond) for each patient. The histograms represent simulated data, where the thick blue line below the histogram is its 89% HDI (highest density interval) region, and the white marker is its mean. HDI intervals are discussed at length in McElreath. The support defined by our prior spans across observed values for postdrug PVC counts and PVC counts for predrug + postdrug. The prior for pp is not very informative for our model, but it does capture the relevant bounds of the data.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

# Predictive checks without observations.
first_model_predictive_checks_no_observations = hearts.Model1PredictiveChecks(
samples_without_observations=first_model_samples_no_observations,
data=df,
p_query=first_model_p,
)
first_model_prior_pc = (
first_model_predictive_checks_no_observations.plot_prior_predictive_checks()
)
show(first_model_prior_pc)
loading...

Next, we run inference using this model, and compare the prior predictive plot with a posterior predictive plot.

# Note the observations dictionary is no longer empty.
first_model_observations = {
first_model_y(i): torch.tensor(observed)
for i, observed in enumerate(df["postdrug"].astype(float).tolist())
}

# Run inference with observations.
first_model_samples = bm.GlobalNoUTurnSampler(nnc_compile=nnc_compile).infer(
queries=first_model_queries,
observations=first_model_observations,
num_samples=first_model_num_samples,
num_chains=first_model_num_chains,
num_adaptive_samples=first_model_num_adaptive_samples,
)
Out:

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

Samples collected: 0%| | 0/750 [00:00<?, ?it/s]

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

# Predictive checks with observations.
first_model_predictive_checks = hearts.Model1PredictiveChecks(
samples_with_observations=first_model_samples,
data=df,
p_query=first_model_p,
)

# Remove the prior plot legend since it is duplicated in the posterior plot.
first_model_prior_pc.legend.visible = False

# Posterior plot.
first_model_posterior_pc = (
first_model_predictive_checks.plot_posterior_predictive_checks()
)

# Link plots together so zooms occur in both plots.
first_model_posterior_pc.y_range = first_model_prior_pc.y_range
first_model_posterior_pc.x_range = first_model_prior_pc.x_range

# Compare prior/posterior predictive checks.
prior_post_plots = gridplot([[first_model_prior_pc, first_model_posterior_pc]])
show(prior_post_plots)
loading...

The left plot is exactly what we saw above, and the right plot is simulating data using the posterior predictive. The mean simulated density values shift towards data (magenta markers) in the right plot as expected. However, the model is not capturing the observed zero values as shown in the histograms of the right plot for each patient. This suggests that our model is not specified well and needs to be updated to capture the observed zero values.

Second model - Zero-inflated Binomial linear model

We begin our model refinement by revisiting the conditional probability derived above,

yiti,pBinomial(ti,p).y_i\mid t_i,p\sim\text{Binomial}(t_i,p).

We will multiply pp by a random variable that samples from a Bernoulli for each patient,

siBernoulli(1θ)yiti,p,siBinomial(ti,psi).\begin{aligned} s_i &\sim \text{Bernoulli}(1-\theta)\\ y_i\mid t_i, p, s_i &\sim \text{Binomial}(t_i,p\cdot s_i). \end{aligned}

Whenever si=1s_i=1 is sampled, this model is equivalent to the previous model. When si=0s_i=0, then the parameter of the Binomial distribution becomes 00 and yi=0y_i=0 is obtained with certainty. Hence sis_i decides for each patient ii whether their intervention effect is "relief" or "curing", and θ\theta is a mixing parameter between these two scenarios. In particular, for θ=1\theta=1, all patients would be always "cured".

The model now has two mechanisms for causing a zero count parameterized by θ\theta and pp. We could sample θ\theta and pp from Uniform(0,1)\text{Uniform}(0, 1) like we did in the first model, however, we will center their prior distributions about zero. We can accomplish this by setting the log-odds for both θ\theta and pp equal to two new random variables α\alpha and δ\delta,

α=ln(p1p)δ=ln(θ1θ).\alpha=\ln\left(\frac{p}{1-p}\right)\quad\quad\delta=\ln\left(\frac{\theta}{1-\theta}\right).

Both α\alpha and δ\delta can be sampled from normal distributions since the normal distribution has a support over (,+)(-\infty,+\infty), just like the log-odds values. To get θ\theta or pp back, we calculate the inverse logit of α\alpha or δ\delta. logit1\text{logit}^{-1} is defined in torch as the sigmoid method. We can write our model now as

αNormal(0,10)δNormal(0,10)p=logit1(α)θ=logit1(δ)siBernoulli(1θ)yiti,p,siBinomial(ti,psi).\begin{aligned} \alpha &\sim \text{Normal}(0, 10)\\ \delta &\sim \text{Normal}(0, 10)\\ p &= \text{logit}^{-1}\left(\alpha\right)\\ \theta &= \text{logit}^{-1}\left(\delta\right)\\ s_i &\sim \text{Bernoulli}(1-\theta)\\ y_i\mid t_i, p, s_i &\sim \text{Binomial}(t_i, p\cdot s_i). \end{aligned}

We can code this in Bean Machine as shown below.

@bm.random_variable
def second_model_alpha() -> RVIdentifier:
return dist.Normal(0, 10)


@bm.random_variable
def second_model_delta() -> RVIdentifier:
return dist.Normal(0, 10)


@bm.functional
def second_model_p() -> torch.Tensor:
return torch.sigmoid(second_model_alpha())


@bm.functional
def second_model_theta() -> torch.Tensor:
return torch.sigmoid(second_model_delta())


@bm.random_variable
def second_model_s(i: int) -> RVIdentifier:
return dist.Bernoulli(1 - second_model_theta())


@bm.random_variable
def second_model_y(i: int) -> RVIdentifier:
return dist.Binomial(t[i], second_model_p() * second_model_s(i))

Prior predictive checks

We will use the SingleSiteAncestralMetropolisHastings inference method because our model is mixing both discrete variables (si)(s_i) and continuous random variables (α(\alpha and δ)\delta). Such models cannot be handled by Hamiltonian Monte Carlo, which prohibits the use of the GlobalNoUTurnSampler.

second_model_num_samples = 2 if smoke_test else 2000
second_model_num_chains = 1 if smoke_test else 4
second_model_num_adaptive_samples = 0 if smoke_test else second_model_num_samples // 2

# Model queries.
second_model_queries = [
second_model_alpha(),
second_model_delta(),
second_model_p(),
second_model_theta(),
] + [second_model_s(i) for i in range(N)]

# Model without observations. Note we are not using the GlobalNoUTurnSampler.
second_model_samples_no_observations = bm.SingleSiteAncestralMetropolisHastings().infer(
queries=second_model_queries,
observations={},
num_samples=second_model_num_samples,
num_chains=second_model_num_chains,
num_adaptive_samples=second_model_num_adaptive_samples,
)
Out:

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

# Predictive checks without observations.
second_model_predictive_checks_no_observations = hearts.Model2PredictiveChecks(
samples_without_observations=second_model_samples_no_observations,
data=df,
p_query=second_model_p,
s_query_str="second_model_s",
)

second_model_prior_pc = (
second_model_predictive_checks_no_observations.plot_prior_predictive_checks()
)
sec_model_prior_plot = gridplot([[second_model_prior_pc]])
show(sec_model_prior_plot)
loading...

The above plot is successfully simulating more zeros using the given priors than our first model. You can use the tools above the plot to zoom in on any patient. Doing so will show that the prior predictive distribution has positive density everywhere between the end points of zero and the total PVC counts for a patient (the brown diamonds). The large bars near zero and tit_i are caused from the sigmoid used when calculating pp and θ\theta, and are expected.

Next we will run inference using this model and run posterior predictive checks along with other model analysis described below. You will note that we have increased the number of samples and adaptive samples a lot compared to the first model in order for the diagnostics to show good mixing of the chains.

Inference

scnd_mdl_post_num_samples = 2 if smoke_test else 10000
scnd_mdl_post_num_chains = 1 if smoke_test else 4
scnd_mdl_post_num_adaptive_samples = 0 if smoke_test else second_model_num_samples // 2

# Create a tensor of observed values.
y_observed = torch.tensor(df["postdrug"].astype(float).values)

# Model observations is not empty.
second_model_observations = {second_model_y(i): y_observed[i] for i in range(N)}

# Model with observations. Note the GlobalUTurnSampler is not used for inference.
second_model_samples = bm.SingleSiteAncestralMetropolisHastings().infer(
queries=second_model_queries,
observations=second_model_observations,
num_samples=scnd_mdl_post_num_samples,
num_chains=scnd_mdl_post_num_chains,
num_adaptive_samples=scnd_mdl_post_num_adaptive_samples,
)
Out:

Samples collected: 0%| | 0/11000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/11000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/11000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/11000 [00:00<?, ?it/s]

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

# Predictive checks with observations.
second_model_predictive_checks = hearts.Model2PredictiveChecks(
samples_with_observations=second_model_samples,
data=df,
p_query=second_model_p,
s_query_str="second_model_s",
)

# Remove the prior plot legend since it is duplicated in the posterior plot.
second_model_prior_pc.legend.visible = False

# Posterior plot.
second_model_posterior_pc = (
second_model_predictive_checks.plot_posterior_predictive_checks()
)

# Link plots together.
second_model_posterior_pc.y_range = second_model_prior_pc.y_range
second_model_posterior_pc.x_range = second_model_prior_pc.x_range

# Compare prior/posterior predictive checks.
sec_prior_post_plot = gridplot([[second_model_prior_pc, second_model_posterior_pc]])
show(sec_prior_post_plot)
loading...

The left plot is again the prior predictive simulation for our second model, and the right plot is the posterior predictive simulated data. You can see that this model is quite effectively capturing the observed zero data. The major caveat here is that we had to use a lot more computational resources in order to capture the zeros than in our previous model. The posterior predictive plot looks like it is doing a good job of simulating data from our model, but we will do a bit more analysis below to ensure the model is working as expected.

Analysis

We begin our analysis by printing out summary statistics. Two important statistics to take note of are the R^\hat{R} (r_hat)(r\_hat) and effective sample size (ess)(ess) values in the below dataframe.

second_model_summary_df = az.summary(second_model_samples.to_xarray(), round_to=4)

# Rearanging for easy viewing of the dataframe.
sort_order = dict(zip(second_model_queries, range(1, len(second_model_queries) + 1)))
second_model_summary_df.reset_index(inplace=True)
second_model_summary_df["sort"] = second_model_summary_df["index"].map(sort_order)
second_model_summary_df.sort_values(by="sort", inplace=True)
second_model_summary_df.rename(columns={"index": "query"}, inplace=True)
second_model_summary_df.set_index("query", inplace=True)
second_model_summary_df.drop("sort", axis=1, inplace=True)

Markdown(second_model_summary_df.to_markdown())
querymeansdhdi_5.5%hdi_94.5%mcse_meanmcse_sdess_bulkess_tailr_hat
second_model_alpha()-0.480.2843-0.9562-0.0870.01030.0073757.703749.6531.0027
second_model_delta()0.30010.6234-0.63131.3550.01330.01032160.442163.41.0016
second_model_p()0.38440.06610.27760.47830.00240.0017757.703749.6531.0029
second_model_theta()0.56810.13980.34720.79490.0030.00222160.442163.41.0016
second_model_s(0,)1011004000040000nan
second_model_s(1,)1011004000040000nan
second_model_s(2,)0.0010.032000.00030.000212367.612367.61.0001
second_model_s(3,)0000004000040000nan
second_model_s(4,)1011004000040000nan
second_model_s(5,)1011004000040000nan
second_model_s(6,)0.07760.2676000.00270.00199726.819726.811.0002
second_model_s(7,)0.0020.045000.00040.000313813.713813.71
second_model_s(8,)0.01720.13000.00120.000911247.311247.31.0004
second_model_s(9,)0.03390.181000.00170.001210956.410956.41.0002
second_model_s(10,)1011004000040000nan
second_model_s(11,)0000004000040000nan

Note that there are NaN values for some R^\hat{R} calculations in the above dataframe for some sis_i parameters. This is occurring because there is no variance in those sis_i parameters. We can ignore these values as it is entirely possible for a chain of sis_i to be all ones or all zeros, giving rise to zero variance within the chain and a NaN value in R^\hat{R}.

Measuring variance between- and within-chains with R^\hat{R} (r_hat)(r\_hat)

R^\hat{R} is a diagnostic tool that measures the between- and within-chain variances. It is a test that indicates a lack of convergence by comparing the variance between multiple chains to the variance within each chain. If the parameters are successfully exploring the full space for each chain, then R^1\hat{R}\approx 1, since the between-chain and within-chain variance should be equal. R^\hat{R} is calculated from NN samples as

R^=V^WV^=N1NW+1NB,\begin{aligned} \hat{R} &= \frac{\hat{V}}{W} \\ \hat{V} &= \frac{N-1}{N} W + \frac{1}{N} B, \end{aligned}

where WW is the within-chain variance, BB is the between-chain variance and V^\hat{V} is the estimate of the posterior variance of the samples. The take-away here is that R^\hat{R} converges to 1 when each of the chains begins to empirically approximate the same posterior distribution. We do not recommend using inference results if R^>1.01\hat{R}>1.01. More information about R^\hat{R} can be found in the Vehtari et al paper.

Effective sample size (ess)(ess) diagnostic

MCMC samplers do not draw truly independent samples from the target distribution, which means that our samples are correlated. In an ideal situation all samples would be independent, but we do not have that luxury. We can, however, measure the number of effectively independent samples we draw, which is called the effective sample size. You can read more about how this value is calculated in the Vehtari et al paper. In brief, it is a measure that combines information from the R^\hat{R} value with the autocorrelation estimates within the chains.

ESS estimates come in two variants, ess_bulk and ess_tail. The former is the default, but the latter can be useful if you need good estimates of the tails of your posterior distribution. The rule of thumb for ess_bulk is for this value to be greater than 100 per chain on average. Since we ran four chains, we need ess_bulk to be greater than 400 for each parameter. The ess_tail is an estimate for effectively independent samples considering the more extreme values of the posterior. This is not the number of samples that landed in the tails of the posterior, but rather a measure of the number of effectively independent samples if we sampled the tails of the posterior. The rule of thumb for this value is also to be greater than 100 per chain on average.

We use arviz to plot the evolution of the effective sample size for each of our parameters. The below plots show over each successive draw from our model the estimated effective sample size. The red horizontal dashed line shows the rule-of-thumb value needed for both the bulk and tail ess estimates. When the model is converging properly, both the bulk and tail lines should be roughly linear.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

keys = ["p", "θ"]
values = [second_model_p(), second_model_theta()]
parameters = dict(zip(keys, values))
plots = []
for title, parameter in parameters.items():
data = {title: second_model_samples.get(parameter).numpy()}
f = az.plot_ess(data, kind="evolution", show=False)[0][0]
f.plot_width = 500
f.plot_height = 500
f.y_range.start = 0
f.outline_line_color = "black"
f.grid.grid_line_alpha = 0.2
f.grid.grid_line_color = "grey"
f.grid.grid_line_width = 0.2
plots.append(f)
p_theta_plots = gridplot([plots[:2], plots[2:]])
show(p_theta_plots)
loading...

The ess diagnostics look linear and are all above the rule-of-thumb value of 400 (the red dashed line) for each parameter in our model. We continue the diagnostics by investigating the posteriors, rank plots, and autocorrelations for each parameter.

  • Rank plots are a histogram of the samples over time. All samples across all chains are ranked and then we plot the average rank for each chain on regular intervals. If the chains are mixing well this histogram should look roughly uniform. If it looks highly irregular that suggests chains might be getting stuck and not adequately exploring the sample space.
  • Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). We compute autocorrelation approximately, so it may sometimes exceed these bounds. In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)

second_model_diag_plot = gridplot(hearts.plot_diagnostics(second_model_samples, values))
show(second_model_diag_plot)
loading...

The first plot for each row is the posterior for the query. The second multicolored line plot is the posterior for each chain, followed by the rank plot for each chain. Finally we have a set of plots showing the autocorrelation of the samples.

All of the posteriors show high frequency fluctuations. These fluctuations are also shown in the rank plots, where we have peaks and valleys about the black dashed line, which is added as a convenient way to quickly gauge if the rank plots are uniform within- and between-chains. These fluctuations indicate that the sampler is spending either too much time (peaks), or not enough time (valleys) stuck in the respective values of the sample space. This amounts to our chains not mixing sufficiently. The autocorrelation plots also show that samples are highly correlated, which is another indicator that the model is having a difficult time sampling from the true posterior.

Farewell and Sprott quote their posterior means and 95% confidence interval for α\alpha and δ\delta as

  • α=0.4636±0.5318\alpha=-0.4636\pm0.5318 and
  • δ=0.3043±1.1643\delta=0.3043\pm1.1643.

Our model is in fairly good agreement with these values with posterior means and standard deviations of

Markdown(
second_model_summary_df.loc[
second_model_summary_df.index.astype(str).isin(
["second_model_alpha()", "second_model_delta()"]
),
["mean", "sd"],
].to_markdown()
)
querymeansd
second_model_alpha()-0.480.2843
second_model_delta()0.30010.6234

Despite having reasonable similarities to the published results, we had to run inference for a large number of samples. Doing so resulted in posteriors that have high frequency components in their distribution, and show correlated sample draws within each chain, which may indicate that our model is not as good as we can make it. We actually can make a better model by marginalizing out the discrete variables sis_i.

Third model - Marginalization of the discrete variable

Marginalizing out sis_i will allow us to use the GlobalNoUTurnSampler as well as make a less computationally expensive model. Marginalizing lets us compute the expectation exactly so our model will have zero variance coming from the discrete variables mitigating some of the high frequency fluctuations we saw from model 2. Recall we had

siBernoulli(1θ)yiti,p,siBinomial(ti,psi).\begin{aligned} s_i &\sim \text{Bernoulli}(1-\theta)\\ y_i\mid t_i, p, s_i &\sim \text{Binomial}(t_i, p\cdot s_i). \end{aligned}

Marginalizing out sis_i gives

Pr(si)=(1θ)siθ1sisiPr(si)Pr(yiti,p,si)=Pr(si=0)Pr(yiti,p,si=0)+Pr(si=1)Pr(yiti,p,si=1)Pr(yiti,p,θ)=θPr(yiti,p,si=0)+(1θ)Pr(yiti,p,si=1).\begin{aligned} \text{Pr}(s_i) &= (1-\theta)^{s_i}\theta^{1-s_i}\\ \sum_{s_i}\text{Pr}(s_i)\cdot\text{Pr}(y_i\mid t_i,p,s_i) &= \text{Pr}(s_i=0)\cdot\text{Pr}(y_i\mid t_i, p, s_i=0) +\text{Pr}(s_i=1)\cdot\text{Pr}(y_i\mid t_i, p, s_i=1)\\ \text{Pr}(y_i\mid t_i,p,\theta) &= \theta\cdot\text{Pr}(y_i\mid t_i, p, s_i=0)+(1-\theta)\cdot\text{Pr}(y_i\mid t_i, p, s_i=1). \end{aligned}

We can obtain a more familiar form if we split the probability into cases when yi=0y_i=0 and when yi0y_i\neq0.

θPr(yiti,p,si=0)=θ{1,yi=00,yi0(1θ)Pr(yiti,p,si=1)=(1θ)Binomial(yiti,p)=(1θ){(1p)ti,yi=0Binomial(yiti,p),yi0Pr(yiti,p,θ)={θ+(1θ)(1p)ti,yi=0(1θ)Binomial(yiti,p),yi0\begin{aligned} \theta\cdot\text{Pr}(y_i\mid t_i, p, s_i=0) &= \theta\begin{cases} 1, &\quad y_i=0\\ 0, &\quad y_i\neq0 \end{cases}\\ (1-\theta)\cdot\text{Pr}(y_i\mid t_i, p, s_i=1) = (1-\theta)\cdot\text{Binomial}(y_i\mid t_i, p) &= (1-\theta)\begin{cases} (1-p)^{t_i}, &\quad y_i=0\\ \text{Binomial}(y_i\mid t_i, p), &\quad y_i\neq0 \end{cases}\\ \text{Pr}(y_i\mid t_i,p,\theta) &= \begin{cases} \theta+(1-\theta)(1-p)^{t_i}, &\quad y_i=0\\ (1-\theta)\cdot\text{Binomial}(y_i\mid t_i, p), &\quad y_i\neq0 \end{cases} \end{aligned}

This model can therefore sample zero counts either due to

  1. observed values yi=0y_i=0; or
  2. a Binomial distribution when yi0y_i\neq0.

We can rewrite the conditional probability using indicator functions I()\mathcal{I}(\bullet), where I()\mathcal{I}(\bullet) is 11 if \bullet is True and 00 otherwise. The indicator functions are acting as an if else statement in code.

i=I(yi=0)[θ+(1θ)(1p)ti]+I(yi0)[(1θ)Binomial(yiti,p)]\ell_i=\mathcal{I}(y_i=0)\left[\theta+(1-\theta)(1-p)^{t_i}\right] +\mathcal{I}(y_i\neq0)\left[(1-\theta)\cdot\text{Binomial}(y_i\mid t_i,p)\right]

i\ell_i is just a number to add to the log-density or multiply to the density. The simplest way to implement adding to the log-density or multiplying to the density is to place the likelihood factor inside an auxiliary Bernoulli observation. If we look at the probability mass function (PMF) for a Bernoulli using i\ell_i as the probability we have that

Bernoulli(di)=i  d(1i)1d.\text{Bernoulli}(d\mid\ell_i)=\ell_i^{\;d}(1-\ell_i)^{1-d}.

If we set d1d\equiv1, then the PMF becomes

Bernoulli(d=1i)=i.\text{Bernoulli}(d=1\mid\ell_i)=\ell_i.

Our full generative model can now be expressed as

αNormal(0,10)δNormal(0,10)p=logit1(α)θ=logit1(δ)iyi,ti,p,θδI(yi=0)[θ+(1θ)(1p)ti]+I(yi0)[(1θ)Binomial(yiti,p)]diiBernoulli(i),\begin{aligned} \alpha &\sim \text{Normal}(0, 10)\\ \delta &\sim \text{Normal}(0, 10)\\ p &= \text{logit}^{-1}\left(\alpha\right)\\ \theta &= \text{logit}^{-1}\left(\delta\right)\\ \ell_i\mid y_i,t_i,p,\theta &\sim \delta_{\mathcal{I}(y_i=0)\left[\theta+(1-\theta)(1-p)^{t_i}\right] +\mathcal{I}(y_i\neq0)\left[(1-\theta)\cdot\text{Binomial}(y_i\mid t_i,p)\right]}\\ d_i\mid\ell_i &\sim \text{Bernoulli}(\ell_i), \end{aligned}

where did_i is a dummy variable and an auxiliary observation that will always be 11.

In the previous model we had samples from Pr(siyi,θ,p,ti)\text{Pr}(s_i\mid y_i,\theta,p,t_i) to use when simulating yiy_i. A challenge that arises when using a model where we have marginalized out the discrete variables is that it becomes less clear how to simulate from the posterior predictive distribution. We tackle this challenge by noticing that Pr(si,yiθ,p,ti)\text{Pr}(s_i,y_i\mid\theta,p,t_i) is computed as part of our likelihood calculation every iteration, which is proportional to Pr(siyi,ti,p,θ)\text{Pr}(s_i\mid y_i,t_i,p,\theta). We can capture this information using @bm.functional and query for it during posterior inference. We then use these probabilities to simulate sis_i with a Bernoulli distribution passing the probability as an argument. These samples of sis_i can then be used identically to how they were used in the previous model. This is a trade-off for marginalizing out the discrete variable; we still have to keep a catalog of it if we want to simulate data, but we will see that running inference is a less computationally expensive task.

@bm.random_variable
def alpha() -> RVIdentifier:
return dist.Normal(0, 10)


@bm.random_variable
def delta() -> RVIdentifier:
return dist.Normal(0, 10)


@bm.functional
def p():
return torch.sigmoid(alpha())


@bm.functional
def theta():
return torch.sigmoid(delta())


@bm.functional
def ell0(i: int):
return theta() + (1.0 - theta()) * torch.pow(1.0 - p(), t[i])


@bm.functional
def ell1(i: int):
return (1.0 - theta()) * dist.Binomial(t[i], p()).log_prob(y_observed[i]).exp()


@bm.functional
def s(i: int):
# s_i = 0
s0 = theta() if y_observed[i] == 0 else torch.tensor(0.0)
# s_i = 1
s1 = ell1(i)
return torch.tensor([s0, s1])


@bm.random_variable
def d(i: int) -> RVIdentifier:
ell = ell0(i) if y_observed[i] == 0 else ell1(i)
return dist.Bernoulli(ell)

Inference

We set all the auxiliary observations d(i)d(i) to 11 (corresponding to did_i) and use y_observed[i] (corresponding to yiy_i) inside the model to compute the likelihood factor.

thrd_mdl_num_samples = 2 if smoke_test else 4000
thrd_mdl_num_chains = 1 if smoke_test else 4
thrd_mdl_num_adaptive_samples = 0 if smoke_test else thrd_mdl_num_samples // 2

queries = [alpha(), delta(), p(), theta()] + [s(i) for i in range(N)]

# Observations are all equal to 1.
observations = {d(i): torch.tensor(1.0) for i in range(df.shape[0])}

# NNC compilation gets into some issues with this model so we will turn it off for now
samples = bm.GlobalNoUTurnSampler(nnc_compile=False).infer(
queries=queries,
observations=observations,
num_samples=thrd_mdl_num_samples,
num_chains=thrd_mdl_num_chains,
num_adaptive_samples=thrd_mdl_num_adaptive_samples,
)
Out:

Samples collected: 0%| | 0/6000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/6000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/6000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/6000 [00:00<?, ?it/s]

Below we plot the posterior predictive checks for model 3. We are able to recover the posterior predictive distributions we saw in model 2 in 3 because we kept track of Pr(siyi,ti,p,θ)\text{Pr}(s_i\mid y_i,t_i,p,\theta) using Bean Machine. Recall that our bm.functional object for s_i is a way of book keeping, so we can reproduce model 2's ability to run posterior predictive checks. If we have successfully tracked sis_i, then the two posterior predictive checks from model 2 and 3 should be nearly identical.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

# Posterior predictive checks.
third_model_predictive_checks = hearts.Model3PredictiveChecks(
samples_with_observations=samples,
data=df,
p_query=p,
s_query_str="s",
)

# Remove the second model posterior plot legend since it is duplicated in the third
# model's posterior plot.
second_model_posterior_pc.legend.visible = False

# Posterior plot.
third_model_posterior_pc = (
third_model_predictive_checks.plot_posterior_predictive_checks()
)

# Link plots together.
second_model_posterior_pc.y_range = third_model_posterior_pc.y_range
second_model_posterior_pc.x_range = third_model_posterior_pc.x_range

# Compare the prior/posterior predictive plots.
sec_thrd_model_post_plot = gridplot([[second_model_posterior_pc, third_model_posterior_pc]])
show(sec_thrd_model_post_plot)
loading...

Analysis

Below is a print out of summary statistics for the queries.

summary_df = az.summary(samples.to_xarray(), round_to=4)
summary_df.reset_index(inplace=True)
summary_df["sort"] = summary_df["index"].map(sort_order)
summary_df.dropna(subset=["sort"], axis=0, inplace=True)
summary_df.sort_values(by="sort", inplace=True)
summary_df.rename(columns={"index": "query"}, inplace=True)
summary_df.set_index("query", inplace=True)
summary_df.drop("sort", axis=1, inplace=True)

Markdown(summary_df.to_markdown())
querymeansdhdi_5.5%hdi_94.5%mcse_meanmcse_sdess_bulkess_tailr_hat

The essess and R^\hat{R} values look excellent. Next we look at how the effective sample sizes evolved when sampling.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

keys = ["p", "θ"]
values = [p(), theta()]
parameters = dict(zip(keys, values))
plots = []
for title, parameter in parameters.items():
data = {title: samples.get(parameter).numpy()}
f = az.plot_ess(data, kind="evolution", show=False)[0][0]
f.plot_width = 400
f.plot_height = 400
f.y_range.start = 0
f.outline_line_color = "black"
f.grid.grid_line_alpha = 0.2
f.grid.grid_line_color = "grey"
f.grid.grid_line_width = 0.2
plots.append(f)
thrd_p_theta_plots = gridplot([plots[:2], plots[2:]])
show(thrd_p_theta_plots)
loading...

We are above the rule-of-thumb of greater than 400 for the effective sample size plots, and all plots also look linear. Below we look at the posteriors, rank plots, and autocorrelation plots for the queries of the model.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

diag_plots = gridplot(hearts.plot_diagnostics(samples, values))
show(diag_plots)
loading...

In comparison to the second model all the above posteriors look smoother, and the autocorrelation plots also show less autocorrelation when sampling. The third model is exploring the parameter space more effectively (and efficiently) than the second model.

Discussion

The three plots below show the progression we took for creating a zero inflated model. Model 1 was incapable of simulating the observed zero values, while model 2 did a great job of simulating zeros conditioned on the data yi=0y_i=0. Model 2 took a lot of computational resources to produce, and had several issues with the posterior distributions for the parameters along with the parameter autocorrelation plots that led us to creating model 3. Model 3 used less computational resources, reproduced the prior and posterior predictive plots of model 2, and explored the parameter space more efficiently than model 2. The extra work of marginalizing out the discrete variable and keeping tabs on Pr(siyi,θ,p,ti)\text{Pr}(s_i\mid y_i,\theta,p,t_i) for simulating data was worth the effort.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

first_model_posterior_pc.legend.visible = False
second_model_posterior_pc.legend.visible = False
plots = [first_model_posterior_pc, second_model_posterior_pc, third_model_posterior_pc]
frst_scnd_thrd_plots = gridplot([plots])
show(frst_scnd_thrd_plots)
loading...

Comparing our third model with the results of Farewell and Sprott show we are still in good agreement.

  • α=0.4636±0.5318\alpha=-0.4636\pm0.5318 and
  • δ=0.3043±1.1643\delta=0.3043\pm1.1643
Markdown(
summary_df.loc[
summary_df.index.astype(str).isin(["alpha()", "delta()"]),
["mean", "sd"],
].to_markdown()
)
querymeansd

Rather than quote a 95% confidence interval for our findings, we will plot the 89% highest density interval for δ\delta and α\alpha below. Why do we use 89%? Mostly to prevent you from thinking about p-values and to think about the posterior densities. See McElreath for a good discussion about HDI intervals.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

a = az.plot_posterior({"α": samples.get(alpha()).numpy()}, show=False)[0][0]
d = az.plot_posterior({"δ": samples.get(delta()).numpy()}, show=False)[0][0]
a_d_plots = gridplot([[a, d]])
show(a_d_plots)
loading...

We plot the marginal posterior distributions for θ\theta, pp, and β\beta below. Recall that β=exp(α)\beta=\exp(\alpha) is the coefficient of the Poisson mean for PVC events after the medication had been administered.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

th = az.plot_posterior({"θ": samples.get(theta()).numpy()}, show=False)[0][0]
p_ = az.plot_posterior({"p": samples.get(p()).numpy()}, show=False)[0][0]
b = az.plot_posterior({"β": samples.get(alpha()).exp().numpy()}, show=False)[0][0]
th_p_b_plots = gridplot([[p_, th, b]])
show(th_p_b_plots)
loading...

The above plot shows the marginal posterior mean for θ\theta is similar to what we see from the observed data calculated below.

round(df[df["postdrug"] == 0].shape[0] / df.shape[0], 4)
Out:

0.5833

As a final consistency check, the approximate posterior mean for pp can be used to estimate counts for patients that did not have a zero value for postdrug PVC events. As we see below, these counts are consistent with the observed pre- and postdrug counts.

pp = samples.get(p()).numpy().mean()
posterior_df = df.copy()
posterior_df["tp"] = df[df["postdrug"] != 0]["total"] * pp
posterior_df["t(1 - p)"] = df[df["postdrug"] != 0]["total"] * (1 - pp)
Markdown(
posterior_df[posterior_df["postdrug"] != 0][
["patient_number", "predrug", "t(1 - p)", "postdrug", "tp"]
].astype(int).to_markdown()
)
patient_numberpredrugt(1 - p)postdrugtp
016654
129624
457523
565312
1011913138

Conclusion

We learned several things in this tutorial.

  • What zero-inflated data look like.
  • How to remove nuisance parameters.
  • How to add more mass to the likelihood around zero.
  • An example showing why marginalization of discrete variables from your model is a good practice.
  • How to define models with complex likelihood factors in Bean Machine using auxiliary Bernoulli observations.

References