beanmachine.ppl.inference.base_inference module

class beanmachine.ppl.inference.base_inference.BaseInference

Bases: object

Abstract class all inference methods should inherit from.

abstract get_proposers(world: beanmachine.ppl.world.world.World, target_rvs: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier], num_adaptive_sample: int) List[beanmachine.ppl.inference.proposer.base_proposer.BaseProposer]

Returns the proposer(s) corresponding to every non-observed variable in target_rvs. Should be implemented by the specific inference algorithm.

infer(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: int, num_chains: int = 4, num_adaptive_samples: Optional[int] = None, verbose: beanmachine.ppl.inference.utils.VerboseLevel = VerboseLevel.LOAD_BAR, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_init_retries: int = 100, run_in_parallel: bool = False, mp_context: Optional[typing_extensions.Literal[fork, spawn, forkserver]] = None) beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples

Performs inference and returns a MonteCarloSamples object with samples from the posterior.

Parameters
  • queries – List of queries

  • observations – Observations as an RVDict keyed by RVIdentifier

  • num_samples – Number of samples.

  • num_chains – Number of chains to run, defaults to 4.

  • num_adaptive_samples – Number of adaptive samples. If not provided, BM will fall back to algorithm-specific default value based on num_samples.

  • verbose – Whether to display the progress bar or not.

  • initialize_fn – A callable that takes in a distribution and returns a Tensor. The default behavior is to sample from Uniform(-2, 2) then biject to the support of the distribution.

  • max_init_retries – The number of attempts to make to initialize values for an inference before throwing an error (default to 100).

  • run_in_parallel – Whether to run multiple chains in parallel (with multiple processes) or not.

  • mp_context – The multiprocessing context to used for parallel inference.

sampler(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: Optional[int] = None, num_adaptive_samples: Optional[int] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_init_retries: int = 100) beanmachine.ppl.inference.sampler.Sampler

Returns a generator that returns a new world (represents a new state of the graph) each time it is iterated. If num_samples is not provided, this method will return an infinite generator.

Parameters
  • queries – List of queries

  • observations – Observations as an RVDict keyed by RVIdentifier

  • num_samples – Number of samples, defaults to None for an infinite sampler.

  • num_adaptive_samples – Number of adaptive samples. If not provided, BM will fall back to algorithm-specific default value based on num_samples. If num_samples is not provided either, then defaults to 0.

  • initialize_fn – A callable that takes in a distribution and returns a Tensor. The default behavior is to sample from Uniform(-2, 2) then biject to the support of the distribution.

  • max_init_retries – The number of attempts to make to initialize values for an inference before throwing an error (default to 100).