Skip to content

Simplify dots with 1 #638

Open
Open
@ricardoV94

Description

@ricardoV94

Description

We have a local_0_dot_x that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in #637 (comment)

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
    fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))

pytensor.dprint(fn)
dot [id A] 0
 ├─ x [id B]
 └─ [[1.]] [id C]

I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.

@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot):
return False
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions