beanmachine.ppl.inference.vi.variational_infer module

class beanmachine.ppl.inference.vi.variational_infer.VariationalInfer(queries_to_guides: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], optimizer: Callable[[torch.Tensor], torch.optim.optimizer.Optimizer] = <function VariationalInfer.<lambda>>, device: torch.device = device(type='cpu'))

Bases: object

infer(num_steps: int, num_samples: int = 1, discrepancy_fn=<function kl_reverse>, mc_approx=<function monte_carlo_approximate_reparam>, step_callback: Optional[Callable[[int, torch.Tensor, beanmachine.ppl.inference.vi.variational_infer.VariationalInfer], None]] = None, subsample_factor: float = 1) beanmachine.ppl.inference.vi.variational_world.VariationalWorld

Perform variatonal inference.

Parameters
  • num_steps – number of optimizer steps

  • num_samples – number of samples per Monte-Carlo gradient estimate of E[f(logp - logq)]

  • discrepancy_fn – discrepancy function f, use kl_reverse to minimize negative ELBO

  • mc_approx – Monte-Carlo gradient estimator to use

  • step_callback – callback function invoked each optimizer step

  • subsample_factor – subsampling factor used for subsampling, helps scale the observations to avoid overshrinking towards the prior

Returns

A world with variational guide distributions initialized with optimized parameters

Return type

VariationalWorld

initialize_world() beanmachine.ppl.inference.vi.variational_world.VariationalWorld

Initializes a VariationalWorld using samples from guide distributions evaluated at the current parameter values.

Returns

a World where guide samples and distributions have replaced their corresponding queries

Return type

VariationalWorld

step(num_samples: int = 1, discrepancy_fn=<function kl_reverse>, mc_approx=<function monte_carlo_approximate_reparam>, subsample_factor: float = 1) torch.Tensor

Perform one step of variatonal inference.

Parameters
  • num_samples – number of samples per Monte-Carlo gradient estimate of E[f(logp - logq)]

  • discrepancy_fn – discrepancy function f, use kl_reverse to minimize negative ELBO

  • mc_approx – Monte-Carlo gradient estimator to use

  • subsample_factor – subsampling factor used for subsampling, helps scale the observations to avoid overshrinking towards the prior

Returns

the loss value (before the step)

Return type

torch.Tensor