beanmachine.ppl.world.world module

class beanmachine.ppl.world.world.World(observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_from_prior>)

Bases: beanmachine.ppl.world.base_world.BaseWorld, Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]

A World represents an instantiation of the graphical model and can be manipulated or evaluated. In the context of MCMC inference, a world represents a single Monte Carlo posterior sample.

A World can also be used as a context manager to run and sample random variables. Example:

@bm.random_variable
def foo():
  return Normal(0., 1.)

world = World()
with world:
  x = foo()  # returns a sample, ie tensor.

with world:
  y = foo()  # same world = same tensor

assert x == y
Parameters
  • observations (Optional) – Optional observations, which fixes the random variables to observed values

  • initialize_fn (callable, Optional) – Callable which takes a torch.distribution object as argument and returns a torch.Tensor

copy() beanmachine.ppl.world.world.World
Returns

Shallow copy of the current world.

enumerate_node(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
Parameters

node (RVIdentifier) – RVIdentifier of node.

Returns

A tensor enumerating the support of the node.

get_variable(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) beanmachine.ppl.world.variable.Variable
Parameters

node (RVIdentifier) – RVIdentifier of node.

Returns

Variable object that contains the metadata of the current node in the world.

initialize_value(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) None
classmethod initialize_world(queries: Iterable[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_retries: int = 100, **kwargs) beanmachine.ppl.world.world.T

Initializes a world with all of the random variables (queries and observations). In case of initializing values outside of support of the distributions, the method will keep resampling until a valid initialization is found up to max_retries times.

Parameters
  • queries – A list of random variables that need to be inferred.

  • observations – Observations, which fixes the random variables to observed values

  • initialize_fn – Function for initializing the values of random variables

  • max_retries – The number of attempts this method will make before throwing an error (default to 100).

property latent_nodes: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier]

All the latent nodes in the current world.

log_prob(nodes: Optional[Collection[beanmachine.ppl.model.rv_identifier.RVIdentifier]] = None) torch.Tensor
Parameters

nodes (Optional) – Optional collection of RVIdentifiers to evaluate the log prob of a subset of the graph. If none is specified, then all the variables in the world are used.

Returns

The joint log prob of all of the nodes in the current world

replace(values: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) beanmachine.ppl.world.world.World
Parameters

values (RVDict) – Dict of RVIdentifiers and their values to replace.

Returns

A new world where values specified in the dictionary are replaced. This method will update the internal graph structure.

update_graph(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor

This function adds a node to the graph and initialize its value if the node is not found in the graph already.

Parameters

node (RVIdentifier) – RVIdentifier of node to update in the graph.

Returns

The value of the node stored in world (in original space).