beanmachine.ppl.inference.bmg_inference module
An inference engine which uses Bean Machine Graph to make inferences on Bean Machine models.
- class beanmachine.ppl.inference.bmg_inference.BMGInference
Bases:
object
Interface to Bean Machine Graph (BMG) Inference, an experimental framework for high-performance implementations of inference algorithms.
Internally, BMGInference consists of a compiler and C++ runtime implementations of various inference algorithms. Currently, only Newtonian Monte Carlo (NMC) inference is supported, and is the algorithm used by default.
Please note that this is a highly experimental implementation under active development, and that the subset of Bean Machine model is limited. Limitations include that the runtime graph should be static (meaning, it does not change during inference), and that the types of primitive distributions supported is currently limited.
- infer(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], num_samples: int, num_chains: int = 4, num_adaptive_samples: int = 0, inference_type: beanmachine.graph.InferenceType = <InferenceType.NMC: 3>, skip_optimizations: Set[str] = {'beta_bernoulli_conjugate_fixer', 'beta_binomial_conjugate_fixer', 'normal_normal_conjugate_fixer'}) beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples
Perform inference by (runtime) compilation of Python source code associated with its parameters, constructing a BMG graph, and then calling the BMG implementation of a particular inference method on this graph.
- Parameters
queries – queried random variables
observations – observations dict
num_samples – number of samples in each chain
num_chains – number of chains generated
num_adaptive_samples – number of burn in samples to discard
inference_type – inference method, currently only NMC is supported
skip_optimizations – list of optimization to disable in this call
- Returns
The requested samples
- Return type
- to_bm_python(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) str
Produce a string containing a BM Python program from the graph.
- to_cpp(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) str
Produce a string containing a C++ program fragment which produces the graph deduced from the model.
- to_dot(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], after_transform: bool = True, label_edges: bool = False, skip_optimizations: Set[str] = {'beta_bernoulli_conjugate_fixer', 'beta_binomial_conjugate_fixer', 'normal_normal_conjugate_fixer'}) str
Produce a string containing a program in the GraphViz DOT language representing the graph deduced from the model.
- to_graph(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) Tuple[beanmachine.graph.Graph, Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, int]]
Produce a BMG graph and a map from queried RVIdentifiers to the corresponding indices of the inference results.
- to_graphviz(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor], after_transform: bool = True, label_edges: bool = False, skip_optimizations: Set[str] = {'beta_bernoulli_conjugate_fixer', 'beta_binomial_conjugate_fixer', 'normal_normal_conjugate_fixer'}) graphviz.sources.Source
Small wrapper to generate an actual graphviz object
- to_python(queries: List[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) str
Produce a string containing a Python program fragment which produces the graph deduced from the model.