Maximum likelihood estimation and maximum a priori inference
Tutorial: Maximum likelihood and maximum a priori point estimation
This tutorial demonstrates how maximum likelihood estimation (MLE) and maximum a priori
(MAP) inference problems can be specified and solved using Bean Machine's variational
inference. MLE is treated as a special case of MAP where priors are uninformative, and
MAP is treated as a special case of VI point estimation (i.e. with Delta
guide
distributions).
Prerequisites​
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.vi import MAP
import os
# Plotting settings
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)
# Manual seed
bm.seed(11)
torch.manual_seed(11)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
Data​
Consider data following the OLS model .
N = 900
true_beta = torch.tensor([2.0, 5.0])
@bm.random_variable
def X():
return dist.Normal(0, 1).expand((N,1))
def beta():
return true_beta
@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
data = bm.simulate([X(), Y()], num_samples=1).get_chain(0)
data_X, data_Y = data[X()], data[Y()]
xs = torch.linspace(-3, 3, steps=100)
fig, ax = plt.subplots()
ax.scatter(data_X, data_Y, label='data')
ax.legend()
Maximum Likelihood Estimation (MLE)​
By placing a Flat
uninformative prior over , the prior has no effect and the
likelihood and posterior become equal. Therefore, we can use MAP to search for the MLE
estimate.
@bm.random_variable
def beta():
return Flat(shape=(2,1))
@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
v_world = MAP(
queries=[beta()],
observations={
X(): data_X,
Y(): data_Y,
},
).infer(num_steps=1500)
with v_world:
print(beta())
with v_world, torch.no_grad():
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='yellow', label='MLE')
ax.legend()
fig
MAP
performs gradient descent on a quadratic potential centered at the standard OLS
estimator, which is known to be given by the Moore-Penrose pseudo-inverse:
$$
MAP
performs gradient descent on a quadratic potential centered at the standard OLS
estimator, which is known to be given by the Moore-Penrose pseudo-inverse:
$$
beta_ols = torch.linalg.pinv(torch.cat([data_X, torch.ones((N,1))], dim=1)) @ data_Y
print(beta_ols)
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_ols, color='red', label='OLS')
ax.legend()
fig
MAP Inference​
To incorporate prior beliefs, such as in this example, the Flat
uninformative prior is replaced with a prior distribution that with high precision
( is small) assigns the majority of probability belief near zero. This
prior results in a combined posterior which "shrinks" estimates away from the MLE and
back towards the prior (in this case, both regression coefficients are
shrunken towards .
sigma = 0.1
@bm.random_variable
def beta():
return dist.Normal(0, sigma).expand((2,1))
@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
v_world = MAP(
queries=[beta()],
observations={
X(): data_X,
Y(): data_Y,
},
).infer(num_steps=1000)
with v_world:
print(beta())
with v_world, torch.no_grad():
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='cyan', label='MAP')
ax.legend()
fig
It turns out some MAP estimates correspond to frequentist regularization techniques. In the case of linear regression, MAP inference with a Normal prior is equivalent to Tikhonov or regularization where the regularization parameters the prior variance: $$
X_full = torch.cat([data_X, torch.ones((N,1))], dim=1)
beta_l2 = (torch.linalg.inv(X_full.T @ X_full + sigma**-2 * torch.eye(2)) @ X_full.T @ data_Y)
print(beta_l2)
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_l2, color='magenta', label='L2')
ax.legend()
fig