Skip to main content

Variational inference in a generalized linear mixed model

Tutorial: Variational inference in a generalized linear mixed model​

Adapted from Tensorflow Probability's Linear Mixed Effects Model tutorial. This tutorial demonstrates how to use Bean Machine's variational inference to perform uncertainty-aware estimation in generalized linear models with both fixed and random effects.

Prerequisites​

# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.distributions as dist
from beanmachine.ppl.inference.vi import ADVI
from beanmachine.tutorials.utils.radon import load_data
import os

# Plotting settings
sns.set_context('notebook')
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)

# Manual seed
bm.seed(11)
torch.manual_seed(11)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

Data​

We will consider variational inference (VI) against a regression model defined on the radon dataset.

df = load_data()
df.head()
county_indexcountyflooractivitylog_activityUppmlog_Uppm
00AITKIN12.20.8329090.502054-0.689048
10AITKIN02.20.8329090.502054-0.689048
20AITKIN02.91.098610.502054-0.689048
30AITKIN010.095310.502054-0.689048
41ANOKA02.81.064710.428565-0.847313

Let's visualize the distribution over floor and county. There are two floors with lots of data and many counties with little data.

fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 5]}, figsize=(12, 3))
df['floor'].value_counts().plot(kind='bar', ax=ax[0])
df['county'].value_counts().plot(kind='bar', ax=ax[1])
fig.show()
Out:

/var/folders/f8/_w79hs797f5fp7nrm_2x_43r0000gn/T/ipykernel_95966/3673996874.py:4: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.

fig.show()

Model specification​

Since many counties have little data, to avoid overfitting county effects we will model it as a random effect within a GLMM:

log⁑radonj∼c+floor_effectj+N(county_effectj,county_scale)\log \text{radon}_j \sim c + \text{floor\_effect}_j + \mathcal{N}(\text{county\_effect}_j, \text{county\_scale})

Note that the scale here is global across all counties and the random effect is normal; the hierarchical linear mixed effects models we will look at later will generalize this.

features = df[['county_index', 'floor']].astype(int)
labels = df[['log_activity']].astype(np.float32).values.flatten()
floor = torch.tensor(features.floor.values)
county_index = torch.tensor(features.county_index.values)

@bm.random_variable
def county_scale():
return dist.HalfNormal(scale=1.)

@bm.random_variable
def intercept():
return dist.Normal(loc=0., scale=1.)

@bm.random_variable
def floor_weight():
return dist.Normal(loc=0., scale=1.)

@bm.random_variable
def county_prior():
return dist.Independent(dist.Normal(
loc=torch.zeros(county_index.unique().numel()),
scale=county_scale(),
), 1)


@bm.random_variable
def linear_response():
fixed_effect = intercept() + floor_weight() * floor
random_effect = torch.gather(county_prior(), 0, county_index)
return dist.Independent(dist.Normal(
loc=fixed_effect + random_effect,
scale=1.,
), 1)

Variational inference​

We will use ADVI to approximate the posterior. This method fits a mean-field product of normals guide distribution using gradient descent on a divegence measure between probability distributions. It returns a distributional approximation which gives information on both the parameter location as well as uncertainty in the estimate.

losses = []
v_world = ADVI(
queries=[
county_prior(),
floor_weight(),
intercept(),
county_scale(),
],
observations={
linear_response(): torch.tensor(labels),
},
optimizer=lambda params: torch.optim.Adam(params, lr=5e-2),
).infer(
num_steps=400,
step_callback=lambda it, loss, vi_cls: losses.append(loss.item())
)

Out:

0%| | 0/400 [00:00<?, ?it/s]

We can verify that convergence is obtained by plotting the loss and checking that it is decreasing and saturating:

fig, ax = plt.subplots()
sns.lineplot(
ax=ax,
data=losses,
).set(
yscale='log',
xlabel='Step',
ylabel='-ELBO',
)
fig.show()
Out:

/var/folders/f8/_w79hs797f5fp7nrm_2x_43r0000gn/T/ipykernel_95966/4147618433.py:10: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.

fig.show()

Inspecting posterior approximations​

Our variational approximations for the linear fixed effects model

v_world.get_guide_distribution(intercept())
Out:

Normal(loc: 1.4857617616653442, scale: 0.1068837121129036)

v_world.get_guide_distribution(floor_weight())
Out:

Normal(loc: -0.6434646248817444, scale: 0.08476858586072922)

Since county_scale() is a HalfNormal with constrained support, following (Kucukelbir 2016, https://arxiv.org/abs/1603.00788) its variational approximation is the pushfoward of a Normal under a support-transforming bijection. This is implemented using a torch.distributions.TransformedDistribution.

v_world.get_guide_distribution(county_scale())
Out:

TransformedDistribution()

To estimate its shape, we Monte-Carlo approximate its first two moments:

scale_prior_sample = v_world.get_guide_distribution(county_scale()).sample((10_000,))
print(
scale_prior_sample.mean(),
scale_prior_sample.var(),
)
Out:

tensor(0.3378) tensor(0.0023)

Visualizing results​

Below we visualize the posterior county random effects estimated using ADVI. Since ADVI uses Gaussian guide distributions, uncertainty is quantified and the plot below also plots the standard deviation for each county random effect.

county_counts = (df.groupby(by=['county', 'county_index'], observed=True)
.agg('size')
.sort_values(ascending=False)
.reset_index(name='count'))

means = v_world.get_guide_distribution(county_prior()).base_dist.mean.detach().numpy()
stds = v_world.get_guide_distribution(county_prior()).base_dist.stddev.detach().numpy()

fig, ax = plt.subplots(figsize=(20, 5))

for idx, row in county_counts.iterrows():
mid = means[row.county_index]
std = stds[row.county_index]
ax.vlines(idx, mid - std, mid + std, linewidth=3)
ax.plot(idx, means[row.county_index], 'ko', mfc='w', mew=2, ms=7)

ax.set(
xticks=np.arange(len(county_counts)),
xlim=(-1, len(county_counts)),
ylabel="County effect",
title=r"Estimates of county effects on log radon levels. (mean $\pm$ 1 std. dev.)",
)
ax.set_xticklabels(county_counts.county, rotation=90);

One desirable heuristic of an uncertainty quantification method is that uncertainty should decrease with more data. The plot below verifies this trend by plotting the estimated posterior uncertainty (i.e. ADVI Normal approximation's standard deviation) versus the number of observations in that county. The size of each county implicitly affects the VI objective through the data likelihood, and the plot below shows that an ELBO maximizing approximation in general assigns lower uncertainty to county random effects with more data.

fig, ax = plt.subplots(figsize=(10, 7))
ax.plot(np.log1p(county_counts['count']), stds[county_counts.county_index], 'o')
ax.set(
ylabel='Posterior std. deviation',
xlabel='County log-count',
title='Having more observations generally\nlowers estimation uncertainty'
);