Logistic regression
Tutorial: Bayesian Logistic Regressionβ
The purpose of this tutorial is to show how to build a simple Bayesian model to deduce the line which separates two categories of points.
Problemβ
A logistic model is a statistical model where we have a collection of things that can be divided into two categories β pictures of cats or dogs, patients who are immune or susceptible, students who pass or fail, and so on. An assumption of the model is: the probability that a given thing is in a category can be computed by taking a linear combination of characteristics of the thing.
The problem we seek to solve in this tutorial is: when given a set of examples of each category, can we infer the boundary between the categories?
Prerequisitesβ
We will be using the following packages within this tutorial.
Let's code this in Bean Machine! Import the Bean Machine library and some fundamental PyTorch classes.
# Install Bean Machine in Colab if using Colab.
import sys
if "google.colab" in sys.modules and "beanmachine" not in sys.modules:
!pip install beanmachine
import os
import warnings
import arviz as az
import beanmachine.ppl as bm
import torch
import torch.distributions as dist
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
from beanmachine.tutorials.utils import plots
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource, MultiLine, Span
from bokeh.plotting import show
from IPython.display import Markdown
The next cell includes convenient configuration settings to improve the notebook presentation as well as setting a manual seed for reproducibility.
# Eliminate excess UserWarnings from Python.
warnings.filterwarnings("ignore")
# Plotting settings
az.rcParams["plot.backend"] = "bokeh"
az.rcParams["stats.hdi_prob"] = 0.89
# Manual seed
bm.seed(0)
# Other settings for the notebook.
smoke_test = "SANDCASTLE_NEXUS" in os.environ or "CI" in os.environ
Example: orange and blue pointsβ
For the purposes of this tutorial our "things" will be two-dimensional points between and . Our two categories will be represented by , which we'll render as blue, and , which we'll render as orange.
The probability that a given point will be orange is
for some coefficients .
That equation is seen to be a linear combination of properties of the point if we express the probability as log odds:
The boundary between the two categories is the set of points where the probability is that the point is orange; the log odds of is , so the equation of the set of points separating the categories is a straight line:
Or, expressed as slope and intercept:
Our goal therefore is to infer possible values for from which we can compute the line separating the categories.
Creating sample dataβ
We start by creating a data set to illustrate the model. In order to make computations faster and easier, we will express the entire data set as a single tensor representing 200 points. Each row of the tensor will be in the form . That way we can compute the linear combination by matrix-multiplying each row by to obtain the log odds that the point is orange: .
For our synthetic dataset, we will assume the following parameters.
N = 200
low = -10.0
high = 10.0
uniform = dist.Uniform(
low=torch.tensor([1.0, low, low]),
high=torch.tensor([1.0, high, high]),
)
points = torch.tensor([uniform.sample().tolist() for i in range(N)]).view(N, 3)
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
x = points[:, 1]
y = points[:, 2]
cds = ColumnDataSource({"x": x.tolist(), "y": y.tolist()})
tips = [("y", "@y{0.000}"), ("x", "@x{0.000}")]
synthetic_data_plot = plots.scatter_plot(
plot_sources=cds,
tooltips=tips,
figure_kwargs={
"title": "Synthetic data",
"x_axis_label": "x",
"y_axis_label": "y",
},
plot_kwargs={"fill_color": "black"},
)
show(synthetic_data_plot)
We now assign points to categories (blue) and (orange).
For this example we will assign categories to points using as our coefficients, which makes the line separating the categories .
true_coefficients = torch.tensor([-2.0, 0.3, -0.5]).view(3, 1)
true_slope = -float(true_coefficients[1] / true_coefficients[2])
true_intercept = -float(true_coefficients[0] / true_coefficients[2])
def log_odds(point):
return point.view(1, 3).mm(true_coefficients)
observed_categories = torch.tensor(
[dist.Bernoulli(logits=log_odds(point)).sample() for point in points]
)
Data visualization methodsβ
It is useful to have helper methods to visualize the synthetic data set and the line which separates the two categories.
def categorize_points(points, categories):
orange_x = []
orange_y = []
blue_x = []
blue_y = []
for point, category in zip(points, categories):
if category == 1:
orange_x.append(float(point[1]))
orange_y.append(float(point[2]))
else:
blue_x.append(float(point[1]))
blue_y.append(float(point[2]))
return {
"orange": {"x": orange_x, "y": orange_y, "label": ["orange"] * len(orange_x)},
"blue": {"x": blue_x, "y": blue_y, "label": ["blue"] * len(blue_x)},
}
def plot_line(slope, intercept, high=10, low=-10):
if intercept > high or intercept < low:
return
xs = [low, high]
ys = [slope * low + intercept, slope * high + intercept]
if ys[0] > high:
xs[0] = (high - intercept) / slope
ys[0] = high
elif ys[0] < low:
xs[0] = (low - intercept) / slope
ys[0] = low
if ys[1] > high:
xs[1] = (high - intercept) / slope
ys[1] = high
elif ys[1] < low:
xs[1] = (low - intercept) / slope
ys[1] = low
return xs, ys
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
points_with_categories = categorize_points(points, observed_categories)
orange_cds = ColumnDataSource(
{
"x": points_with_categories["orange"]["x"],
"y": points_with_categories["orange"]["y"],
"label": points_with_categories["orange"]["label"],
}
)
orange_tips = [("Category", "@label"), ("y", "@y{0.000}"), ("x", "@x{0.000}")]
blue_cds = ColumnDataSource(
{
"x": points_with_categories["blue"]["x"],
"y": points_with_categories["blue"]["y"],
"label": points_with_categories["blue"]["label"],
}
)
blue_tips = [("Category", "@label"), ("y", "@y{0.000}"), ("x", "@x{0.000}")]
synthetic_data_with_categories_plot = plots.scatter_plot(
plot_sources=[orange_cds, blue_cds],
tooltips=[orange_tips, blue_tips],
figure_kwargs={
"title": "Synthetic data with categories",
"x_axis_label": "x",
"y_axis_label": "y",
},
legend_items=["Category orange", "Category blue"],
plot_kwargs={"fill_color": "label"},
)
# Add the separating line.
x, y = plot_line(true_slope, true_intercept)
synthetic_data_with_categories_plot.line(
x=x,
y=y,
legend_label="Separating line",
line_color="black",
line_width=3,
line_alpha=1,
)
show(synthetic_data_with_categories_plot)
Modelβ
We can now start building our model.
The first thing we need is a prior distribution for our three coefficients.
We have no reason to believe that the coefficients will be either positive or negative, so we should choose a prior distribution that is centered on zero. We use a matrix multiplication to compute the linear combination, and therefore make the prior a column vector of samples from a normal distribution:
@bm.random_variable
def coefficients():
mean = torch.zeros(3, 1)
sigma = torch.ones(3, 1)
return dist.Normal(mean, sigma)
Our model for categories is now straightforward: each category is chosen by matrix-multiplying the point by the prior distribution of coefficients, and we get a set of categories from the Bernoulli distribution; either (blue) or (orange):
@bm.random_variable
def categories():
return dist.Bernoulli(logits=points.mm(coefficients()))
Inferenceβ
We can now infer the posterior distribution of the coefficients given the observations:
num_samples = 2 if smoke_test else 2000
num_adaptive_samples = 0 if smoke_test else num_samples // 2
num_chains = 1 if smoke_test else 4
observations = {categories(): observed_categories.view(N, 1)}
queries = [coefficients()]
mc = bm.GlobalNoUTurnSampler()
%%time
samples = mc.infer(
queries=queries,
observations=observations,
num_samples=num_samples,
num_chains=num_chains,
num_adaptive_samples=num_adaptive_samples,
)
sampled_coefficients = samples.get_chain()[coefficients()]
The slopes and intercepts are computed from the sampled coefficients:
slopes = [-float(s[1] / s[2]) for s in sampled_coefficients if float(s[2]) != 0.0]
intercepts = [-float(s[0] / s[2]) for s in sampled_coefficients if float(s[2]) != 0.0]
Posterior resultsβ
A histogram of the inferred slopes should cluster near the true value, marked in red. A slight deviation is to be expected since there is noise in our data.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
slope_hist_plot = plots.histogram_plot(slopes)
# Add a line showing the true slope.
span = Span(
location=true_slope,
dimension="height",
line_color="red",
line_width=3,
)
slope_hist_plot.add_layout(span)
show(slope_hist_plot)
And similarly for the intercepts. Our model actually predicts a higher median intercept to capture the orange points above the slope in our ground truth.
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
intercept_hist_plot = plots.histogram_plot(intercepts)
# Add a line showing the true slope.
span = Span(
location=true_intercept,
dimension="height",
line_color="red",
line_width=3,
)
intercept_hist_plot.add_layout(span)
show(intercept_hist_plot)
Inference typically finds a reasonable distribution of possible lines that separate these two categories.
Diagnosticsβ
In addition to visualizing the lines directly, we can also use the diagnostics capabilities of Arviz to examine the diagnostics of our variables:
filtered_samples = {k: v for k, v in samples.items() if k == coefficients()}
az_data = MonteCarloSamples(filtered_samples).to_inference_data()
summary_df = az.summary(az_data, round_to=3).to_markdown()
Markdown(summary_df)
mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
coefficients()[0,0] | -1.774 | 0.294 | -2.245 | -1.313 | 0.005 | 0.004 | 3561.2 | 5042.12 | 1.001 |
coefficients()[1,0] | 0.262 | 0.048 | 0.185 | 0.336 | 0.001 | 0.001 | 3543.11 | 4425.9 | 1 |
coefficients()[2,0] | -0.452 | 0.064 | -0.548 | -0.342 | 0.001 | 0.001 | 3137.52 | 4031.13 | 1.001 |
As you can see, the average and median (50%) values inferred are reasonably close to the true values but there is a large spread. Remember that any constant multiple of the true values would produce the same line; if we got , and then the slope and intercept would be the same and only the amount of "mixing" near the line would change, so we can expect that this inference problem may have some fairly large variance.
The n_eff
column is the effective sample size, which indicates how correlated the
posterior samples are to each other; higher is better. This effective sample size is a
little low.
Notice that we generated the data by separating points along a line with slope and intercept , but were we to observe the specific set of 200 points we classified without knowing the true parameters ahead of time, we would deduce that the separating line had a slightly lower slope than and an intercept around .
To more clearly illustrate the accuracy of the inference, we can take a random selection of the inferred lines and plot them on the data set:
# Required for visualizing in Colab.
output_notebook(hide_banner=True)
# Replicate the original data separating plot above.
randomly_selected_lines_plot = plots.scatter_plot(
plot_sources=[orange_cds, blue_cds],
tooltips=[orange_tips, blue_tips],
figure_kwargs={
"title": "Synthetic data with categories",
"x_axis_label": "x",
"y_axis_label": "y",
},
legend_items=["Category orange", "Category blue"],
plot_kwargs={"fill_color": "label"},
)
# Add randomly selected sampled separating lines.
num_lines = 25
sampled_indices = torch.randint(0, len(slopes), (num_lines,)).tolist()
xs = []
ys = []
for sampled_index in sampled_indices:
sampled_slope = slopes[sampled_index]
sampled_intercept = intercepts[sampled_index]
x, y = plot_line(sampled_slope, sampled_intercept)
xs.append(x)
ys.append(y)
cds = ColumnDataSource({"xs": xs, "ys": ys})
glyph = MultiLine(xs="xs", ys="ys", line_color="magenta", line_alpha=0.2)
randomly_selected_lines_plot.add_glyph(cds, glyph)
# Add the separating line.
x, y = plot_line(true_slope, true_intercept)
randomly_selected_lines_plot.line(
x=x,
y=y,
legend_label="Separating line",
line_color="black",
line_width=3,
line_alpha=1,
)
show(randomly_selected_lines_plot)