Skip to content

Vectorize node should return list of variables not node #902

Open
@ricardoV94

Description

@ricardoV94

Description

vectorize_node implicitly assumes that whenever we want to vectorize a node, we will return a new node that has a 1-to-1 mapping with the original outputs, but this is too restrictive. It could be the case we want to vectorize a single node with two variables coming from different nodes, or a single output from a multi-valued node. There's no reason why we need a one node -> one node mapping.

@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError

For backwards compatibility we should check if the returned object is an Apply and issue a warning that this form is deprecated (but still use it) and instead a list of outputs (like in the rewrites) should be returned. All our implementations in PyTensor should switch to returning a list of variables.

The catch/warning could be done here:

def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)

Then everything that calls vectorize_node should now expect a list as output. Like here:

vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
# We make sure we don't overwrite the provided replacement with the newly vectorized output
continue
vect_vars[output] = vect_output

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions