Skip to main content

Maximum likelihood estimation and maximum a priori inference

Tutorial: Maximum likelihood and maximum a priori point estimation

This tutorial demonstrates how maximum likelihood estimation (MLE) and maximum a priori (MAP) inference problems can be specified and solved using Bean Machine's variational inference. MLE is treated as a special case of MAP where priors are uninformative, and MAP is treated as a special case of VI point estimation (i.e. with Delta guide distributions).

Prerequisites​

# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import beanmachine.ppl as bm
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.vi import MAP
import os

# Plotting settings
plt.rc("figure", figsize=[8, 6])
plt.rc("font", size=14)
plt.rc("lines", linewidth=2.5)

# Manual seed
bm.seed(11)
torch.manual_seed(11)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

Data​

Consider data following the OLS model Y∼N([X;1]⊀βtrue,1.0)Y \sim N([X; 1]^\top \beta_{\text{true}}, 1.0).

N = 900
true_beta = torch.tensor([2.0, 5.0])

@bm.random_variable
def X():
return dist.Normal(0, 1).expand((N,1))

def beta():
return true_beta

@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)

data = bm.simulate([X(), Y()], num_samples=1).get_chain(0)
data_X, data_Y = data[X()], data[Y()]
xs = torch.linspace(-3, 3, steps=100)

fig, ax = plt.subplots()
ax.scatter(data_X, data_Y, label='data')
ax.legend()
Out:

Samples collected: 0%| | 0/1 [00:00<?, ?it/s]

Out:

<matplotlib.legend.Legend at 0x145765100>

Maximum Likelihood Estimation (MLE)​

By placing a Flat uninformative prior over Ξ²\beta, the prior has no effect and the likelihood and posterior become equal. Therefore, we can use MAP to search for the MLE estimate.

@bm.random_variable
def beta():
return Flat(shape=(2,1))

@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)


v_world = MAP(
queries=[beta()],
observations={
X(): data_X,
Y(): data_Y,
},
).infer(num_steps=1500)
with v_world:
print(beta())
Out:

0%| | 0/1500 [00:00<?, ?it/s]

Out:

tensor([[1.9999],

[4.9858]], requires_grad=True)

with v_world, torch.no_grad():
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='yellow', label='MLE')
ax.legend()
fig

MAP performs gradient descent on a quadratic potential centered at the standard OLS estimator, which is known to be given by the Moore-Penrose pseudo-inverse: $Ξ²^OLS=X†y=(X⊀X)βˆ’1X⊀y\hat\beta_{OLS} = X^\dagger y = (X^\top X)^{-1} X^\top y$

MAP performs gradient descent on a quadratic potential centered at the standard OLS estimator, which is known to be given by the Moore-Penrose pseudo-inverse: $Ξ²^OLS=X†y=(X⊀X)βˆ’1X⊀y\hat\beta_{OLS} = X^\dagger y = (X^\top X)^{-1} X^\top y$

beta_ols = torch.linalg.pinv(torch.cat([data_X, torch.ones((N,1))], dim=1)) @ data_Y
print(beta_ols)
Out:

tensor([2.0001, 5.0015])

ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_ols, color='red', label='OLS')
ax.legend()
fig

MAP Inference​

To incorporate prior beliefs, such as β∼N(0,Ο„)\beta \sim N(0,\tau) in this example, the Flat uninformative prior is replaced with a prior distribution that with high precision (Οƒ=0.1\sigma = 0.1 is small) assigns the majority of probability belief near zero. This prior results in a combined posterior which "shrinks" estimates away from the MLE and back towards the prior (in this case, both regression coefficients Ξ²i\beta_i are shrunken towards 00.

sigma = 0.1

@bm.random_variable
def beta():
return dist.Normal(0, sigma).expand((2,1))

@bm.random_variable
def Y():
return dist.Normal((torch.cat([X(), torch.ones((N,1))], dim=1) @ beta()).squeeze(), 1.0)

v_world = MAP(
queries=[beta()],
observations={
X(): data_X,
Y(): data_Y,
},
).infer(num_steps=1000)
with v_world:
print(beta())
Out:

0%| | 0/1000 [00:00<?, ?it/s]

Out:

tensor([[1.7862],

[4.4997]], requires_grad=True)

with v_world, torch.no_grad():
ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta(), color='cyan', label='MAP')
ax.legend()
fig

It turns out some MAP estimates correspond to frequentist regularization techniques. In the case of linear regression, MAP inference with a Normal prior is equivalent to Tikhonov or L2L_2 regularization where the regularization parameters Ξ»=Οƒβˆ’2\lambda = \sigma^{-2} the prior variance: $Ξ²^L2,Ξ»=(X⊀X+Οƒβˆ’2I)βˆ’1X⊀y\hat{\beta}_{L_2,\lambda} = (X^\top X + \sigma^{-2} I)^{-1} X^\top y$

X_full = torch.cat([data_X, torch.ones((N,1))], dim=1)
beta_l2 = (torch.linalg.inv(X_full.T @ X_full + sigma**-2 * torch.eye(2)) @ X_full.T @ data_Y)
print(beta_l2)
Out:

tensor([1.7862, 4.4998])

ax.plot(xs, torch.cat([xs.unsqueeze(1), torch.ones((len(xs),1))], dim=1) @ beta_l2, color='magenta', label='L2')
ax.legend()
fig