beanmachine.ppl.world package
Submodules
Module contents
- class beanmachine.ppl.world.BetaDimensionTransform(cache_size=0)
Bases:
torch.distributions.transforms.Transform
Volume preserving transformation to the Beta distribution support.
- bijective = True
- codomain: torch.distributions.constraints.Constraint = IndependentConstraint(Real(), 1)
- domain: torch.distributions.constraints.Constraint = Real()
- forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
- inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
- log_abs_det_jacobian(x, y)
Computes the log det jacobian log |dy/dx| given input and output.
- class beanmachine.ppl.world.World(observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_from_prior>)
Bases:
beanmachine.ppl.world.base_world.BaseWorld
,Mapping
[beanmachine.ppl.model.rv_identifier.RVIdentifier
,torch.Tensor
]A World represents an instantiation of the graphical model and can be manipulated or evaluated. In the context of MCMC inference, a world represents a single Monte Carlo posterior sample.
A World can also be used as a context manager to run and sample random variables. Example:
@bm.random_variable def foo(): return Normal(0., 1.) world = World() with world: x = foo() # returns a sample, ie tensor. with world: y = foo() # same world = same tensor assert x == y
- Parameters
observations (Optional) – Optional observations, which fixes the random variables to observed values
initialize_fn (callable, Optional) – Callable which takes a
torch.distribution
object as argument and returns atorch.Tensor
- copy() beanmachine.ppl.world.world.World
- Returns
Shallow copy of the current world.
- enumerate_node(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
- Parameters
node (RVIdentifier) – RVIdentifier of node.
- Returns
A tensor enumerating the support of the node.
- get_variable(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) beanmachine.ppl.world.variable.Variable
- Parameters
node (RVIdentifier) – RVIdentifier of node.
- Returns
Variable object that contains the metadata of the current node in the world.
- initialize_value(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) None
- classmethod initialize_world(queries: Iterable[beanmachine.ppl.model.rv_identifier.RVIdentifier], observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_to_uniform>, max_retries: int = 100, **kwargs) beanmachine.ppl.world.world.T
Initializes a world with all of the random variables (queries and observations). In case of initializing values outside of support of the distributions, the method will keep resampling until a valid initialization is found up to
max_retries
times.- Parameters
queries – A list of random variables that need to be inferred.
observations – Observations, which fixes the random variables to observed values
initialize_fn – Function for initializing the values of random variables
max_retries – The number of attempts this method will make before throwing an error (default to 100).
- property latent_nodes: Set[beanmachine.ppl.model.rv_identifier.RVIdentifier]
All the latent nodes in the current world.
- log_prob(nodes: Optional[Collection[beanmachine.ppl.model.rv_identifier.RVIdentifier]] = None) torch.Tensor
- Parameters
nodes (Optional) – Optional collection of RVIdentifiers to evaluate the log prob of a subset of the graph. If none is specified, then all the variables in the world are used.
- Returns
The joint log prob of all of the nodes in the current world
- replace(values: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) beanmachine.ppl.world.world.World
- Parameters
values (RVDict) – Dict of RVIdentifiers and their values to replace.
- Returns
A new world where values specified in the dictionary are replaced. This method will update the internal graph structure.
- update_graph(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
This function adds a node to the graph and initialize its value if the node is not found in the graph already.
- Parameters
node (RVIdentifier) – RVIdentifier of node to update in the graph.
- Returns
The value of the node stored in world (in original space).
- beanmachine.ppl.world.get_default_transforms(distribution: torch.distributions.distribution.Distribution) torch.distributions.transforms.Transform
Get transforms of a distribution to transform it from constrained space into unconstrained space.
- Parameters
distribution – the distribution to check
- Returns
a Transform that need to be applied to the distribution to transform it from constrained space into unconstrained space
- beanmachine.ppl.world.get_world_context() Optional[beanmachine.ppl.world.base_world.BaseWorld]
- beanmachine.ppl.world.init_from_prior(distribution: torch.distributions.distribution.Distribution) torch.Tensor
Samples from the distribution.
Used as an arg for
World
- Parameters
distribution –
torch.distribution.Distribution
corresponding to the distribution to sample from
- beanmachine.ppl.world.init_to_uniform(distribution: torch.distributions.distribution.Distribution) torch.Tensor
Initializes a uniform distribution to sample from transformed to the support of
distribution
. A Categorical is used for discrete distributions, a bijective transform is used for constrained continuous distributions, anddistribution
is used otherwise.Used as an arg for
World
- Parameters
distribution –
torch.distribution.Distribution
of the RV, usually the prior distribution.