beanmachine.ppl.inference.proposer.nuts_proposer module

class beanmachine.ppl.inference.proposer.nuts_proposer.NUTSProposer(initial_world:, target_rvs: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier], num_adaptive_sample: int, max_tree_depth: int = 10, max_delta_energy: float = 1000.0, initial_step_size: float = 1.0, adapt_step_size: bool = True, adapt_mass_matrix: bool = True, full_mass_matrix: bool = False, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, jit_backend: beanmachine.ppl.experimental.torch_jit_backend.TorchJITBackend = TorchJITBackend.NNC)

Bases: beanmachine.ppl.inference.proposer.hmc_proposer.HMCProposer

The No-U-Turn Sampler (NUTS) as described in [1]. Unlike vanilla HMC, it does not require users to specify a trajectory length. The current implementation roughly follows Algorithm 6 of [1]. If multinomial_sampling is True, then the next state will be drawn from a multinomial distribution (weighted by acceptance probability, as introduced in Appendix 2 of [2]) instead of drawn uniformly.

[1] Matthew Hoffman and Andrew Gelman. “The No-U-Turn Sampler: Adaptively

Setting Path Lengths in Hamiltonian Monte Carlo” (2014).

[2] Michael Betancourt. “A Conceptual Introduction to Hamiltonian Monte Carlo”


  • initial_world – Initial world to propose from.

  • target_rvs – Set of RVIdentifiers to indicate which variables to propose.

  • num_adaptive_samples – Number of adaptive samples to run.

  • max_tree_depth – Maximum tree depth, defaults to 10.

  • max_delta_energy – Maximum delta energy (for numerical stability), defaults to 1000.

  • initial_step_size – Defaults to 1.0.

  • adapt_step_size – Whether to adapt step size with Dual averaging as suggested in [1], defaults to True.

  • adapt_mass_matrix – Whether to adapt mass matrix using Welford Scheme, defaults to True.

  • multinomial_sampling – Whether to use multinomial sampling as in [2], defaults to True.

  • target_accept_prob – Target accept probability. Increasing this would lead to smaller step size. Defaults to 0.8.

  • nnc_compile – If True, NNC compiler will be used to accelerate the inference.

propose(world: Tuple[, torch.Tensor]