Skip to main content

Robust linear regression

Tutorial: Robust Linear regression​

This tutorial demonstrates modeling and running inference on a robust linear regression model in Bean Machine. This should offer a simple modification from the standard regression model to incorporate heavy tailed error models that are more robust to outliers and demonstrates modifying base models.

Problem​

In this classical extension to the linear regression problem, the goal still is to estimate some unobserved response variable from an observed covariate. The twist from the basic linear regression model is that we now believe that errors can occasionally be large and we would like the model to be more robust to such error sources through an adequate observation model. We'll construct a Bayesian model for this problem, which will yield not only point estimates but also measures of uncertainty in our predictions. We will also explore different inference procedures for this tutorial.

We'll restrict this tutorial to the univariate case, to aid with clarity and visualization. We will follow the outline of the tutorial by Adrian Baez-Ortega for comparison.

Prerequisites​

We will be using the following packages within this tutorial.

Let's code this in Bean Machine! Import the Bean Machine library and some fundamental PyTorch classes.

# Install Bean Machine in Colab if using Colab.
import sys


if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import os

import arviz as az
import beanmachine.ppl as bm
import numpy as np
import pandas as pd
import sklearn.model_selection
import torch
import torch.distributions as dist
from beanmachine.tutorials.utils import plots
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.palettes import Colorblind3
from bokeh.plotting import gridplot, show
from IPython.display import Markdown

The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for torch for reproducibility.

# Plotting settings
az.rcParams["plot.backend"] = "bokeh"
az.rcParams["stats.hdi_prob"] = 0.89

# Manual seed
torch.manual_seed(11)

# Other settings for the notebook
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ

Model​

We're interested in predicting a response variable yy given an observed covariate xx:

  • y=βx+α+errory = \beta x + \alpha + \text{error}

If we assume the error to stem from a heavy-tailed distribution, for instance a Student's T, we can reframe this as:

  • y∼T(βx+α,σ,ν)y \sim \mathcal{T}(\beta x + \alpha, \sigma, \nu)

Here, β\beta is a coefficient for xx, α\alpha is a bias term, σ\sigma is a term indicating a deviation, and ν\nu is a term describing degrees of freedom. Specifically:

  • N∈Z+N \in \mathbb{Z}^+ is the size of the training data.
  • xi∈Rx_i \in \mathbb{R} is the observed covariate.
  • β∈R\beta \in \mathbb{R} is the coefficient for xx. We'll use a prior of N(0,1000)\mathcal{N}(0,1000).
  • α∈R\alpha \in \mathbb{R} is the bias term. We'll use a prior of N(0,1000)\mathcal{N}(0,1000).
  • σ∈R+\sigma \in \mathbb{R}^+ is the error deviation. We'll use a Half-Normal prior HN(0,1000)\text{H}\mathcal{N}(0, 1000).
  • ν∈R+\nu \in \mathbb{R}^+ is the degrees of freedom of the Student's T. We'll use a Gamma prior G(2,0.1)\text{G}(2,0.1).
  • yi∼iidT(βxi+α,σ,ν)∈Ry_i \stackrel{iid}{\sim} \mathcal{T}(\beta x_i + \alpha, \sigma, \nu) \in \mathbb{R} is the prediction.

We are interested in fitting posterior distributions for β\beta, α\alpha, σ\sigma, and ν\nu given a collection of training data {x,y}i=1N\{x, y\}_{i=1}^N.

Let's visualize the Gamma distribution that we used as our prior for ν\nu:

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

concentration = 2
rate = 0.1
x = torch.arange(0, 5, 0.01)
y = dist.Gamma(concentration, rate).log_prob(x).exp()
cds = ColumnDataSource({"x": x.tolist(), "y": y.tolist()})

gamma_prior_plot = plots.line_plot(
plot_sources=[cds],
tooltips=[[("Density", "@y{0.000}"), ("ν", "@x{0.000}")]],
figure_kwargs={
"x_axis_label": "nu",
"y_axis_label": "density",
"title": f"Γ({concentration}, {rate}) prior",
},
plot_kwargs={"line_width": 2, "hover_line_color": "orange"},
)
show(gamma_prior_plot)
loading...

