beanmachine.ppl.world.world module
- class beanmachine.ppl.world.world.World(observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_from_prior>)
Bases:
beanmachine.ppl.world.base_world.BaseWorld
,Mapping
[beanmachine.ppl.model.rv_identifier.RVIdentifier
,torch.Tensor
]A World represents an instantiation of the graphical model and can be manipulated or evaluated. In the context of MCMC inference, a world represents a single Monte Carlo posterior sample.
A World can also be used as a context manager to run and sample random variables. Example:
@bm.random_variable def foo(): return Normal(0., 1.) world = World() with world: x = foo() # returns a sample, ie tensor. with world: y = foo() # same world = same tensor assert x == y
- Parameters
observations (Optional) – Optional observations, which fixes the random variables to observed values
initialize_fn (callable, Optional) – Callable which takes a
torch.distribution
object as argument and returns atorch.Tensor
- copy() beanmachine.ppl.world.world.World
- Returns
Shallow copy of the current world.
- enumerate_node(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
- Parameters
node (RVIdentifier) – RVIdentifier of node.
- Returns
A tensor enumerating the support of the node.
- get_variable(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) beanmachine.ppl.world.variable.Variable
- Parameters
node (RVIdentifier) – RVIdentifier of node.
- Returns
Variable object that contains the metadata of the current node in the world.
- initialize_value(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) None
- classmethod initialize_world(queries: Iterable[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_retries: int = 100, **kwargs) beanmachine.ppl.world.world.T
Initializes a world with all of the random variables (queries and observations). In case of initializing values outside of support of the distributions, the method will keep resampling until a valid initialization is found up to
max_retries
times.- Parameters
queries – A list of random variables that need to be inferred.
observations – Observations, which fixes the random variables to observed values
initialize_fn – Function for initializing the values of random variables
max_retries – The number of attempts this method will make before throwing an error (default to 100).
- property latent_nodes: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier]
All the latent nodes in the current world.
- log_prob(nodes: Optional[Collection[beanmachine.ppl.model.rv_identifier.RVIdentifier]] = None) torch.Tensor
- Parameters
nodes (Optional) – Optional collection of RVIdentifiers to evaluate the log prob of a subset of the graph. If none is specified, then all the variables in the world are used.
- Returns
The joint log prob of all of the nodes in the current world
- replace(values: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) beanmachine.ppl.world.world.World
- Parameters
values (RVDict) – Dict of RVIdentifiers and their values to replace.
- Returns
A new world where values specified in the dictionary are replaced. This method will update the internal graph structure.
- update_graph(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
This function adds a node to the graph and initialize its value if the node is not found in the graph already.
- Parameters
node (RVIdentifier) – RVIdentifier of node to update in the graph.
- Returns
The value of the node stored in world (in original space).