beanmachine.ppl.inference.vi.variational_infer module
- class beanmachine.ppl.inference.vi.variational_infer.VariationalInfer(queries_to_guides: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], optimizer: Callable[[torch.Tensor], torch.optim.optimizer.Optimizer] = <function VariationalInfer.<lambda>>, device: torch.device = device(type='cpu'))
Bases:
object
- infer(num_steps: int, num_samples: int = 1, discrepancy_fn=<function kl_reverse>, mc_approx=<function monte_carlo_approximate_reparam>, step_callback: Optional[Callable[[int, torch.Tensor, beanmachine.ppl.inference.vi.variational_infer.VariationalInfer], None]] = None, subsample_factor: float = 1) beanmachine.ppl.inference.vi.variational_world.VariationalWorld
Perform variatonal inference.
- Parameters
num_steps – number of optimizer steps
num_samples – number of samples per Monte-Carlo gradient estimate of E[f(logp - logq)]
discrepancy_fn – discrepancy function f, use
kl_reverse
to minimize negative ELBOmc_approx – Monte-Carlo gradient estimator to use
step_callback – callback function invoked each optimizer step
subsample_factor – subsampling factor used for subsampling, helps scale the observations to avoid overshrinking towards the prior
- Returns
A world with variational guide distributions initialized with optimized parameters
- Return type
- initialize_world() beanmachine.ppl.inference.vi.variational_world.VariationalWorld
Initializes a VariationalWorld using samples from guide distributions evaluated at the current parameter values.
- Returns
a World where guide samples and distributions have replaced their corresponding queries
- Return type
- step(num_samples: int = 1, discrepancy_fn=<function kl_reverse>, mc_approx=<function monte_carlo_approximate_reparam>, subsample_factor: float = 1) torch.Tensor
Perform one step of variatonal inference.
- Parameters
num_samples – number of samples per Monte-Carlo gradient estimate of E[f(logp - logq)]
discrepancy_fn – discrepancy function f, use
kl_reverse
to minimize negative ELBOmc_approx – Monte-Carlo gradient estimator to use
subsample_factor – subsampling factor used for subsampling, helps scale the observations to avoid overshrinking towards the prior
- Returns
the loss value (before the step)
- Return type
torch.Tensor