beanmachine.ppl.compiler.fix_logsumexp module
- beanmachine.ppl.compiler.fix_logsumexp.logsumexp_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]
This fixer attempts to rewrite log expressions of the form log( exp(a) + exp(b) + exp(c) …) -> logsumexp(a,b,c, …)