Variational inference in a generalized linear mixed model
Tutorial: Variational inference in a generalized linear mixed modelβ
Adapted from Tensorflow Probability's Linear Mixed Effects Model tutorial. This tutorial demonstrates how to use Bean Machine's variational inference to perform uncertainty-aware estimation in generalized linear models with both fixed and random effects.
Prerequisitesβ
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.distributions as dist
from beanmachine.ppl.inference.vi import ADVI
from beanmachine.tutorials.utils.radon import load_data
import os
# Plotting settings
sns.set_context('notebook')
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)
# Manual seed
bm.seed(11)
torch.manual_seed(11)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
Dataβ
We will consider variational inference (VI) against a regression model defined on the
radon
dataset.
df = load_data()
df.head()
county_index | county | floor | activity | log_activity | Uppm | log_Uppm | |
---|---|---|---|---|---|---|---|
0 | 0 | AITKIN | 1 | 2.2 | 0.832909 | 0.502054 | -0.689048 |
1 | 0 | AITKIN | 0 | 2.2 | 0.832909 | 0.502054 | -0.689048 |
2 | 0 | AITKIN | 0 | 2.9 | 1.09861 | 0.502054 | -0.689048 |
3 | 0 | AITKIN | 0 | 1 | 0.09531 | 0.502054 | -0.689048 |
4 | 1 | ANOKA | 0 | 2.8 | 1.06471 | 0.428565 | -0.847313 |
Let's visualize the distribution over floor
and county
. There are two floor
s with
lots of data and many counties with little data.
fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 5]}, figsize=(12, 3))
df['floor'].value_counts().plot(kind='bar', ax=ax[0])
df['county'].value_counts().plot(kind='bar', ax=ax[1])
fig.show()
Model specificationβ
Since many counties have little data, to avoid overfitting county effects we will model it as a random effect within a GLMM:
Note that the scale here is global across all counties and the random effect is normal; the hierarchical linear mixed effects models we will look at later will generalize this.
features = df[['county_index', 'floor']].astype(int)
labels = df[['log_activity']].astype(np.float32).values.flatten()
floor = torch.tensor(features.floor.values)
county_index = torch.tensor(features.county_index.values)
@bm.random_variable
def county_scale():
return dist.HalfNormal(scale=1.)
@bm.random_variable
def intercept():
return dist.Normal(loc=0., scale=1.)
@bm.random_variable
def floor_weight():
return dist.Normal(loc=0., scale=1.)
@bm.random_variable
def county_prior():
return dist.Independent(dist.Normal(
loc=torch.zeros(county_index.unique().numel()),
scale=county_scale(),
), 1)
@bm.random_variable
def linear_response():
fixed_effect = intercept() + floor_weight() * floor
random_effect = torch.gather(county_prior(), 0, county_index)
return dist.Independent(dist.Normal(
loc=fixed_effect + random_effect,
scale=1.,
), 1)
Variational inferenceβ
We will use ADVI to approximate the posterior. This method fits a mean-field product of normals guide distribution using gradient descent on a divegence measure between probability distributions. It returns a distributional approximation which gives information on both the parameter location as well as uncertainty in the estimate.
losses = []
v_world = ADVI(
queries=[
county_prior(),
floor_weight(),
intercept(),
county_scale(),
],
observations={
linear_response(): torch.tensor(labels),
},
optimizer=lambda params: torch.optim.Adam(params, lr=5e-2),
).infer(
num_steps=400,
step_callback=lambda it, loss, vi_cls: losses.append(loss.item())
)
We can verify that convergence is obtained by plotting the loss and checking that it is decreasing and saturating:
fig, ax = plt.subplots()
sns.lineplot(
ax=ax,
data=losses,
).set(
yscale='log',
xlabel='Step',
ylabel='-ELBO',
)
fig.show()
Inspecting posterior approximationsβ
Our variational approximations for the linear fixed effects model
v_world.get_guide_distribution(intercept())
v_world.get_guide_distribution(floor_weight())
Since county_scale()
is a HalfNormal
with constrained support, following (Kucukelbir
2016, https://arxiv.org/abs/1603.00788) its variational approximation is the pushfoward
of a Normal under a support-transforming bijection. This is implemented using a
torch.distributions.TransformedDistribution
.
v_world.get_guide_distribution(county_scale())
To estimate its shape, we Monte-Carlo approximate its first two moments:
scale_prior_sample = v_world.get_guide_distribution(county_scale()).sample((10_000,))
print(
scale_prior_sample.mean(),
scale_prior_sample.var(),
)
Visualizing resultsβ
Below we visualize the posterior county random effects estimated using ADVI. Since ADVI uses Gaussian guide distributions, uncertainty is quantified and the plot below also plots the standard deviation for each county random effect.
county_counts = (df.groupby(by=['county', 'county_index'], observed=True)
.agg('size')
.sort_values(ascending=False)
.reset_index(name='count'))
means = v_world.get_guide_distribution(county_prior()).base_dist.mean.detach().numpy()
stds = v_world.get_guide_distribution(county_prior()).base_dist.stddev.detach().numpy()
fig, ax = plt.subplots(figsize=(20, 5))
for idx, row in county_counts.iterrows():
mid = means[row.county_index]
std = stds[row.county_index]
ax.vlines(idx, mid - std, mid + std, linewidth=3)
ax.plot(idx, means[row.county_index], 'ko', mfc='w', mew=2, ms=7)
ax.set(
xticks=np.arange(len(county_counts)),
xlim=(-1, len(county_counts)),
ylabel="County effect",
title=r"Estimates of county effects on log radon levels. (mean $\pm$ 1 std. dev.)",
)
ax.set_xticklabels(county_counts.county, rotation=90);
One desirable heuristic of an uncertainty quantification method is that uncertainty should decrease with more data. The plot below verifies this trend by plotting the estimated posterior uncertainty (i.e. ADVI Normal approximation's standard deviation) versus the number of observations in that county. The size of each county implicitly affects the VI objective through the data likelihood, and the plot below shows that an ELBO maximizing approximation in general assigns lower uncertainty to county random effects with more data.
fig, ax = plt.subplots(figsize=(10, 7))
ax.plot(np.log1p(county_counts['count']), stds[county_counts.county_index], 'o')
ax.set(
ylabel='Posterior std. deviation',
xlabel='County log-count',
title='Having more observations generally\nlowers estimation uncertainty'
);