Maximum likelihood estimation and maximum a priori inference
Tutorial: Maximum likelihood and maximum a priori point estimation
This tutorial demonstrates how maximum likelihood estimation (MLE) and maximum a priori
(MAP) inference problems can be specified and solved using Bean Machine's variational
inference. MLE is treated as a special case of MAP where priors are uninformative, and
MAP is treated as a special case of VI point estimation (i.e. with Delta guide
distributions).
Prerequisites​
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
    !pip install beanmachine
import beanmachine.ppl as bm
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.vi import MAP
import os
# Plotting settings
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)
# Manual seed
bm.seed(11)
torch.manual_seed(11)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
Data​
Consider data following the OLS model .
N = 900
true_beta = torch.tensor([2.0, 5.0])
@bm.random_variable
def X():
    return dist.Normal(0, 1).expand((N,1))
def beta():
    return true_beta
@bm.random_variable
def Y():
    return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
data = bm.simulate([X(), Y()], num_samples=1).get_chain(0)
data_X, data_Y = data[X()], data[Y()]
xs = torch.linspace(-3, 3, steps=100)
fig, ax = plt.subplots()
ax.scatter(data_X, data_Y, label='data')
ax.legend()
Maximum Likelihood Estimation (MLE)​
By placing a Flat uninformative prior over , the prior has no effect and the
likelihood and posterior become equal. Therefore, we can use MAP to search for the MLE
estimate.
@bm.random_variable
def beta():
    return Flat(shape=(2,1))
@bm.random_variable
def Y():
    return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
v_world = MAP(
    queries=[beta()],
    observations={
        X(): data_X,
        Y(): data_Y,
    },
).infer(num_steps=1500)
with v_world:
    print(beta())
with v_world, torch.no_grad():
    ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='yellow', label='MLE')
ax.legend()
fig
MAP performs gradient descent on a quadratic potential centered at the standard OLS
estimator, which is known to be given by the Moore-Penrose pseudo-inverse:
$$
MAP performs gradient descent on a quadratic potential centered at the standard OLS
estimator, which is known to be given by the Moore-Penrose pseudo-inverse:
$$
beta_ols = torch.linalg.pinv(torch.cat([data_X, torch.ones((N,1))], dim=1)) @ data_Y
print(beta_ols)
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_ols, color='red', label='OLS')
ax.legend()
fig
MAP Inference​
To incorporate prior beliefs, such as  in this example, the Flat
uninformative prior is replaced with a prior distribution that with high precision
( is small) assigns the majority of probability belief near zero. This
prior results in a combined posterior which "shrinks" estimates away from the MLE and
back towards the prior (in this case, both regression coefficients  are
shrunken towards .
sigma = 0.1
@bm.random_variable
def beta():
    return dist.Normal(0, sigma).expand((2,1))
@bm.random_variable
def Y():
    return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)
v_world = MAP(
    queries=[beta()],
    observations={
        X(): data_X,
        Y(): data_Y,
    },
).infer(num_steps=1000)
with v_world:
    print(beta())
with v_world, torch.no_grad():
    ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='cyan', label='MAP')
ax.legend()
fig
It turns out some MAP estimates correspond to frequentist regularization techniques. In the case of linear regression, MAP inference with a Normal prior is equivalent to Tikhonov or regularization where the regularization parameters the prior variance: $$
X_full = torch.cat([data_X, torch.ones((N,1))], dim=1)
beta_l2 = (torch.linalg.inv(X_full.T @ X_full + sigma**-2 * torch.eye(2)) @ X_full.T @ data_Y)
print(beta_l2)
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_l2, color='magenta', label='L2')
ax.legend()
fig