Description
Description
Reshape has only two inputs, x, and a vector of the output shape. This is cumbersome because many times we want to analyze the individual dimensions to rewrite Reshape as expand_dims
or get rid of useless Reshape. Also the Reshape Op needs to be parametrized with the output length, because historically we didn't have static shapes, and couldn't always guess how many entries the shape vector had.
Most times Reshape is used to concatenate dimensions, so we end up with stuff like [x.shape[0], ..., x.shape[n] * x.shape[m], ..., x.shape[-1]]
, wrapped in a MakeVector. This makes Resahpe rewrites harder because they have to handle the case where things are joined in a MakeVector or may have been constant folded into a single tensor.
pytensor/pytensor/tensor/rewriting/shape.py
Lines 921 to 926 in bf73f8a
SpecifyShape already works with a variable number of inputs and we haven't any trouble with it.