beanmachine.ppl.compiler.fix_matrix_scale module

beanmachine.ppl.compiler.fix_matrix_scale.matrix_scale_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder, sizer: beanmachine.ppl.compiler.sizer.Sizer) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]

This node fixer attempts to rewrite binary multiplications that involve a matrix and a scalar into a matrix_scale node.

beanmachine.ppl.compiler.fix_matrix_scale.nested_matrix_scale_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]
beanmachine.ppl.compiler.fix_matrix_scale.trivial_matmul_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder, typer: beanmachine.ppl.compiler.lattice_typer.LatticeTyper) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]

This node fixer attempts to rewrite matrix multiplications of two scalars into an ordinary multiplication.