Gaussian process
Gaussian Process Regression (with GPytorch
)β
This tutorial assumes familiarity with the following:
- Bean Machine modeling and inference
- Gaussian Processes
- GPyTorch
A Gaussian Process (GP) is a stochastic process commonly used in Bayesian non-parametrics, whose finite collection of random variables follow a multivariate Gaussian distribution. GPs are fully defined by a mean and covariance function:
where are two data points (e.g.) train and test), is the mean function (usually taken to be zero or constant), and is the kernel function, which computes a covariance given two data points and a distance metric.
The aim is then to fit a posterior over functions. GPs allow us to learn a distribution over functions given our observed data and predict unseen data with well-calibrated uncertainty, and is commonly used in Bayesian Optimization as a surrogate function to maximize an objective. For a thorough introduction to Gaussian processes, please see [1]
With a PPL such as Bean Machine, we can be Bayesian about the parameters we care about, i.e. learn posterior distributions over these parameters rather than a point estimate.
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import copy
import math
import os
import warnings
from functools import partial
import arviz as az
import beanmachine
import beanmachine.ppl as bm
import beanmachine.ppl.experimental.gp as bgp
import gpytorch
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.distributions as dist
from beanmachine.ppl.experimental.gp.models import SimpleGP
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel
from IPython.display import Markdown
The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for reproducibility.
# Eliminate excess UserWarnings from Python.
warnings.filterwarnings("ignore")
# Manual seed
torch.manual_seed(123)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
# Tool versions
print("pytorch version: ", torch.__version__)
print("gpytorch version: ", gpytorch.__version__)
Let's use some simple cyclic data:
x_train = torch.linspace(0, 1, 11)
y_train = torch.sin(x_train * (2 * math.pi)) + torch.randn(x_train.shape) * 0.2
x_test = torch.linspace(0, 1, 51).unsqueeze(-1)
with torch.no_grad():
plt.scatter(x_train.numpy(), y_train.numpy())
plt.show()
Since this data has a periodic trend to it, we will use a Periodic Kernel:
where , , are the periodicity, length scale, and output scale of the function respectively, the (hyper)parameters of the kernel we want to learn.
Maximum Likelihood Estimation (with GPyTorch)β
GPytorch's exact inference algorithms allow you to compute maximum likelihood estimates
(MLE) of kernel parameters. Since a SimpleGP
extends a GPytorch ExactGP model, you can
use GPytorch to optimize the model. Let's try that, closely following the
GPytorch regression tutorial.
Regression = SimpleGP
kernel = gpytorch.kernels.ScaleKernel(base_kernel=gpytorch.kernels.PeriodicKernel())
likelihood = gpytorch.likelihoods.GaussianLikelihood()
mean = gpytorch.means.ConstantMean()
gp = Regression(x_train, y_train, mean, kernel, likelihood)
optimizer = torch.optim.Adam(
gp.parameters(), lr=0.1
) # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp)
gp.eval() # this converts the BM model into a gpytorch model
num_iters = 1 if smoke_test else 100
for i in range(num_iters):
optimizer.zero_grad()
output = gp(x_train)
loss = -mll(output, y_train)
loss.backward()
if i % 10 == 0:
print(
"Iter %d/%d - Loss: %.3f"
% (
i + 1,
100,
loss.item(),
)
)
optimizer.step()
with torch.no_grad():
observed_pred = likelihood(gp(x_test))
# Initialize plot
f, ax = plt.subplots(1, 1, figsize=(4, 3))
# Get upper and lower confidence bounds
lower, upper = observed_pred.confidence_region()
# Plot training data as black stars
ax.plot(x_train.numpy(), y_train.numpy(), "k*")
# Plot predictive means as blue line
ax.plot(x_test.squeeze().numpy(), observed_pred.mean.numpy().T, "b")
# Shade between the lower and upper confidence bounds
ax.fill_between(x_test.squeeze().numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
ax.set_ylim([-1, 1])
ax.legend(["Observed Data", "Mean", "Confidence"])
Not bad! Our GP fits this simple function fairly well. However, we've only captured data uncertainty, not parameter uncertainty. It can often be the case that calibrating parameter uncertainty may lead to better predictive performance. In the next section, we'll do just that using Bean Machine's NUTS algorithm.
Fully Bayesian Inference with Bean Machineβ
Let's reuse the same model, but this time, use Bean Machine to learn posteriors over the
parameters. In train
mode, SimpleGP
is a simple wrapper around
gpytorch.models.ExactGP
that lifts the model's __call__
method to BM.
As before, we'll create kernel and likelihood objects. This time, we'll include priors on our parameters , , and .
from gpytorch.priors import UniformPrior
mean = gpytorch.means.ConstantMean(constant_prior=UniformPrior(-1,1))
kernel = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.PeriodicKernel(
period_length_prior=UniformPrior(0.05, 2.5), lengthscale_prior=UniformPrior(0.01, 0.5),
),
outputscale_prior=UniformPrior(1.0, 2.0),
)
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=UniformPrior(0.05, 0.3))
gp = Regression(x_train, y_train, mean, kernel, likelihood)
Now we can run inference as we would with any other Bean Machine model.
from beanmachine.ppl.experimental.gp import make_prior_random_variables, bm_sample_from_prior
num_samples = 1 if smoke_test else 100
num_adaptive_samples = 0 if smoke_test else num_samples // 2
num_chains = 1 if smoke_test else 1
name_to_rv = make_prior_random_variables(gp)
@bm.random_variable
def y():
sampled_model = bm_sample_from_prior(gp.to_pyro_random_module(), name_to_rv)
return sampled_model.likelihood(sampled_model(x_train))
queries = list(name_to_rv.values())
obs = {y(): y_train}
nuts = bm.GlobalNoUTurnSampler(nnc_compile=False)
samples = nuts.infer(
queries=queries,
observations=obs,
num_samples=num_samples,
num_adaptive_samples=num_adaptive_samples,
num_chains=num_chains,
)
Let's take a look at how our model fit. We will plot the samples of our posterior as well as the predictives generated from our GP.
summary_df = az.summary(samples.to_inference_data())
Markdown(summary_df.to_markdown())
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
kernel.base_kernel.lengthscale_prior()[0, 0] | 1.223 | 0.484 | 0.453 | 2.068 | 0.059 | 0.043 | 71 | 78 | nan |
kernel.base_kernel.period_length_prior()[0, 0] | 1.715 | 0.38 | 0.985 | 2.346 | 0.049 | 0.035 | 53 | 27 | nan |
kernel.outputscale_prior() | 1.447 | 0.265 | 1.009 | 1.876 | 0.04 | 0.028 | 46 | 55 | nan |
likelihood.noise_covar.noise_prior()[0] | 0.092 | 0.047 | 0.05 | 0.184 | 0.006 | 0.004 | 21 | 21 | nan |
mean.mean_prior() | -0.033 | 0.574 | -0.982 | 0.881 | 0.09 | 0.092 | 39 | 34 | nan |
lengthscale_samples = samples.get_chain(0)[name_to_rv['kernel.base_kernel.lengthscale_prior']]
outputscale_samples = samples.get_chain(0)[name_to_rv['kernel.outputscale_prior']]
period_length_samples = samples.get_chain(0)[name_to_rv['kernel.base_kernel.period_length_prior']]
mean_samples = samples.get_chain(0)[name_to_rv['mean.mean_prior']]
noise_samples = samples.get_chain(0)[name_to_rv['likelihood.noise_covar.noise_prior']]
if not smoke_test:
plt.figure(figsize=(8, 5))
sns.distplot(lengthscale_samples, label="lengthscale")
sns.distplot(outputscale_samples, label="outputscale")
sns.distplot(period_length_samples[: int(num_samples / 2)], label="periodlength")
plt.legend()
plt.title("Posterior Empirical Distribution", fontsize=18)
plt.tight_layout()
plt.show()
To generate predictions, we will convert our model to a Gpytorch model by running in
eval
mode. We load our posterior samples with a python dict, keyed on the parameter
namespace and valued on the torch tensor of samples. Note the unsqueeze
s to allow
broadcasting of the data dimension to the right.
gp.eval() # converts to Gpytorch model in eval mode
gp.bm_load_samples(
{
"kernel.outputscale_prior": outputscale_samples,
"kernel.base_kernel.lengthscale_prior": lengthscale_samples,
"kernel.base_kernel.period_length_prior": period_length_samples,
"likelihood.noise_covar.noise_prior": noise_samples,
"mean.mean_prior": mean_samples,
}
)
expanded_test_x = x_test.unsqueeze(0).repeat(num_samples, 1, 1)
output = gp(expanded_test_x)
Now we let's plot a few predictive samples from our GP. As you can see, we can draw different kernels, each of which paramaterizes a Multivariate Normal.
if not smoke_test:
with torch.no_grad():
f, ax = plt.subplots(1, 1, figsize=(8, 5))
ax.plot(x_train.numpy(), y_train.numpy(), "k*", zorder=10)
ax.plot(
x_test.numpy(),
output.mean.median(0)[0].detach().numpy(),
"b",
linewidth=1.5,
)
for i in range(min(20, num_samples)):
ax.plot(
x_test.numpy(),
output.mean[i].detach().numpy(),
"gray",
linewidth=0.3,
alpha=0.8,
)
ax.legend(["Observed Data", "Median", "Sampled Means"])
Referencesβ
[1] Rasmussen, Carl and Williams, Christopher. Gaussian Processes for Machine Learning. 2006.