We can implement this model in Bean Machine by defining random variable objects with the @bm.random_variable decorator. These functions behave differently than ordinary Python functions.

Semantics for @bm.random_variable functions:
  • They must return PyTorch Distribution objects.
  • Though they return distributions, callees actually receive samples from the distribution. The machinery for obtaining samples from distributions is handled internally by Bean Machine.
  • Inference runs the model through many iterations. During a particular inference iteration, a distinct random variable will correspond to exactly one sampled value: calls to the same random variable function with the same arguments will receive the same sampled value within one inference iteration. This makes it easy for multiple components of your model to refer to the same logical random variable.
  • Consequently, to define distinct random variables that correspond to different sampled values during a particular inference iteration, an effective practice is to add a dummy "indexing" parameter to the function. Distinct random variables can be referred to with different values for this index.
  • Please see the documentation for more information about this decorator.
@bm.random_variable
def beta():
"""
Regression Coefficient
"""
return dist.Normal(0, 1000)


@bm.random_variable
def alpha():
"""
Regression Bias/Offset
"""
return dist.Normal(0, 1000)


@bm.random_variable
def sigma_regressor():
"""
Deviation parameter for Student's T
Controls the magnitude of the errors.
"""
return dist.HalfNormal(1000)


@bm.random_variable
def df_nu():
"""
Degrees of Freedom of a Student's T
Check https://en.wikipedia.org/wiki/Student%27s_t-distribution for effect
"""
return dist.Gamma(2, 0.1)


@bm.random_variable
def y_robust(X):
"""
Heavy-Tailed Noise model for regression utilizing StudentT
Student's T : https://en.wikipedia.org/wiki/Student%27s_t-distribution
"""
return dist.StudentT(df=df_nu(), loc=beta() * X + alpha(), scale=sigma_regressor())

Data​

With the model defined, we need to collect some observed data in order to learn about values of interest in our model.

In this case, we will observe a few samples of inputs and outputs. For demonstrative purposes, we will use a synthetically generated dataset of observed values. In practice, you would gather a collection of covariate and response variables, and then you could construct a model to predict a new, unobserved response variable from a new, observed covariate.

For our synthetic dataset, we will assume the following parameters to the relationship between inputs and outputs.

true_beta = 2.0
true_alpha = 5.0
true_epsilon = 1.0
N = 200

X = X_clean = dist.Normal(0, 1).expand([N, 1]).sample()
Y = Y_clean = dist.Normal(true_beta * X + true_alpha, true_epsilon).sample()

We can visualize the data as follows:

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

clean_cds = ColumnDataSource(
{
"x": X.flatten().tolist(),
"y": Y.flatten().tolist(),
"color": ["steelblue"] * len(X.flatten()),
}
)
clean_tips = [("y", "@y{0.000}"), ("x", "@x{0.000}")]
synthetic_clean_data_plot = plots.scatter_plot(
plot_sources=[clean_cds],
tooltips=[clean_tips],
figure_kwargs={
"title": "Synthetic clean data",
"x_axis_label": "x",
"y_axis_label": "y",
"plot_width": 500,
"plot_height": 500,
"x_range": [-6, 6],
"y_range": [-11, 16],
},
legend_items="Clean data",
)
synthetic_clean_data_plot.legend.location = "bottom_left"
show(synthetic_clean_data_plot)
loading...

We will now corrupt this data with some extreme outliers.

X_corr, Y_corr = X_clean, Y_clean
X_corr[0], Y_corr[0] = 5, 15
X_corr[1], Y_corr[1] = -5, -10
X_corr[2], Y_corr[2] = -3, 10

Let's visualize the data now again after adding outliers.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

