Skip to main content

Quick Start

Let's quickly translate the model we discussed in "Why Bean Machine?" into Bean Machine code! Although this will get you up-and-running, it's important that you read through all of the pages in the Overview to have a complete understanding of Bean Machine. If you're interested, the full source code for this Overview is available as a notebook on GitHub and Colab. Happy modeling!

Modeling​

As a quick refresher, we're writing a model to understand a disease's reproduction rate, based on the number of new cases of that disease we've seen. Though we never observe the true reproduction rate, let's start off with a prior distribution that represents our beliefs about the reproduction rate before seeing any data.

import beanmachine.ppl as bm
import torch.distributions as dist

reproduction_rate_rate = 10.0

@bm.random_variable
def reproduction_rate():
# An Exponential distribution with rate 10 has mean 0.1.
return dist.Exponential(rate=reproduction_rate_rate)

There are a few things to notice here!

  • Most importantly, we've decorated this function with @bm.random_variable. This is how you tell Bean Machine to interpret this function probabilistically. @bm.random_variable functions are the building blocks of Bean Machine models, and let the framework explore different values that the function represents when fitting a good distribution for observed data that you'll provide later.
  • Next, notice that the function returns a PyTorch distribution. This distribution encodes your prior belief about a particular random variable. In the case of Exponential(10.0)\text{Exponential}(10.0), our prior has this shape:
loading...
  • As you can see, the prior encourages smaller values for the reproduction rate, averaging at a rate of 10%10\%, but allows for the possibility of much larger spread rates.
  • Lastly, realize that although you've provided a prior distribution here, the framework will automatically "refine" this distribution, as it searches for values that represent observed data that you'll provide later. So, after we fit the model to observed data, the random variable will no longer look like the graph shown above!

The last piece of the model describes how the reproduction rate relates to the new cases we observe the subsequent day. This number of new cases is related to the underlying reproduction rate -- how fast the virus tends to spread -- as well as the current number of cases. However, it's not a deterministic function of those two values. Instead, it depends on a lot of environmental factors like social behavior, stochasticity of transmission, and so on. It would be far too complicated to capture all of those factors in a single model. Instead, we'll aggregate all of these environmental factors in the form of a probability distribution, the Poisson\text{Poisson} distribution.

Let's say, for this example, we observed a little over a million, 10879801087980, cases today. We use such a precise number here to remind you that this is a known value and not a random one. In this case, if the disease were to happen to have a reproduction rate of 0.10.1, this is what our Poisson\text{Poisson} distribution for new cases would look like:

loading...

Let's write this up in Bean Machine. Using the syntax we've already seen, it's pretty simple:

@bm.random_variable
def num_new(num_current):
return dist.Poisson(reproduction_rate() * num_current)

As you can see, this function relies on the reproduction_rate() that we defined before. Do notice: even though reproduction_rate() returns a distribution, here the return value from reproduction_rate() is treated like a sample from that distribution! Bean Machine works hard behind the scenes to sample efficiently from distributions, so that you can easily build sophisticated models that only have to reason about these samples.

Data​

With the model fully defined, we should gather some data to learn about! In the real world, you might work with a government agency to determine the number of real, new cases observed on the next day. For the sake of our example, let's say that we observed 238154238154 new cases on the next day. Bean Machine's random variable syntax allows you to bind this information directly as an observation for the num_new() random variable within a simple Python dictionary. Here's how to do it:

from torch import tensor

num_init = 1087980

observations = {
# PyTorch distributions expect tensors
num_new(num_init): tensor(238154.),
}

Using a random variable function and its arguments as keys in this dictionary may feel unusual at first, but it quickly becomes an intuitive way to reference these random variable functions by name! Note also that we're using num_init as an argument to the random variable function. This might seem unnecessary, since num_init could simply remain a global constant in this example, but a similar indexing scheme for num_new() will come in handy when we extend the model to time series with more than a single time step.

Inference​

With model and observations in hand, we're ready for the fun part: inference! Inference is the process of combining a model with data to obtain insights, in the form of probability distributions over values of interest. Bean Machine offers a powerful and general inference framework to enable fitting arbitrary models to data.

The call to inference involves first creating an appropriate inference engine object and then invoking the infer method:

samples = bm.CompositionalInference().infer(
queries=[reproduction_rate()],
observations=observations,
num_samples=7000,
num_adaptive_samples=3000,
)

There's a lot going on here! First, let's take a look at the inference method that we used, CompositionalInference(). Bean Machine supports generic inference, which means that it can fit your model to the data without knowing the intricate and particular workings of the model that you defined. However, there are lots of ways of performing this, and Bean Machine supports a rich library of inference methods that can work for different kinds of models. For now, all you need to know is that CompositionalInference is a general inference strategy that will try to automatically determine the best inference method(s) to use for your model, based on the definitions of random variables you've provided. It should work well for this simple model. You can check out our guides on Inference to learn more!

Let's take a look at the parameters to infer(). In queries, you provide a list of random variables that you're interested in learning about. Bean Machine will learn probability distributions for these, and will return them to you when inference completes! Note that this uses exactly the same pattern to reference random variables that we used when binding data.

We bind our real-world observations with the observations parameter. This provides a set of probabilistic constraints that Bean Machine seeks to satisfy during inference. In particular, Bean Machine tries to fit probability distributions for unobserved random variables, so that those probability distributions explain the observed data -- and your prior beliefs -- well.

Lastly, num_samples is the number of samples that you want to learn. Bean Machine doesn't learn smooth probability distributions for your queries, but instead accumulates a representative set of samples from those distributions. This parameter lets you specify how many samples should comprise these distributions.

Analysis​

Our results are ready! Let's visualize them for the reproduction rate parameter.

The samples object that we have now contains samples from the probability distributions that we've fit for our model and data. It supports dictionary-like indexing using -- you guessed it -- the same random variable referencing syntax we've seen before. A second index (here, [0]) selects one of the inference chains generated by the sampling algorithm; this will be explained in the Inference section, so let us just use 0 for now.

reproduction_rate_samples = samples[reproduction_rate()][0]
reproduction_rate_samples
Out:

tensor([0.0146, 0.1720, 0.1720, ..., 0.2187, 0.2187, 0.2187])

Let's visualize that more intuitively.

loading...

This histogram represents our beliefs over the underlying reproduction rate, after observing the current day's worth of new cases. You'll note that it is balancing our prior beliefs with the rate that we would learn just from looking at the new data. It also captures the uncertainty inherent in our estimate!

We're Not Done Yet!​

This is the tip of the iceberg. The rest of this Overview will cover critical concepts from the above sections. Read on to learn how to make the most of Bean Machine's powerful modeling and inference systems!