Skip to main content

Gaussian process

Gaussian Process Regression (with GPytorch)​

This tutorial assumes familiarity with the following:

  1. Bean Machine modeling and inference
  2. Gaussian Processes
  3. GPyTorch

A Gaussian Process (GP) is a stochastic process commonly used in Bayesian non-parametrics, whose finite collection of random variables follow a multivariate Gaussian distribution. GPs are fully defined by a mean and covariance function:

f∼GP(ΞΌ(x),Kf(x,xβ€²))f\sim\mathcal{GP}\left(\mu(x),\mathbf{K}_f(x, x')\right)

where x,xβ€²βˆˆXx,x'\in\mathbf{X} are two data points (e.g.) train and test), ΞΌ\mu is the mean function (usually taken to be zero or constant), and Kf\mathbf{K}_f is the kernel function, which computes a covariance given two data points and a distance metric.

The aim is then to fit a posterior over functions. GPs allow us to learn a distribution over functions given our observed data and predict unseen data with well-calibrated uncertainty, and is commonly used in Bayesian Optimization as a surrogate function to maximize an objective. For a thorough introduction to Gaussian processes, please see [1]

With a PPL such as Bean Machine, we can be Bayesian about the parameters we care about, i.e. learn posterior distributions over these parameters rather than a point estimate.

# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import copy
import math
import os
import warnings
from functools import partial

import arviz as az
import beanmachine
import beanmachine.ppl as bm
import beanmachine.ppl.experimental.gp as bgp
import gpytorch
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.distributions as dist
from beanmachine.ppl.experimental.gp.models import SimpleGP
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel
from IPython.display import Markdown
Out:

I0927 185346.164 _utils_internal.py:179] NCCL_DEBUG env var is set to None

I0927 185346.166 _utils_internal.py:197] NCCL_DEBUG is forced to WARN from None

The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for reproducibility.

# Eliminate excess UserWarnings from Python.
warnings.filterwarnings("ignore")

# Manual seed
torch.manual_seed(123)

# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

# Tool versions
print("pytorch version: ", torch.__version__)
print("gpytorch version: ", gpytorch.__version__)
Out:

pytorch version: 1.13.0a0+fb

gpytorch version: Unknown

Let's use some simple cyclic data:

x_train = torch.linspace(0, 1, 11)
y_train = torch.sin(x_train * (2 * math.pi)) + torch.randn(x_train.shape) * 0.2
x_test = torch.linspace(0, 1, 51).unsqueeze(-1)

with torch.no_grad():
plt.scatter(x_train.numpy(), y_train.numpy())
plt.show()

Since this data has a periodic trend to it, we will use a Periodic Kernel:

k(x,xβ€²)=Οƒ2exp⁑(βˆ’2β„“sin⁑2(Ο€βˆ£xβˆ’xβ€²βˆ£p))k(x,x')=\sigma^2\exp\Big(-\frac{2}{\ell}\sin^2\Big(\pi\frac{|x-x'|}{p}\Big)\Big)

where pp, β„“\ell, Οƒ2\sigma^2 are the periodicity, length scale, and output scale of the function respectively, the (hyper)parameters of the kernel we want to learn.

Maximum Likelihood Estimation (with GPyTorch)​

GPytorch's exact inference algorithms allow you to compute maximum likelihood estimates (MLE) of kernel parameters. Since a SimpleGP extends a GPytorch ExactGP model, you can use GPytorch to optimize the model. Let's try that, closely following the GPytorch regression tutorial.

Regression = SimpleGP
kernel = gpytorch.kernels.ScaleKernel(base_kernel=gpytorch.kernels.PeriodicKernel())
likelihood = gpytorch.likelihoods.GaussianLikelihood()
mean = gpytorch.means.ConstantMean()
gp = Regression(x_train, y_train, mean, kernel, likelihood)
optimizer = torch.optim.Adam(
gp.parameters(), lr=0.1
) # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp)
gp.eval() # this converts the BM model into a gpytorch model
num_iters = 1 if smoke_test else 100

for i in range(num_iters):
optimizer.zero_grad()
output = gp(x_train)
loss = -mll(output, y_train)
loss.backward()
if i % 10 == 0:
print(
"Iter %d/%d - Loss: %.3f"
% (
i + 1,
100,
loss.item(),
)
)
optimizer.step()
Out:

Iter 1/100 - Loss: 1.082

Iter 11/100 - Loss: 0.504

Iter 21/100 - Loss: 0.040

Iter 31/100 - Loss: -0.385

Iter 41/100 - Loss: -0.755

Iter 51/100 - Loss: -0.939

Iter 61/100 - Loss: -0.947

Iter 71/100 - Loss: -0.966

Iter 81/100 - Loss: -0.960

Iter 91/100 - Loss: -0.954

with torch.no_grad():
observed_pred = likelihood(gp(x_test))
# Initialize plot
f, ax = plt.subplots(1, 1, figsize=(4, 3))