corrupted_cds = ColumnDataSource(
{"x": [5, -5, -3], "y": [15, -10, 10], "color": ["orange"] * 3}
)
corrupted_tips = [("y", "@y{0.000}"), ("x", "@x{0.000}")]
synthetic_corrupted_data_plot = plots.scatter_plot(
plot_sources=[clean_cds, corrupted_cds],
tooltips=[clean_tips, corrupted_tips],
figure_kwargs={
"title": "Synthetic corrupted data",
"x_axis_label": "x",
"y_axis_label": "y",
"plot_width": 500,
"plot_height": 500,
"x_range": synthetic_clean_data_plot.x_range,
"y_range": synthetic_clean_data_plot.y_range,
},
legend_items=["Clean data", "Corrupted data"],
plot_kwargs={"fill_color": "color"},
)

# Compare clean vs corrupted data.
synthetic_clean_data_plot.legend.visible = False
synthetic_corrupted_data_plot.legend.location = "bottom_right"
synthetic_data_plot = gridplot(
[[synthetic_clean_data_plot, synthetic_corrupted_data_plot]]
)
show(synthetic_data_plot)
loading...

Let's split the dataset into a training and test set, which we'll use later to evaluate predictive performance.

X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X, Y)

Our inference algorithms expect observations in the form of a dictionary. This dictionary should consist of @bm.random_variable invocations as keys, and tensor data as values.

You can see this in the code snippet below, where we bind the observed values to a key representing the random variable that was observed.

observations = {y_robust(X_train): Y_train}

Inference: Take 1​

Inference is the process of combining model with data to obtain insights, in the form of probability distributions over values of interest. Bean Machine offers a powerful and general inference framework to enable fitting arbitrary models to data.

As a starting point for running inference, we will use the basic Metropolis-Hastings inference algorithm. Ancestral Metropolis-Hastings is a simple inference algorithm, which proposes child random variables conditional on values for the parent random variables. The most ancestral random variables are simply sampled from the prior distribution.

Running inference consists of a few arguments:

NameUsage
queriesList of @bm.random_variable targets to fit posterior distributions for.
observationsA dictionary of observations, as built above.
num_samplesNumber of Monte Carlo samples to approximate the posterior distributions for the variables in queries.
num_chainsNumber of separate inference runs to use. Multiple chains can help verify that inference ran correctly.

Let's run inference:

queries = [beta(), alpha(), sigma_regressor(), df_nu()]
num_samples = 2 if smoke_test else 2000
num_chains = 1 if smoke_test else 4
num_adaptive_samples = 0 if smoke_test else num_samples // 2
# Experimental backend using the Pytorch NNC compiler
nnc_compile = "SANDCASTLE_NEXUS" not in os.environ

samples = bm.GlobalNoUTurnSampler(nnc_compile=nnc_compile).infer(
queries=queries,
observations=observations,
num_samples=num_samples,
num_adaptive_samples=num_adaptive_samples,
num_chains=num_chains,
)
Out:

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Samples collected: 0%| | 0/3000 [00:00<?, ?it/s]

Out:

