beanmachine.ppl.inference.base_inference module
- class beanmachine.ppl.inference.base_inference.BaseInference
Bases:
object
Abstract class all inference methods should inherit from.
- abstract get_proposers(world: beanmachine.ppl.world.world.World, target_rvs: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier], num_adaptive_sample: int) List[beanmachine.ppl.inference.proposer.base_proposer.BaseProposer]
Returns the proposer(s) corresponding to every non-observed variable in target_rvs. Should be implemented by the specific inference algorithm.
- infer(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: int, num_chains: int = 4, num_adaptive_samples: Optional[int] = None, show_progress_bar: bool = True, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_init_retries: int = 100, run_in_parallel: bool = False, mp_context: Optional[typing_extensions.Literal[fork, spawn, forkserver]] = None, verbose: Optional[beanmachine.ppl.inference.utils.VerboseLevel] = None) beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples
Performs inference and returns a
MonteCarloSamples
object with samples from the posterior.- Parameters
queries – List of queries
observations – Observations as an RVDict keyed by RVIdentifier
num_samples – Number of samples.
num_chains – Number of chains to run, defaults to 4.
num_adaptive_samples – Number of adaptive samples. If not provided, BM will fall back to algorithm-specific default value based on num_samples.
show_progress_bar – Whether to display the progress bar, defaults to True.
initialize_fn – A callable that takes in a distribution and returns a Tensor. The default behavior is to sample from Uniform(-2, 2) then biject to the support of the distribution.
max_init_retries – The number of attempts to make to initialize values for an inference before throwing an error (default to 100).
run_in_parallel – Whether to run multiple chains in parallel (with multiple processes).
mp_context – The multiprocessing context to used for parallel inference.
verbose – (Deprecated) Whether to display the progress bar. This option is deprecated, please use
show_progress_bar
instead.
- sampler(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: Optional[int] = None, num_adaptive_samples: Optional[int] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_init_retries: int = 100) beanmachine.ppl.inference.sampler.Sampler
Returns a generator that returns a new world (represents a new state of the graph) each time it is iterated. If num_samples is not provided, this method will return an infinite generator.
- Parameters
queries – List of queries
observations – Observations as an RVDict keyed by RVIdentifier
num_samples – Number of samples, defaults to None for an infinite sampler.
num_adaptive_samples – Number of adaptive samples. If not provided, BM will fall back to algorithm-specific default value based on num_samples. If num_samples is not provided either, then defaults to 0.
initialize_fn – A callable that takes in a distribution and returns a Tensor. The default behavior is to sample from Uniform(-2, 2) then biject to the support of the distribution.
max_init_retries – The number of attempts to make to initialize values for an inference before throwing an error (default to 100).