Open
Description
Description
If we have any empty operations, we can truncate the whole graph above it:
import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(10,))
out = pt.add.outer(x, x)[9:-1]
fn = pytensor.function([x], out)
fn.dprint(print_shape=True)
# Subtensor{start:stop} [id A] shape=(0, 10) 3
# ├─ Add [id B] shape=(10, 10) 2
# │ ├─ ExpandDims{axis=1} [id C] shape=(10, 1) 1
# │ │ └─ x [id D] shape=(10,)
# │ └─ ExpandDims{axis=0} [id E] shape=(1, 10) 0
# │ └─ x [id D] shape=(10,)
# ├─ 9 [id F] shape=()
# └─ -1 [id G] shape=()
PyTensor knows that the shape=(0, 10). We can replace the graph by just zeros. Worse thing it could do would be to hide shape errors, so we can add the shape_unsafe
tag.
We need to be careful with Operations that cannot figure out the output shape without computing itself (such as while Scan). We can ask the ShapeFeature for the shape of the variable, and if the variable is still in the shape graph we know we can't really avoid computing the graph. Otherwise just replace it by zeros.