/Users/jpchen/beanmachine/src/beanmachine/ppl/experimental/nnc/utils.py:21: UserWarning: The support of NNC compiler is experimental and the API is subject tochange in the future releases of Bean Machine. For questions regarding NNC, pleasecheckout the functorch project (https://github.com/pytorch/functorch).

"The support of NNC compiler is experimental and the API is subject to"

Analysis​

samples now contains our inference results.

beta_marginal = samples[beta()].flatten(start_dim=0, end_dim=1).detach()
alpha_marginal = samples[alpha()].flatten(start_dim=0, end_dim=1).detach()
sigma_marginal = samples[sigma_regressor()].flatten(start_dim=0, end_dim=1).detach()
nu_marginal = samples[df_nu()].flatten(start_dim=0, end_dim=1).detach()

print(
f"β marginal: {beta_marginal}\n"
f"α marginal: {alpha_marginal}\n"
f"σ marginal: {sigma_marginal}\n"
f"ν marginal: {nu_marginal}"
)
Out:

β marginal: tensor([1.9131, 1.9169, 1.8697, ..., 1.9127, 1.9227, 2.0932])

α marginal: tensor([5.1544, 5.1075, 5.0664, ..., 5.1780, 5.1720, 4.9826])

σ marginal: tensor([1.0382, 0.8749, 0.9554, ..., 0.9076, 0.9179, 0.8641])

ν marginal: tensor([6.0199, 8.3381, 6.1413, ..., 4.2614, 4.5714, 5.3702])

Next, let's visualize the inferred random variables.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

alpha_beta_joint_plot = plots.plot_marginal(
queries=[alpha(), beta()],
samples=samples,
true_values=[true_alpha, true_beta],
bandwidth=0.1,
joint_plot_title="α-β marginal",
)
show(alpha_beta_joint_plot)
loading...

We seem to have faithfully recovered α\alpha and β\beta. Let's see what the model predicts for the variance.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

sigma_plot = plots.plot_marginal(
queries=[sigma_regressor()],
samples=samples,
true_values=[true_epsilon],
n_bins=100,
bandwidth=0.025,
)
show(sigma_plot)
loading...

We seem to have recovered a reasonably good understanding of the variance.

We can also compute log probability on the held-out test data. This isn't particularly useful on its own, but is useful for comparing different approaches. Thus, here, we will also plot a baseline for comparison: the log probability implied on the test dataset using the ground truth parameters.

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

test_data_y = (
dist.Normal(
(
X_test @ beta_marginal[:num_samples].unsqueeze(0)
+ alpha_marginal[:num_samples].unsqueeze(0)
),
sigma_marginal[:num_samples],
)
.log_prob(Y_test)
.sum(dim=0)
)
test_data_x = list(range(len(test_data_y.tolist())))
test_data_cds = ColumnDataSource({"x": test_data_x, "y": test_data_y.tolist()})

ground_truth_y = (
dist.Normal(X_test * true_beta + true_alpha, true_epsilon)
.log_prob(Y_test)
.sum(dim=0)
.item()
)
ground_truth_cds = ColumnDataSource(
{"x": test_data_x, "y": [ground_truth_y] * len(test_data_x)}
)

test_data_mh_plot = plots.line_plot(
plot_sources=[ground_truth_cds, test_data_cds],
labels=[f"Ground truth = {ground_truth_y:.2f}", "MH"],
figure_kwargs={
"y_axis_label": "Log probability",
"x_axis_label": "Draw",
"title": "Log probability on test data",
},
plot_kwargs={"line_width": 3, "line_alpha": 0.7},
)
test_data_mh_plot.legend.location = "bottom_right"

show(test_data_mh_plot)
loading...

The posterior samples seems to capture the log probability using the ground truth parameters.

ArviZ provides helpful statistics about the results of inference, which we show below.

Markdown(az.summary(samples.to_inference_data(), round_to=4).to_markdown())
meansdhdi_5.5%hdi_94.5%mcse_meanmcse_sdess_bulkess_tailr_hat
sigma_regressor()0.91260.07610.79421.03560.00110.00085054.084966.71.0005
alpha()5.03760.08614.9045.17690.00110.000856804716.971.0003
df_nu()5.45941.7232.89277.69370.0240.01725350.484434.721.0008
beta()2.00520.0731.88512.1190.0010.00075503.354848.911.0013

R^\hat{R} diagnostic​

R^\hat{R} is a diagnostic tool that measures the between- and within-chain variances. It is a test that indicates a lack of convergence by comparing the variance between multiple chains to the variance within each chain. If the parameters are successfully exploring the full space for each chain, then R^≈1\hat{R}\approx 1, since the between-chain and within-chain variance should be equal. R^\hat{R} is calculated as

R^=V^W\hat{R}=\frac{\hat{V}}{W}

where WW is the within-chain variance and V^\hat{V} is the posterior variance estimate for the pooled rank-traces. The take-away here is that R^\hat{R} converges towards 1 when each of the Markov chains approaches perfect adaptation to the true posterior distribution. We do not recommend using inference results if R^>1.01\hat{R}>1.01. More information about R^\hat{R} can be found in the Vehtari et al paper.

Effective sample size essess diagnostic​

MCMC samplers do not draw independent samples from the target distribution, which means that our samples are correlated. In an ideal situation all samples would be independent, but we do not have that luxury. We can, however, measure the number of effectively independent samples we draw, which is called the effective sample size. You can read more about how this value is calculated in the Vehtari et al paper, briefly it is a measure that combines information from the R^\hat{R} value with the autocorrelation estimates within the chains. There are many ways to estimate effective samples sizes, however, we will be using the method defined in the Vehtari et al paper.

The rule of thumb for ess_bulk is for this value to be greater than 100 per chain on average. Since we ran four chains, we need ess_bulk to be greater than 400 for each parameter. The ess_tail is an estimate for effectively independent samples considering the more extreme values of the posterior. This is not the number of samples that landed in the tails of the posterior. It is a measure of the number of effectively independent samples if we sampled the tails of the posterior. The rule of thumb for this value is also to be greater than 100 per chain on average.

We can plot diagnostic information to assess model fit using ArviZ. Let's take a look:

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

samples_diagnostic_plots = gridplot(plots.plot_diagnostics(samples))
show(samples_diagnostic_plots)
loading...

The diagnostics output shows two diagnostic plots for individual random variables: trace plots and autocorrelation plots.

  • Rank plots are a histogram of the samples over time. All samples across all chains are ranked and then we plot the average rank for each chain on regular intervals. If the chains are mixing well this histogram should look roughly uniform. If it looks highly irregular that suggests chains might be getting stuck and not adequately exploring the sample space.

  • Autocorrelation plots measure how predictive the last several samples are of the current sample. Autocorrelation may vary between -1.0 (deterministically anticorrelated) and 1.0 (deterministically correlated). We compute autocorrelation approximately, so it may sometimes exceed these bounds. In an ideal world, the current sample is chosen independently of the previous samples: an autocorrelation of zero. This is not possible in practice, due to stochastic noise and the mechanics of how inference works.

From the trace plots, we see each of the chains are relatively healthy: they don't get stuck, and do not explore a chain-specific subset of the space.

Prediction​

We've built and evaluated our model. Lastly, let's take a quick look at how to predict with it.

def predict(x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x).float()
return pd.DataFrame(
np.percentile(
dist.Normal(
x.view([-1, 1]) @ beta_marginal.unsqueeze(0)
+ alpha_marginal.unsqueeze(0),
sigma_marginal.unsqueeze(0),
)
.sample([10])
.transpose(0, 1)
.flatten(1),
[2.5, 50, 97.5],
axis=1,
).T,
index=x.view(-1).numpy(),
columns=["2.5%", "50%", "97.5%"],
)

Predict for a single value:

Markdown(predict(4).to_markdown())
2.5%50%97.5%
411.160113.063214.9579

Or for a range:

# Required for visualizing in Colab.
output_notebook(hide_banner=True)

predicted_df = predict(torch.linspace(-10, 10, 100))
low_cds = ColumnDataSource(
{"x": predicted_df.index.tolist(), "y": predicted_df["2.5%"].tolist()}
)
med_cds = ColumnDataSource(
{"x": predicted_df.index.tolist(), "y": predicted_df["50%"].tolist()}
)
high_cds = ColumnDataSource(
{"x": predicted_df.index.tolist(), "y": predicted_df["97.5%"].tolist()}
)
prediction_plot = plots.line_plot(
plot_sources=[low_cds, med_cds, high_cds],
labels=predicted_df.columns.tolist(),
figure_kwargs={"x_axis_label": "x", "y_axis_label": "y", "title": "Predictions"},
plot_kwargs={"line_width": 2, "line_alpha": 0.6},
)
prediction_plot.legend.location = "top_left"
# Add the clean and corrupted data.
prediction_plot.circle(
x="x",
y="y",
source=clean_cds,
fill_color=Colorblind3[0],
line_color="white",
fill_alpha=0.6,
size=7,
legend_label="Clean data",
)
prediction_plot.circle(
x="x",
y="y",
source=corrupted_cds,
fill_color=Colorblind3[1],
line_color="white",
fill_alpha=0.6,
size=10,
legend_label="Corrupted data",
)

show(prediction_plot)
loading...

References​