beanmachine.ppl.inference.vi.variational_world module

class beanmachine.ppl.inference.vi.variational_world.VariationalWorld(observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_from_prior>, params: Optional[MutableMapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, queries_to_guides: Optional[Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier]] = None)

Bases: beanmachine.ppl.world.base_world.BaseWorld, Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]

A World which also contains (variational) parameters.

copy()
Returns

Shallow copy of the current world.

get_guide_distribution(rv: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.distributions.distribution.Distribution
get_param(param: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor

Gets a parameter or initializes it if not found.

set_params(params: MutableMapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor])

Sets the parameters in this World to specified values.

update_graph(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor

Initialize a new node using its guide if available and the prior otherwise.

Parameters

node (RVIdentifier) – RVIdentifier of node to update in the graph.

Returns

The value of the node stored in world (in original space).