Skip to content

Replace 0 sized nodes by zeros #1298

Open
@ricardoV94

Description

@ricardoV94

Description

If we have any empty operations, we can truncate the whole graph above it:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(10,))
out = pt.add.outer(x, x)[9:-1]

fn = pytensor.function([x], out)
fn.dprint(print_shape=True)
# Subtensor{start:stop} [id A] shape=(0, 10) 3
#  ├─ Add [id B] shape=(10, 10) 2
#  │  ├─ ExpandDims{axis=1} [id C] shape=(10, 1) 1
#  │  │  └─ x [id D] shape=(10,)
#  │  └─ ExpandDims{axis=0} [id E] shape=(1, 10) 0
#  │     └─ x [id D] shape=(10,)
#  ├─ 9 [id F] shape=()
#  └─ -1 [id G] shape=()

PyTensor knows that the shape=(0, 10). We can replace the graph by just zeros. Worse thing it could do would be to hide shape errors, so we can add the shape_unsafe tag.

We need to be careful with Operations that cannot figure out the output shape without computing itself (such as while Scan). We can ask the ShapeFeature for the shape of the variable, and if the variable is still in the shape graph we know we can't really avoid computing the graph. Otherwise just replace it by zeros.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions