beanmachine.ppl.compiler.fix_matrix_scale module
- beanmachine.ppl.compiler.fix_matrix_scale.matrix_scale_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder, sizer: beanmachine.ppl.compiler.sizer.Sizer) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]
This node fixer attempts to rewrite binary multiplications that involve a matrix and a scalar into a matrix_scale node.
- beanmachine.ppl.compiler.fix_matrix_scale.nested_matrix_scale_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]
- beanmachine.ppl.compiler.fix_matrix_scale.trivial_matmul_fixer(bmg: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder, typer: beanmachine.ppl.compiler.lattice_typer.LatticeTyper) Callable[[beanmachine.ppl.compiler.bmg_nodes.BMGNode], Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, None, beanmachine.ppl.compiler.fix_problem.NodeFixerError]]
This node fixer attempts to rewrite matrix multiplications of two scalars into an ordinary multiplication.