beanmachine.ppl.inference.vi.variational_world module
- class beanmachine.ppl.inference.vi.variational_world.VariationalWorld(observations: Optional[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, initialize_fn: Callable[[torch.distributions.distribution.Distribution], torch.Tensor] = <function init_from_prior>, params: Optional[MutableMapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]] = None, queries_to_guides: Optional[Mapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, beanmachine.ppl.model.rv_identifier.RVIdentifier]] = None)
Bases:
beanmachine.ppl.world.base_world.BaseWorld
,Mapping
[beanmachine.ppl.model.rv_identifier.RVIdentifier
,torch.Tensor
]A World which also contains (variational) parameters.
- copy()
- Returns
Shallow copy of the current world.
- get_guide_distribution(rv: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.distributions.distribution.Distribution
- get_param(param: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
Gets a parameter or initializes it if not found.
- set_params(params: MutableMapping[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor])
Sets the parameters in this World to specified values.
- update_graph(node: beanmachine.ppl.model.rv_identifier.RVIdentifier) torch.Tensor
Initialize a new node using its guide if available and the prior otherwise.
- Parameters
node (RVIdentifier) – RVIdentifier of node to update in the graph.
- Returns
The value of the node stored in world (in original space).