beanmachine.ppl.compiler.tensorizer_transformer module

class beanmachine.ppl.compiler.tensorizer_transformer.ElementType(value)

Bases: enum.Enum

An enumeration.

MATRIX = 3
SCALAR = 2
TENSOR = 1
UNKNOWN = 4
class beanmachine.ppl.compiler.tensorizer_transformer.Tensorizer(cloner: beanmachine.ppl.compiler.copy_and_replace.Cloner, sizer: beanmachine.ppl.compiler.sizer.Sizer)

Bases: beanmachine.ppl.compiler.copy_and_replace.NodeTransformer

assess_node(node: beanmachine.ppl.compiler.bmg_nodes.BMGNode, original: beanmachine.ppl.compiler.bm_graph_builder.BMGraphBuilder) beanmachine.ppl.compiler.copy_and_replace.TransformAssessment
can_be_tensorized(original_node: beanmachine.ppl.compiler.bmg_nodes.BMGNode) bool
div_can_be_tensorized(node: beanmachine.ppl.compiler.bmg_nodes.DivisionNode) bool
mult_can_be_tensorized(node: beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode) bool
negate_can_be_tensorized(node: beanmachine.ppl.compiler.bmg_nodes.NegateNode) bool
transform_node(node: beanmachine.ppl.compiler.bmg_nodes.BMGNode, new_inputs: List[beanmachine.ppl.compiler.bmg_nodes.BMGNode]) Optional[Union[beanmachine.ppl.compiler.bmg_nodes.BMGNode, List[beanmachine.ppl.compiler.bmg_nodes.BMGNode]]]