# Get upper and lower confidence bounds
lower, upper = observed_pred.confidence_region()
# Plot training data as black stars
ax.plot(x_train.numpy(), y_train.numpy(), "k*")
# Plot predictive means as blue line
ax.plot(x_test.squeeze().numpy(), observed_pred.mean.numpy().T, "b")
# Shade between the lower and upper confidence bounds
ax.fill_between(x_test.squeeze().numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
ax.set_ylim([-1, 1])
ax.legend(["Observed Data", "Mean", "Confidence"])

Not bad! Our GP fits this simple function fairly well. However, we've only captured data uncertainty, not parameter uncertainty. It can often be the case that calibrating parameter uncertainty may lead to better predictive performance. In the next section, we'll do just that using Bean Machine's NUTS algorithm.

Fully Bayesian Inference with Bean Machine​

Let's reuse the same model, but this time, use Bean Machine to learn posteriors over the parameters. In train mode, SimpleGP is a simple wrapper around gpytorch.models.ExactGP that lifts the model's __call__ method to BM.

As before, we'll create kernel and likelihood objects. This time, we'll include priors on our parameters pp, Οƒ2\sigma^2, and β„“\ell.

from gpytorch.priors import UniformPrior


mean = gpytorch.means.ConstantMean(constant_prior=UniformPrior(-1,1))
kernel = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.PeriodicKernel(
period_length_prior=UniformPrior(0.05, 2.5), lengthscale_prior=UniformPrior(0.01, 0.5),
),
outputscale_prior=UniformPrior(1.0, 2.0),
)
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=UniformPrior(0.05, 0.3))

gp = Regression(x_train, y_train, mean, kernel, likelihood)

Now we can run inference as we would with any other Bean Machine model.

from beanmachine.ppl.experimental.gp import make_prior_random_variables, bm_sample_from_prior


num_samples = 1 if smoke_test else 100
num_adaptive_samples = 0 if smoke_test else num_samples // 2
num_chains = 1 if smoke_test else 1

name_to_rv = make_prior_random_variables(gp)
@bm.random_variable
def y():
sampled_model = bm_sample_from_prior(gp.to_pyro_random_module(), name_to_rv)
return sampled_model.likelihood(sampled_model(x_train))

queries = list(name_to_rv.values())
obs = {y(): y_train}

nuts = bm.GlobalNoUTurnSampler(nnc_compile=False)
samples = nuts.infer(
queries=queries,
observations=obs,
num_samples=num_samples,
num_adaptive_samples=num_adaptive_samples,
num_chains=num_chains,
)
Out:

Samples collected: 0%| | 0/100 [00:00<?, ?it/s]

Let's take a look at how our model fit. We will plot the samples of our posterior as well as the predictives generated from our GP.

summary_df = az.summary(samples.to_inference_data())
Markdown(summary_df.to_markdown())
Out:

Shape validation failed: input_shape: (1, 100), minimum_shape: (chains=2, draws=4)

meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
kernel.base_kernel.lengthscale_prior()[0, 0]1.2230.4840.4532.0680.0590.0437178nan
kernel.base_kernel.period_length_prior()[0, 0]1.7150.380.9852.3460.0490.0355327nan
kernel.outputscale_prior()1.4470.2651.0091.8760.040.0284655nan
likelihood.noise_covar.noise_prior()[0]0.0920.0470.050.1840.0060.0042121nan
mean.mean_prior()-0.0330.574-0.9820.8810.090.0923934nan
lengthscale_samples = samples.get_chain(0)[name_to_rv['kernel.base_kernel.lengthscale_prior']]
outputscale_samples = samples.get_chain(0)[name_to_rv['kernel.outputscale_prior']]
period_length_samples = samples.get_chain(0)[name_to_rv['kernel.base_kernel.period_length_prior']]
mean_samples = samples.get_chain(0)[name_to_rv['mean.mean_prior']]
noise_samples = samples.get_chain(0)[name_to_rv['likelihood.noise_covar.noise_prior']]
if not smoke_test:
plt.figure(figsize=(8, 5))
sns.distplot(lengthscale_samples, label="lengthscale")
sns.distplot(outputscale_samples, label="outputscale")
sns.distplot(period_length_samples[: int(num_samples / 2)], label="periodlength")
plt.legend()
plt.title("Posterior Empirical Distribution", fontsize=18)
plt.tight_layout()
plt.show()

To generate predictions, we will convert our model to a Gpytorch model by running in eval mode. We load our posterior samples with a python dict, keyed on the parameter namespace and valued on the torch tensor of samples. Note the unsqueezes to allow broadcasting of the data dimension to the right.

gp.eval()  # converts to Gpytorch model in eval mode
gp.bm_load_samples(
{
"kernel.outputscale_prior": outputscale_samples,
"kernel.base_kernel.lengthscale_prior": lengthscale_samples,
"kernel.base_kernel.period_length_prior": period_length_samples,
"likelihood.noise_covar.noise_prior": noise_samples,
"mean.mean_prior": mean_samples,
}
)
expanded_test_x = x_test.unsqueeze(0).repeat(num_samples, 1, 1)
output = gp(expanded_test_x)

Now we let's plot a few predictive samples from our GP. As you can see, we can draw different kernels, each of which paramaterizes a Multivariate Normal.

if not smoke_test:
with torch.no_grad():
f, ax = plt.subplots(1, 1, figsize=(8, 5))
ax.plot(x_train.numpy(), y_train.numpy(), "k*", zorder=10)
ax.plot(
x_test.numpy(),
output.mean.median(0)[0].detach().numpy(),
"b",
linewidth=1.5,
)
for i in range(min(20, num_samples)):
ax.plot(
x_test.numpy(),
output.mean[i].detach().numpy(),
"gray",
linewidth=0.3,
alpha=0.8,
)
ax.legend(["Observed Data", "Median", "Sampled Means"])

References​

[1] Rasmussen, Carl and Williams, Christopher. Gaussian Processes for Machine Learning. 2006.