Zero inflated count data
Tutorial: Marginalizing discrete variables in zero inflated count data
In this tutorial we will be investigating data that originated from Berry, and was analyzed by Farewell and Sprott from a study about the efficacy of a medication that helps prevent irregular heartbeats. Counts of patients' irregular heartbeats were observed 60 seconds before the administration of the drug, and 60 seconds after the medication was taken. A large percentage of records show zero irregular heartbeats in the 60 seconds after taking the medication. There are more observed zeros than would be expected if we were to sample from one of the common statistical discrete distributions, see Wikipedia-Distributions for a list of common discrete distributions. The problem we face is trying to model these zero counts in order to appropriately quantify the medication's impact on reducing irregular heartbeats.
Learning outcomes
On completion of this tutorial, you should be able:
- to execute a zero-inflated count data model with Bean Machine;
- to increment log probabilities using Bean Machine;
- to run diagnostics and to understand what Bean Machine is doing.
Problem
We will address how to model response data that has more zeros than would be expected based on a common discrete statistical distribution. If your data have a large number of zeros in them, then they are called zero-inflated. Recognizing when your data is zero-inflated is straightforward using a histogram as we will see below in our data section. Zero inflation occurs in many different scientific fields and application areas. Examples include social science, traffic accident research, econometrics, psychology, and a well known investigation done by Lambert on manufacturing defects that we will take inspiration from.
Prerequisites
We will be using the following packages within this tutorial.
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import os
import warnings
from io import StringIO
import arviz as az
import beanmachine.ppl as bm
import pandas as pd
import torch
import torch.distributions as dist
from beanmachine.ppl.model import RVIdentifier
from beanmachine.tutorials.utils import hearts
from bokeh.io import output_notebook
from bokeh.plotting import gridplot, show
from IPython.display import Markdown
The next cell includes convenient configuration settings to improve the notebook
presentation as well as setting a manual seed for torch
, for reproducibility.
# Ignore excessive warnings from ArviZ.
warnings.simplefilter("ignore")
# Plotting settings.
az.rcParams["plot.backend"] = "bokeh"
# See McElreath for an exposition on highest density intervals, and why we use 89%.
az.rcParams["stats.hdi_prob"] = 0.89
# Manual seed for torch.
torch.manual_seed(1199)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
Data
The data set we use in this tutorial originates from a study by Berry with subsequent analysis done by Farewell and Sprott. The data consist of EKG measurements where patients' heartbeats were counted for one minute before and after the administration of a drug, see Wikipedia-EKG for a description of what an EKG is. The patients in the study suffer from an affliction called premature ventricular contraction or PVC, see Wikipedia-PVC for more information about PVC events. Acute PVC events are common, but chronic PVC events are considered risk factors for other diseases.
The data from Berry are replicated below, with the addition of one column we will use in
our model as outlined in Farewell and Sprott. This column is called total
and amounts
to the sum of the predrug
and postdrug
PVC counts.
Column name | Description |
---|---|
patient_number | Patient ID. |
predrug | PVC event counts before administering the medication. |
postdrug | PVC event counts after administering the medication. |
decrease | The difference predrug - postdrug . |
total | The sum of predrug + postdrug . |
The data only contain 12 records. Instead of loading the data from a CSV file, we will
load it from the data_string
below into a pandas dataframe object.
data_string = """patient_number,predrug,postdrug,decrease,total
1,6,5,1,11
2,9,2,7,11
3,17,0,17,17
4,22,0,22,22
5,7,2,5,9
6,5,1,4,6
7,5,0,5,5
8,14,0,14,14
9,9,0,9,9
10,7,0,7,7
11,9,13,-4,22
12,51,0,51,51"""
df = pd.read_csv(StringIO(data_string))
Markdown(df.set_index("patient_number").to_markdown())
patient_number | predrug | postdrug | decrease | total |
---|---|---|---|---|
1 | 6 | 5 | 1 | 11 |
2 | 9 | 2 | 7 | 11 |
3 | 17 | 0 | 17 | 17 |
4 | 22 | 0 | 22 | 22 |
5 | 7 | 2 | 5 | 9 |
6 | 5 | 1 | 4 | 6 |
7 | 5 | 0 | 5 | 5 |
8 | 14 | 0 | 14 | 14 |
9 | 9 | 0 | 9 | 9 |
10 | 7 | 0 | 7 | 7 |
11 | 9 | 13 | -4 | 22 |
12 | 51 | 0 | 51 | 51 |
Below are histograms of the PVC event counts pre- and post-medication for all patients. What we observe from these histograms is a lot of zero values after the medication has been administered. These zero values will be important for our model defined in the next section, and are characteristic of zero-inflated data. If we disregarded modeling these zero values we would be throwing out an important feature of our data, and our parameter estimates would be biased.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
value_counts_plot = gridplot(hearts.plot_value_counts(df))
show(value_counts_plot)
The literature classifies zero count observations broadly into two kinds of plausible explanations, which we are interested in modeling statistically (see Farewell and Sprott).
- A general chance for "curing" by the intervention. In this scenario, the zero count is considered certain, and the observation is understood as a true positive event.
- A patient-specific chance for a "varied response level" or "relief". In this scenario, the PVC event frequency is understood to be merely reduced, and the zero count observation is therefore understood as a false positive event.
Note that we cannot assess the biological significance of the statement "cured" in this tutorial.
Models
First model - Binomial linear model
The data come in pairs , where
Note that both and data are (positive integers including zero). The data come in pairs because each of the twelve patients had their heartbeats monitored for 60 seconds before and 60 seconds after the administration of an anti-arithmetic drug. The number of PVC events within the 60 second windows are assumed to originate from independent Poisson distributions with some mean , thus
is a scaling factor for the rate of PVC events after the administration of the medication. We can eliminate the nuisance parameters (all the s, which only indirectly contribute to the answer of our modeling question) by using the conditional distribution of given the total PVC counts for any patient; . To accomplish this, first let and denote the Poisson-distributed random variables from which and are assumed to be samples. Then our conditional distribution can be written as
The denominator is a probability density for the sum of the independent Poisson distributions and . A property of Poisson distributions is that the addition of two Poisson distributions is another Poisson with the parameters added together, see Wikipedia-Poisson distribution. Using this property, our conditional probability becomes
where is
The conditional probability distribution no longer has any of the nuisance parameters . We will use this likelihood in our first model, with an uninformed prior on leading to the generative model
NOTE
We can implement this model in Bean Machine by defining random variable
objects with the decorator @bm.random_variable
and deterministic values using the
bm.functional
decorator. These functions behave differently than ordinary Python
functions.
@bm.random_variable
functions:- They must return PyTorch
Distribution
objects. - Though they return distributions, callees actually receive a sample from the distribution. The machinery for obtaining samples from distributions is handled internally by Bean Machine.
- Inference runs the model through many iterations. During a particular inference iteration, a distinct random variable will correspond to exactly one sampled value: calls to the same random variable function with the same arguments will receive the same sampled value within one inference iteration. This makes it easy for multiple components of your model to refer to the same logical random variable.
- Consequently, to define distinct random variables that correspond to different sampled values during a particular inference iteration, an effective practice is to add a dummy "indexing" parameter to the function. Distinct random variables can be referred to with different values for this index.
- Please see the documentation for more information about this decorator.
@bm.functional
functions:- This is a decorator that lets you treat deterministic code as if it is a Bean Machine random variable. This is used to transform the results of one or more random variables.
- Follows the same naming practice as
@bm.random_variable
where variables are distinguished by their argument call values. - Please see the documentation for more information about this decorator.
# Define a variable with the total number of records.
N = df.shape[0]
# Create a torch tensor containing the total PVC events for each record.
t = torch.tensor(df["total"].astype(float).values)
@bm.random_variable
def first_model_p() -> RVIdentifier:
return dist.Uniform(low=0.0, high=1.0)
@bm.random_variable
def first_model_y(i: int) -> RVIdentifier:
return dist.Binomial(total_count=t[i], probs=first_model_p())
Prior predictive checks
Before we run inference, we will conduct prior predictive checks. These checks will
determine if our priors and model can capture relevant statistics about the observed
data. We will use Bean Machine and its infer
method to run the prior predictive checks
even though we are not actually running posterior inference. Bean Machine's infer
method has the following API.
Name | Usage |
---|---|
queries | List of @bm.random_variable targets to fit posterior distributions for. |
observations | A dictionary of observations. |
num_samples | Number of Monte Carlo samples to approximate the posterior distributions for the variables in queries. |
num_chains | Number of separate inference runs to use. Multiple chains can help verify that inference ran correctly. |
When we run prior predictive checks, we pass the infer
method an empty dictionary for
the observations
keyword. This corresponds to a target density for Bean Machine's MCMC
inference which does not include any observation likelihood terms next to the model
prior density (here our prior density is first_model_p
). We can then manually feed the
prior samples thus obtained into the model likelihood (here
lambda p, i: dist.Binomial(total_count=t[i], probs=p)
), to obtain synthetic data from
the prior predictive distribution.
first_model_num_samples = 1 if smoke_test else 500
first_model_num_adaptive_samples = 0 if smoke_test else first_model_num_samples // 2
first_model_num_chains = 1 if smoke_test else 4
# Experimental backend using the Pytorch NNC compiler
nnc_compile = "SANDCASTLE_NEXUS" not in os.environ
# Model queries.
first_model_queries = [first_model_p()]
# Note our observations dictionary is empty.
first_model_observations = {}
# Note we are using the GlobalNoUTurnSampler, this will change in our second model.
first_model_samples_no_observations = bm.GlobalNoUTurnSampler(nnc_compile=nnc_compile).infer(
queries=first_model_queries,
observations=first_model_observations,
num_samples=first_model_num_samples,
num_chains=first_model_num_chains,
num_adaptive_samples=first_model_num_adaptive_samples,
)
Below is a plot of synthetic data simulated from the prior predictive distribution, along with the real observed data (magenta square) and total PVC count data (brown diamond) for each patient. The histograms represent simulated data, where the thick blue line below the histogram is its 89% HDI (highest density interval) region, and the white marker is its mean. HDI intervals are discussed at length in McElreath. The support defined by our prior spans across observed values for postdrug PVC counts and PVC counts for predrug + postdrug. The prior for is not very informative for our model, but it does capture the relevant bounds of the data.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Predictive checks without observations.
first_model_predictive_checks_no_observations = hearts.Model1PredictiveChecks(
samples_without_observations=first_model_samples_no_observations,
data=df,
p_query=first_model_p,
)
first_model_prior_pc = (
first_model_predictive_checks_no_observations.plot_prior_predictive_checks()
)
show(first_model_prior_pc)
Next, we run inference using this model, and compare the prior predictive plot with a posterior predictive plot.
# Note the observations dictionary is no longer empty.
first_model_observations = {
first_model_y(i): torch.tensor(observed)
for i, observed in enumerate(df["postdrug"].astype(float).tolist())
}
# Run inference with observations.
first_model_samples = bm.GlobalNoUTurnSampler(nnc_compile=nnc_compile).infer(
queries=first_model_queries,
observations=first_model_observations,
num_samples=first_model_num_samples,
num_chains=first_model_num_chains,
num_adaptive_samples=first_model_num_adaptive_samples,
)
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Predictive checks with observations.
first_model_predictive_checks = hearts.Model1PredictiveChecks(
samples_with_observations=first_model_samples,
data=df,
p_query=first_model_p,
)
# Remove the prior plot legend since it is duplicated in the posterior plot.
first_model_prior_pc.legend.visible = False
# Posterior plot.
first_model_posterior_pc = (
first_model_predictive_checks.plot_posterior_predictive_checks()
)
# Link plots together so zooms occur in both plots.
first_model_posterior_pc.y_range = first_model_prior_pc.y_range
first_model_posterior_pc.x_range = first_model_prior_pc.x_range
# Compare prior/posterior predictive checks.
prior_post_plots = gridplot([[first_model_prior_pc, first_model_posterior_pc]])
show(prior_post_plots)
The left plot is exactly what we saw above, and the right plot is simulating data using the posterior predictive. The mean simulated density values shift towards data (magenta markers) in the right plot as expected. However, the model is not capturing the observed zero values as shown in the histograms of the right plot for each patient. This suggests that our model is not specified well and needs to be updated to capture the observed zero values.
Second model - Zero-inflated Binomial linear model
We begin our model refinement by revisiting the conditional probability derived above,
We will multiply by a random variable that samples from a Bernoulli for each patient,
Whenever is sampled, this model is equivalent to the previous model. When , then the parameter of the Binomial distribution becomes and is obtained with certainty. Hence decides for each patient whether their intervention effect is "relief" or "curing", and is a mixing parameter between these two scenarios. In particular, for , all patients would be always "cured".
The model now has two mechanisms for causing a zero count parameterized by and . We could sample and from like we did in the first model, however, we will center their prior distributions about zero. We can accomplish this by setting the log-odds for both and equal to two new random variables and ,
Both and can be sampled from normal distributions since the normal
distribution has a support over , just like the log-odds values. To
get or back, we calculate the inverse logit of or .
is defined in torch
as the sigmoid
method. We can write our
model now as
We can code this in Bean Machine as shown below.
@bm.random_variable
def second_model_alpha() -> RVIdentifier:
return dist.Normal(0, 10)
@bm.random_variable
def second_model_delta() -> RVIdentifier:
return dist.Normal(0, 10)
@bm.functional
def second_model_p() -> torch.Tensor:
return torch.sigmoid(second_model_alpha())
@bm.functional
def second_model_theta() -> torch.Tensor:
return torch.sigmoid(second_model_delta())
@bm.random_variable
def second_model_s(i: int) -> RVIdentifier:
return dist.Bernoulli(1 - second_model_theta())
@bm.random_variable
def second_model_y(i: int) -> RVIdentifier:
return dist.Binomial(t[i], second_model_p() * second_model_s(i))
Prior predictive checks
We will use the SingleSiteAncestralMetropolisHastings
inference method because our
model is mixing both discrete variables and continuous random variables
and . Such models cannot be handled by Hamiltonian Monte Carlo, which
prohibits the use of the GlobalNoUTurnSampler
.
second_model_num_samples = 2 if smoke_test else 2000
second_model_num_chains = 1 if smoke_test else 4
second_model_num_adaptive_samples = 0 if smoke_test else second_model_num_samples // 2
# Model queries.
second_model_queries = [
second_model_alpha(),
second_model_delta(),
second_model_p(),
second_model_theta(),
] + [second_model_s(i) for i in range(N)]
# Model without observations. Note we are not using the GlobalNoUTurnSampler.
second_model_samples_no_observations = bm.SingleSiteAncestralMetropolisHastings().infer(
queries=second_model_queries,
observations={},
num_samples=second_model_num_samples,
num_chains=second_model_num_chains,
num_adaptive_samples=second_model_num_adaptive_samples,
)
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Predictive checks without observations.
second_model_predictive_checks_no_observations = hearts.Model2PredictiveChecks(
samples_without_observations=second_model_samples_no_observations,
data=df,
p_query=second_model_p,
s_query_str="second_model_s",
)
second_model_prior_pc = (
second_model_predictive_checks_no_observations.plot_prior_predictive_checks()
)
sec_model_prior_plot = gridplot([[second_model_prior_pc]])
show(sec_model_prior_plot)
The above plot is successfully simulating more zeros using the given priors than our
first model. You can use the tools above the plot to zoom in on any patient. Doing so
will show that the prior predictive distribution has positive density everywhere between
the end points of zero and the total PVC counts for a patient (the brown diamonds). The
large bars near zero and are caused from the sigmoid
used when calculating
and , and are expected.
Next we will run inference using this model and run posterior predictive checks along with other model analysis described below. You will note that we have increased the number of samples and adaptive samples a lot compared to the first model in order for the diagnostics to show good mixing of the chains.
Inference
scnd_mdl_post_num_samples = 2 if smoke_test else 10000
scnd_mdl_post_num_chains = 1 if smoke_test else 4
scnd_mdl_post_num_adaptive_samples = 0 if smoke_test else second_model_num_samples // 2
# Create a tensor of observed values.
y_observed = torch.tensor(df["postdrug"].astype(float).values)
# Model observations is not empty.
second_model_observations = {second_model_y(i): y_observed[i] for i in range(N)}
# Model with observations. Note the GlobalUTurnSampler is not used for inference.
second_model_samples = bm.SingleSiteAncestralMetropolisHastings().infer(
queries=second_model_queries,
observations=second_model_observations,
num_samples=scnd_mdl_post_num_samples,
num_chains=scnd_mdl_post_num_chains,
num_adaptive_samples=scnd_mdl_post_num_adaptive_samples,
)
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Predictive checks with observations.
second_model_predictive_checks = hearts.Model2PredictiveChecks(
samples_with_observations=second_model_samples,
data=df,
p_query=second_model_p,
s_query_str="second_model_s",
)
# Remove the prior plot legend since it is duplicated in the posterior plot.
second_model_prior_pc.legend.visible = False
# Posterior plot.
second_model_posterior_pc = (
second_model_predictive_checks.plot_posterior_predictive_checks()
)
# Link plots together.
second_model_posterior_pc.y_range = second_model_prior_pc.y_range
second_model_posterior_pc.x_range = second_model_prior_pc.x_range
# Compare prior/posterior predictive checks.
sec_prior_post_plot = gridplot([[second_model_prior_pc, second_model_posterior_pc]])
show(sec_prior_post_plot)
The left plot is again the prior predictive simulation for our second model, and the right plot is the posterior predictive simulated data. You can see that this model is quite effectively capturing the observed zero data. The major caveat here is that we had to use a lot more computational resources in order to capture the zeros than in our previous model. The posterior predictive plot looks like it is doing a good job of simulating data from our model, but we will do a bit more analysis below to ensure the model is working as expected.
Analysis
We begin our analysis by printing out summary statistics. Two important statistics to take note of are the and effective sample size values in the below dataframe.
second_model_summary_df = az.summary(second_model_samples.to_xarray(), round_to=4)
# Rearanging for easy viewing of the dataframe.
sort_order = dict(zip(second_model_queries, range(1, len(second_model_queries) + 1)))
second_model_summary_df.reset_index(inplace=True)
second_model_summary_df["sort"] = second_model_summary_df["index"].map(sort_order)
second_model_summary_df.sort_values(by="sort", inplace=True)
second_model_summary_df.rename(columns={"index": "query"}, inplace=True)
second_model_summary_df.set_index("query", inplace=True)
second_model_summary_df.drop("sort", axis=1, inplace=True)
Markdown(second_model_summary_df.to_markdown())
query | mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat |
---|---|---|---|---|---|---|---|---|---|
second_model_alpha() | -0.48 | 0.2843 | -0.9562 | -0.087 | 0.0103 | 0.0073 | 757.703 | 749.653 | 1.0027 |
second_model_delta() | 0.3001 | 0.6234 | -0.6313 | 1.355 | 0.0133 | 0.0103 | 2160.44 | 2163.4 | 1.0016 |
second_model_p() | 0.3844 | 0.0661 | 0.2776 | 0.4783 | 0.0024 | 0.0017 | 757.703 | 749.653 | 1.0029 |
second_model_theta() | 0.5681 | 0.1398 | 0.3472 | 0.7949 | 0.003 | 0.0022 | 2160.44 | 2163.4 | 1.0016 |
second_model_s(0,) | 1 | 0 | 1 | 1 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(1,) | 1 | 0 | 1 | 1 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(2,) | 0.001 | 0.032 | 0 | 0 | 0.0003 | 0.0002 | 12367.6 | 12367.6 | 1.0001 |
second_model_s(3,) | 0 | 0 | 0 | 0 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(4,) | 1 | 0 | 1 | 1 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(5,) | 1 | 0 | 1 | 1 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(6,) | 0.0776 | 0.2676 | 0 | 0 | 0.0027 | 0.0019 | 9726.81 | 9726.81 | 1.0002 |
second_model_s(7,) | 0.002 | 0.045 | 0 | 0 | 0.0004 | 0.0003 | 13813.7 | 13813.7 | 1 |
second_model_s(8,) | 0.0172 | 0.13 | 0 | 0 | 0.0012 | 0.0009 | 11247.3 | 11247.3 | 1.0004 |
second_model_s(9,) | 0.0339 | 0.181 | 0 | 0 | 0.0017 | 0.0012 | 10956.4 | 10956.4 | 1.0002 |
second_model_s(10,) | 1 | 0 | 1 | 1 | 0 | 0 | 40000 | 40000 | nan |
second_model_s(11,) | 0 | 0 | 0 | 0 | 0 | 0 | 40000 | 40000 | nan |
Note that there are NaN
values for some calculations in the above dataframe
for some parameters. This is occurring because there is no variance in those
parameters. We can ignore these values as it is entirely possible for a chain of
to be all ones or all zeros, giving rise to zero variance within the chain and a NaN
value in .
Measuring variance between- and within-chains with
is a diagnostic tool that measures the between- and within-chain variances. It is a test that indicates a lack of convergence by comparing the variance between multiple chains to the variance within each chain. If the parameters are successfully exploring the full space for each chain, then , since the between-chain and within-chain variance should be equal. is calculated from samples as
where is the within-chain variance, is the between-chain variance and is the estimate of the posterior variance of the samples. The take-away here is that converges to 1 when each of the chains begins to empirically approximate the same posterior distribution. We do not recommend using inference results if . More information about can be found in the Vehtari et al paper.
Effective sample size diagnostic
MCMC samplers do not draw truly independent samples from the target distribution, which means that our samples are correlated. In an ideal situation all samples would be independent, but we do not have that luxury. We can, however, measure the number of effectively independent samples we draw, which is called the effective sample size. You can read more about how this value is calculated in the Vehtari et al paper. In brief, it is a measure that combines information from the value with the autocorrelation estimates within the chains.
ESS estimates come in two variants, ess_bulk and ess_tail. The former is the default, but the latter can be useful if you need good estimates of the tails of your posterior distribution. The rule of thumb for ess_bulk is for this value to be greater than 100 per chain on average. Since we ran four chains, we need ess_bulk to be greater than 400 for each parameter. The ess_tail is an estimate for effectively independent samples considering the more extreme values of the posterior. This is not the number of samples that landed in the tails of the posterior, but rather a measure of the number of effectively independent samples if we sampled the tails of the posterior. The rule of thumb for this value is also to be greater than 100 per chain on average.
We use arviz
to plot the evolution of the effective sample size for each of our
parameters. The below plots show over each successive draw from our model the estimated
effective sample size. The red horizontal dashed line shows the rule-of-thumb value
needed for both the bulk and tail ess estimates. When the model is converging
properly, both the bulk and tail lines should be roughly linear.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
keys = ["p", "θ"]
values = [second_model_p(), second_model_theta()]
parameters = dict(zip(keys, values))
plots = []
for title, parameter in parameters.items():
data = {title: second_model_samples.get(parameter).numpy()}
f = az.plot_ess(data, kind="evolution", show=False)[0][0]
f.plot_width = 500
f.plot_height = 500
f.y_range.start = 0
f.outline_line_color = "black"
f.grid.grid_line_alpha = 0.2
f.grid.grid_line_color = "grey"
f.grid.grid_line_width = 0.2
plots.append(f)
p_theta_plots = gridplot([plots[:2], plots[2:]])
show(p_theta_plots)
The ess diagnostics look linear and are all above the rule-of-thumb value of 400 (the red dashed line) for each parameter in our model. We continue the diagnostics by investigating the posteriors, rank plots, and autocorrelations for each parameter.
- Rank plots are a histogram of the samples over time. All samples across all chains are ranked and then we plot the average rank for each chain on regular intervals. If the chains are mixing well this histogram should look roughly uniform. If it looks highly irregular that suggests chains might be getting stuck and not adequately exploring the sample space.
- Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). We compute autocorrelation approximately, so it may sometimes exceed these bounds. In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
second_model_diag_plot = gridplot(hearts.plot_diagnostics(second_model_samples, values))
show(second_model_diag_plot)
The first plot for each row is the posterior for the query. The second multicolored line plot is the posterior for each chain, followed by the rank plot for each chain. Finally we have a set of plots showing the autocorrelation of the samples.
All of the posteriors show high frequency fluctuations. These fluctuations are also shown in the rank plots, where we have peaks and valleys about the black dashed line, which is added as a convenient way to quickly gauge if the rank plots are uniform within- and between-chains. These fluctuations indicate that the sampler is spending either too much time (peaks), or not enough time (valleys) stuck in the respective values of the sample space. This amounts to our chains not mixing sufficiently. The autocorrelation plots also show that samples are highly correlated, which is another indicator that the model is having a difficult time sampling from the true posterior.
Farewell and Sprott quote their posterior means and 95% confidence interval for and as
- and
- .
Our model is in fairly good agreement with these values with posterior means and standard deviations of
Markdown(
second_model_summary_df.loc[
second_model_summary_df.index.astype(str).isin(
["second_model_alpha()", "second_model_delta()"]
),
["mean", "sd"],
].to_markdown()
)
query | mean | sd |
---|---|---|
second_model_alpha() | -0.48 | 0.2843 |
second_model_delta() | 0.3001 | 0.6234 |
Despite having reasonable similarities to the published results, we had to run inference for a large number of samples. Doing so resulted in posteriors that have high frequency components in their distribution, and show correlated sample draws within each chain, which may indicate that our model is not as good as we can make it. We actually can make a better model by marginalizing out the discrete variables .
Third model - Marginalization of the discrete variable
Marginalizing out will allow us to use the GlobalNoUTurnSampler
as well as make
a less computationally expensive model. Marginalizing lets us compute the expectation
exactly so our model will have zero variance coming from the discrete variables
mitigating some of the high frequency fluctuations we saw from model 2. Recall we had
Marginalizing out gives
We can obtain a more familiar form if we split the probability into cases when and when .
This model can therefore sample zero counts either due to
- observed values ; or
- a Binomial distribution when .
We can rewrite the conditional probability using indicator functions
, where is if is True
and
otherwise. The indicator functions are acting as an if else
statement in code.
is just a number to add to the log-density or multiply to the density. The simplest way to implement adding to the log-density or multiplying to the density is to place the likelihood factor inside an auxiliary Bernoulli observation. If we look at the probability mass function (PMF) for a Bernoulli using as the probability we have that
If we set , then the PMF becomes
Our full generative model can now be expressed as
where is a dummy variable and an auxiliary observation that will always be .
In the previous model we had samples from to use
when simulating . A challenge that arises when using a model where we have
marginalized out the discrete variables is that it becomes less clear how to simulate
from the posterior predictive distribution. We tackle this challenge by noticing that
is computed as part of our likelihood calculation
every iteration, which is proportional to . We can
capture this information using @bm.functional
and query for it during posterior
inference. We then use these probabilities to simulate with a Bernoulli
distribution passing the probability as an argument. These samples of can then be
used identically to how they were used in the previous model. This is a trade-off for
marginalizing out the discrete variable; we still have to keep a catalog of it if we
want to simulate data, but we will see that running inference is a less computationally
expensive task.
@bm.random_variable
def alpha() -> RVIdentifier:
return dist.Normal(0, 10)
@bm.random_variable
def delta() -> RVIdentifier:
return dist.Normal(0, 10)
@bm.functional
def p():
return torch.sigmoid(alpha())
@bm.functional
def theta():
return torch.sigmoid(delta())
@bm.functional
def ell0(i: int):
return theta() + (1.0 - theta()) * torch.pow(1.0 - p(), t[i])
@bm.functional
def ell1(i: int):
return (1.0 - theta()) * dist.Binomial(t[i], p()).log_prob(y_observed[i]).exp()
@bm.functional
def s(i: int):
# s_i = 0
s0 = theta() if y_observed[i] == 0 else torch.tensor(0.0)
# s_i = 1
s1 = ell1(i)
return torch.tensor([s0, s1])
@bm.random_variable
def d(i: int) -> RVIdentifier:
ell = ell0(i) if y_observed[i] == 0 else ell1(i)
return dist.Bernoulli(ell)
Inference
We set all the auxiliary observations to (corresponding to ) and use
y_observed[i]
(corresponding to ) inside the model to compute the likelihood
factor.
thrd_mdl_num_samples = 2 if smoke_test else 4000
thrd_mdl_num_chains = 1 if smoke_test else 4
thrd_mdl_num_adaptive_samples = 0 if smoke_test else thrd_mdl_num_samples // 2
queries = [alpha(), delta(), p(), theta()] + [s(i) for i in range(N)]
# Observations are all equal to 1.
observations = {d(i): torch.tensor(1.0) for i in range(df.shape[0])}
# NNC compilation gets into some issues with this model so we will turn it off for now
samples = bm.GlobalNoUTurnSampler(nnc_compile=False).infer(
queries=queries,
observations=observations,
num_samples=thrd_mdl_num_samples,
num_chains=thrd_mdl_num_chains,
num_adaptive_samples=thrd_mdl_num_adaptive_samples,
)
Below we plot the posterior predictive checks for model 3. We are able to recover the
posterior predictive distributions we saw in model 2 in 3 because we kept track of
using Bean Machine. Recall that our
bm.functional
object for s_i
is a way of book keeping, so we can reproduce model 2's
ability to run posterior predictive checks. If we have successfully tracked , then
the two posterior predictive checks from model 2 and 3 should be nearly identical.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Posterior predictive checks.
third_model_predictive_checks = hearts.Model3PredictiveChecks(
samples_with_observations=samples,
data=df,
p_query=p,
s_query_str="s",
)
# Remove the second model posterior plot legend since it is duplicated in the third
# model's posterior plot.
second_model_posterior_pc.legend.visible = False
# Posterior plot.
third_model_posterior_pc = (
third_model_predictive_checks.plot_posterior_predictive_checks()
)
# Link plots together.
second_model_posterior_pc.y_range = third_model_posterior_pc.y_range
second_model_posterior_pc.x_range = third_model_posterior_pc.x_range
# Compare the prior/posterior predictive plots.
sec_thrd_model_post_plot = gridplot([[second_model_posterior_pc, third_model_posterior_pc]])
show(sec_thrd_model_post_plot)
Analysis
Below is a print out of summary statistics for the queries.
summary_df = az.summary(samples.to_xarray(), round_to=4)
summary_df.reset_index(inplace=True)
summary_df["sort"] = summary_df["index"].map(sort_order)
summary_df.dropna(subset=["sort"], axis=0, inplace=True)
summary_df.sort_values(by="sort", inplace=True)
summary_df.rename(columns={"index": "query"}, inplace=True)
summary_df.set_index("query", inplace=True)
summary_df.drop("sort", axis=1, inplace=True)
Markdown(summary_df.to_markdown())
query | mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat |
---|
The and values look excellent. Next we look at how the effective sample sizes evolved when sampling.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
keys = ["p", "θ"]
values = [p(), theta()]
parameters = dict(zip(keys, values))
plots = []
for title, parameter in parameters.items():
data = {title: samples.get(parameter).numpy()}
f = az.plot_ess(data, kind="evolution", show=False)[0][0]
f.plot_width = 400
f.plot_height = 400
f.y_range.start = 0
f.outline_line_color = "black"
f.grid.grid_line_alpha = 0.2
f.grid.grid_line_color = "grey"
f.grid.grid_line_width = 0.2
plots.append(f)
thrd_p_theta_plots = gridplot([plots[:2], plots[2:]])
show(thrd_p_theta_plots)
We are above the rule-of-thumb of greater than 400 for the effective sample size plots, and all plots also look linear. Below we look at the posteriors, rank plots, and autocorrelation plots for the queries of the model.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
diag_plots = gridplot(hearts.plot_diagnostics(samples, values))
show(diag_plots)
In comparison to the second model all the above posteriors look smoother, and the autocorrelation plots also show less autocorrelation when sampling. The third model is exploring the parameter space more effectively (and efficiently) than the second model.
Discussion
The three plots below show the progression we took for creating a zero inflated model. Model 1 was incapable of simulating the observed zero values, while model 2 did a great job of simulating zeros conditioned on the data . Model 2 took a lot of computational resources to produce, and had several issues with the posterior distributions for the parameters along with the parameter autocorrelation plots that led us to creating model 3. Model 3 used less computational resources, reproduced the prior and posterior predictive plots of model 2, and explored the parameter space more efficiently than model 2. The extra work of marginalizing out the discrete variable and keeping tabs on for simulating data was worth the effort.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
first_model_posterior_pc.legend.visible = False
second_model_posterior_pc.legend.visible = False
plots = [first_model_posterior_pc, second_model_posterior_pc, third_model_posterior_pc]
frst_scnd_thrd_plots = gridplot([plots])
show(frst_scnd_thrd_plots)
Comparing our third model with the results of Farewell and Sprott show we are still in good agreement.
- and
Markdown(
summary_df.loc[
summary_df.index.astype(str).isin(["alpha()", "delta()"]),
["mean", "sd"],
].to_markdown()
)
query | mean | sd |
---|
Rather than quote a 95% confidence interval for our findings, we will plot the 89% highest density interval for and below. Why do we use 89%? Mostly to prevent you from thinking about p-values and to think about the posterior densities. See McElreath for a good discussion about HDI intervals.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
a = az.plot_posterior({"α": samples.get(alpha()).numpy()}, show=False)[0][0]
d = az.plot_posterior({"δ": samples.get(delta()).numpy()}, show=False)[0][0]
a_d_plots = gridplot([[a, d]])
show(a_d_plots)
We plot the marginal posterior distributions for , , and below. Recall that is the coefficient of the Poisson mean for PVC events after the medication had been administered.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
th = az.plot_posterior({"θ": samples.get(theta()).numpy()}, show=False)[0][0]
p_ = az.plot_posterior({"p": samples.get(p()).numpy()}, show=False)[0][0]
b = az.plot_posterior({"β": samples.get(alpha()).exp().numpy()}, show=False)[0][0]
th_p_b_plots = gridplot([[p_, th, b]])
show(th_p_b_plots)
The above plot shows the marginal posterior mean for is similar to what we see from the observed data calculated below.
round(df[df["postdrug"] == 0].shape[0] / df.shape[0], 4)
As a final consistency check, the approximate posterior mean for can be used to estimate counts for patients that did not have a zero value for postdrug PVC events. As we see below, these counts are consistent with the observed pre- and postdrug counts.
pp = samples.get(p()).numpy().mean()
posterior_df = df.copy()
posterior_df["tp"] = df[df["postdrug"] != 0]["total"] * pp
posterior_df["t(1 - p)"] = df[df["postdrug"] != 0]["total"] * (1 - pp)
Markdown(
posterior_df[posterior_df["postdrug"] != 0][
["patient_number", "predrug", "t(1 - p)", "postdrug", "tp"]
].astype(int).to_markdown()
)
patient_number | predrug | t(1 - p) | postdrug | tp | |
---|---|---|---|---|---|
0 | 1 | 6 | 6 | 5 | 4 |
1 | 2 | 9 | 6 | 2 | 4 |
4 | 5 | 7 | 5 | 2 | 3 |
5 | 6 | 5 | 3 | 1 | 2 |
10 | 11 | 9 | 13 | 13 | 8 |
Conclusion
We learned several things in this tutorial.
- What zero-inflated data look like.
- How to remove nuisance parameters.
- How to add more mass to the likelihood around zero.
- An example showing why marginalization of discrete variables from your model is a good practice.
- How to define models with complex likelihood factors in Bean Machine using auxiliary Bernoulli observations.
References
- Berry DA (1987) Logarithmic transformations in ANOVA. Biometrics 43(2) 439–456. doi: 10.2307/2531826.
- Farewell, VT and Sprott DA (1988) The use of a mixture model in the analysis of count data. Biometrics 44(4) 1191–1194. doi: 10.2307/2531746.
- Lambert D (1992) Zero-inflated Poisson regression, with an application to defects in manufacturing. Technometrics 34(1) 1–14. doi: 10.2307/1269547.
- McElreath R (2020) Statistical Rethinking: A Bayesian Course with Examples in R and Stan 2nd edition. Chapman and Hall/CRC. doi: 10.1201/9780429029608.
- Vehtari A, Gelman A, Simpson D, Carpenter B, Bürkner PC (2021) Rank-normalization, folding, and localization: An improved for assessing convergence of MCMC (with discussion). Bayesian Analysis 16(2) 667–718. doi: 10.1214/20-BA1221.
- Wikipedia-Distributions https://en.wikipedia.org/wiki/List_of_probability_distributions.
- Wikipedia-EKG https://en.wikipedia.org/wiki/Ekg.
- Wikipedia-Nuisance Parameter https://en.wikipedia.org/wiki/Nuisance_parameter.
- Wikipedia-Poisson distribution https://en.wikipedia.org/wiki/Poisson_distribution#General.
- Wikipedia-PVC https://en.wikipedia.org/wiki/Premature_ventricular_contraction.