beanmachine.ppl.inference.utils module
- class beanmachine.ppl.inference.utils.VerboseLevel(value)
Bases:
enum.Enum
Enum class which is used to set how much output is printed during inference. LOAD_BAR enables tqdm for full inference loop.
- LOAD_BAR = 1
- OFF = 0
- beanmachine.ppl.inference.utils.detach_samples(samples: Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]) Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, numpy.ndarray]
Detach pytorch tensors.
- Parameters
samples (Dict[RVIdentifier, torch.Tensor]) – Dictionary of RVIdentifiers with original torch tensors.
- Returns
- Dictionary of RVIdentifiers with converted
NumPy arrays.
- Return type
Dict[RVIdentifier, np.ndarray]
- beanmachine.ppl.inference.utils.merge_dicts(dicts: List[Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]], dim: int = 0, stack_not_cat: bool = True) Dict[beanmachine.ppl.model.rv_identifier.RVIdentifier, torch.Tensor]
A helper function that merge multiple dicts of samples into a single dictionary, stacking across a new dimension
- beanmachine.ppl.inference.utils.safe_log_prob_sum(distrib, value: torch.Tensor) torch.Tensor
Computes log_prob, converting out of support exceptions to -Infinity.
- beanmachine.ppl.inference.utils.seed(seed: int) None