Variational Inference
Paramsβ
A Param
represents
a variational parameter to be optimized during variational inference.
Use @bm.param
to decorate an "initialization fuction" which returns a
tensor value to initialize the variational parameter at the start of optimization.
Variational Worldsβ
A VariationalWorld
is a sub-class of World
which also contains data on guide distributions and their parameters, specifically:
get_guide_distribution
: given aRVIdentifier
, returns its corresponding guide distributionget_param
: given aRVIdentifier
for aParam
, returns (possibly initializing if empty) the value of the parameter
Note: An implementation detail is that update_graph
is overriden such that the
guide distribution is automatically used if one is available.
Gradient Estimators and Divergencesβ
A gradient_estimator
computes a Monte-Carlo (possibly surrogate) objective estimate whose gradients
are used as the training signal.
We structure our VI objective following abstractions introduced in
f-Divergence Variational Inference, where
gradient_estimator
takes as input a discrepancy
function
corresponding to an -divergence.
VariationalInferβ
The VariationalInfer
class provides an entrypoint for VI. Model and guide RVIdentifier
s are associated in the
constructor's queries_to_guides
argument and optimizater configuration is provided through
a optimizer
callback. An infer()
method is provided for easy invocation whereas step()
permits more customized interactions (e.g. tensorboard callbacks).
AutoGuidesβ
Manually defining a guide for each random variable can become tedious.
AutoGuideVI
provides an initialization strategy for VariationalInfer
which
automatically defines guides through calling a method
get_guide(query: RVIdentifier, distrib: dist.Distribution)
implemented by
subclasses.
All AutoGuide
s currently make a mean-field assumption over RVIdentifiers
:
ADVIβ
In Automatic Differentiation Variational Inference (ADVI), a properly-sized Gaussian is used as a guide to approximate each site:
MAPβ
In Maximum A Posteriori (MAP) inference,
a Delta
point estimate is used as the guide for each site: