Variational Inference
Paramsβ
A Param represents
a variational parameter to be optimized during variational inference.
Use @bm.param to decorate an "initialization fuction" which returns a
tensor value to initialize the variational parameter at the start of optimization.
Variational Worldsβ
A VariationalWorld
is a sub-class of World
which also contains data on guide distributions and their parameters, specifically:
get_guide_distribution: given aRVIdentifier, returns its corresponding guide distributionget_param: given aRVIdentifierfor aParam, returns (possibly initializing if empty) the value of the parameter
Note: An implementation detail is that update_graph is overriden such that the
guide distribution is automatically used if one is available.
Gradient Estimators and Divergencesβ
A gradient_estimator
computes a Monte-Carlo (possibly surrogate) objective estimate whose gradients
are used as the training signal.
We structure our VI objective following abstractions introduced in
f-Divergence Variational Inference, where
gradient_estimator takes as input a discrepancy
function
corresponding to an -divergence.
VariationalInferβ
The VariationalInfer
class provides an entrypoint for VI. Model and guide RVIdentifiers are associated in the
constructor's queries_to_guides argument and optimizater configuration is provided through
a optimizer callback. An infer() method is provided for easy invocation whereas step()
permits more customized interactions (e.g. tensorboard callbacks).
AutoGuidesβ
Manually defining a guide for each random variable can become tedious.
AutoGuideVI
provides an initialization strategy for VariationalInfer which
automatically defines guides through calling a method
get_guide(query: RVIdentifier, distrib: dist.Distribution) implemented by
subclasses.
All AutoGuides currently make a mean-field assumption over RVIdentifiers:
ADVIβ
In Automatic Differentiation Variational Inference (ADVI), a properly-sized Gaussian is used as a guide to approximate each site:
MAPβ
In Maximum A Posteriori (MAP) inference,
a Delta
point estimate is used as the guide for each site: