beanmachine.ppl.model.statistical_model module
- class beanmachine.ppl.model.statistical_model.StatisticalModel
Bases:
object
Parent class to all statistical models implemented in Bean Machine.
Every random variable in the model needs to be defined with function declaration wrapped the
bm.random_variable
.Every deterministic functional that a user would like to query during inference should be wrapped in a
bm.functional
.Every parameter of the guide distribution that is to be learned via variational inference should be wrapped in a
bm.param
.- static functional(f: Callable[[beanmachine.ppl.model.statistical_model.P], torch.Tensor]) Callable[[beanmachine.ppl.model.statistical_model.P], Union[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]]
Decorator to be used for every query defined in statistical model, which are functions of
bm.random_variable
@bm.random_variable def foo(): return Normal(0., 1.) @bm.functional(): def bar(): return foo() * 2.0
- static get_func_key(wrapper, arguments) beanmachine.ppl.model.rv_identifier.RVIdentifier
Creates a key to uniquely identify the Random Variable.
- Parameters
wrapper – reference to the wrapper function
arguments – function arguments
- Returns
Tuple of function and arguments which is to be used to identify a particular function call.
- static param(init_fn)
Decorator to be used for params (variable to be optimized with VI).:
@bm.param def mu(): return torch.zeros(2) @bm.random_variable def foo(): return Normal(mu(), 1.)
- static random_variable(f: Callable[[beanmachine.ppl.model.statistical_model.P], torch.distributions.distribution.Distribution]) Callable[[beanmachine.ppl.model.statistical_model.P], Union[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]]
Decorator to be used for every stochastic random variable defined in all statistical models. E.g.:
@bm.random_variable def foo(): return Normal(0., 1.) def foo(): return Normal(0., 1.) foo = bm.random_variable(foo)
- beanmachine.ppl.model.statistical_model.functional(f: Callable[[beanmachine.ppl.model.statistical_model.P], torch.Tensor]) Callable[[beanmachine.ppl.model.statistical_model.P], Union[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]]
Decorator to be used for every query defined in statistical model, which are functions of
bm.random_variable
@bm.random_variable def foo(): return Normal(0., 1.) @bm.functional(): def bar(): return foo() * 2.0
- beanmachine.ppl.model.statistical_model.param(init_fn)
Decorator to be used for params (variable to be optimized with VI).:
@bm.param def mu(): return torch.zeros(2) @bm.random_variable def foo(): return Normal(mu(), 1.)
- beanmachine.ppl.model.statistical_model.random_variable(f: Callable[[beanmachine.ppl.model.statistical_model.P], torch.distributions.distribution.Distribution]) Callable[[beanmachine.ppl.model.statistical_model.P], Union[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]]
Decorator to be used for every stochastic random variable defined in all statistical models. E.g.:
@bm.random_variable def foo(): return Normal(0., 1.) def foo(): return Normal(0., 1.) foo = bm.random_variable(foo)