beanmachine.ppl.inference.vi.gradient_estimator module
Gradient estimators of f-divergences.
- beanmachine.ppl.inference.vi.gradient_estimator.monte_carlo_approximate_reparam(observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: int, discrepancy_fn: Callable[[torch.Tensor], torch.Tensor], params: Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], queries_to_guides: Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier], subsample_factor: float = 1.0, device: torch.device = device(type='cpu')) torch.Tensor
The pathwise derivative / reparameterization trick (https://arxiv.org/abs/1312.6114) gradient estimator.
- beanmachine.ppl.inference.vi.gradient_estimator.monte_carlo_approximate_sf(observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: int, discrepancy_fn: Callable[[torch.Tensor], torch.Tensor], params: Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], queries_to_guides: Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier], subsample_factor: float = 1, device: torch.device = device(type='cpu')) torch.Tensor
The score function / log derivative trick surrogate loss (https://arxiv.org/pdf/1506.05254) gradient estimator.