|
2 | 2 | from pytensor.graph import node_rewriter
|
3 | 3 | from pytensor.graph.replace import vectorize_node
|
4 | 4 | from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
|
5 |
| -from pytensor.tensor.basic import Alloc, ARange, shape_padleft |
| 5 | +from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft |
6 | 6 | from pytensor.tensor.blockwise import Blockwise
|
| 7 | +from pytensor.tensor.extra_ops import broadcast_shape_iter |
7 | 8 | from pytensor.tensor.math import _matrix_matrix_matmul
|
8 | 9 | from pytensor.tensor.rewriting.basic import (
|
9 | 10 | register_canonicalize,
|
@@ -75,3 +76,94 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
|
75 | 76 | )
|
76 | 77 | ):
|
77 | 78 | return local_useless_unbatched_blockwise.fn(fgraph, node)
|
| 79 | + |
| 80 | + |
| 81 | +@register_specialize("shape_unsafe") |
| 82 | +@node_rewriter([Blockwise]) |
| 83 | +def local_blockwise_alloc(fgraph, node): |
| 84 | + """Push Allocs from the inputs to the output of Blockwise Ops.""" |
| 85 | + |
| 86 | + op: Blockwise = node.op # type: ignore |
| 87 | + |
| 88 | + batch_ndim = node.inputs[0].type.ndim - len(op.inputs_sig[0]) |
| 89 | + batch_axes = tuple(range(batch_ndim)) |
| 90 | + new_inputs = [] |
| 91 | + batch_shapes = [] |
| 92 | + can_lift_alloc = False |
| 93 | + for inp, inp_sig in zip(node.inputs, op.inputs_sig): |
| 94 | + if all(inp.type.broadcastable[:batch_ndim]): |
| 95 | + # The input only has dummy batch dims (if it has any) |
| 96 | + inp = inp.squeeze(batch_axes) |
| 97 | + new_inputs.append(inp) |
| 98 | + continue |
| 99 | + |
| 100 | + core_ndim = len(inp_sig) |
| 101 | + if inp.owner and isinstance(inp.owner.op, Alloc): |
| 102 | + value, *shape = inp.owner.inputs |
| 103 | + value_ndim = value.type.ndim |
| 104 | + value_batch_ndim = value_ndim - core_ndim |
| 105 | + if value_batch_ndim: |
| 106 | + # The original value already has batch dims, let's see if it's just dummy ones |
| 107 | + if all(value.type.broadcastable[:value_batch_ndim]): |
| 108 | + value = value.squeeze(axis=tuple(range(value_batch_ndim))) |
| 109 | + else: |
| 110 | + # The original value has batch dims that are not dummy |
| 111 | + # We cannot lift this Alloc |
| 112 | + new_inputs.append(inp) |
| 113 | + continue |
| 114 | + |
| 115 | + alloc_ndim = len(shape) |
| 116 | + if alloc_ndim > core_ndim: |
| 117 | + # The Alloc adds all the batch dims |
| 118 | + batch_shape = shape[:batch_ndim] |
| 119 | + core_shape = shape[batch_ndim:] |
| 120 | + if any(value.type.broadcastable[batch_ndim:]): |
| 121 | + # We still need an Alloc for the core dims |
| 122 | + value = alloc(value, *core_shape) |
| 123 | + new_inputs.append(value) |
| 124 | + batch_shapes.append( |
| 125 | + [ |
| 126 | + dim if not bcast else 1 |
| 127 | + for dim, bcast in zip( |
| 128 | + batch_shape, inp.type.broadcastable[:batch_ndim] |
| 129 | + ) |
| 130 | + ] |
| 131 | + ) |
| 132 | + can_lift_alloc = True |
| 133 | + continue |
| 134 | + |
| 135 | + # Nothing to do with this input |
| 136 | + new_inputs.append(inp) |
| 137 | + |
| 138 | + if not can_lift_alloc: |
| 139 | + return None |
| 140 | + |
| 141 | + new_outs = node.op.make_node(*new_inputs).outputs |
| 142 | + |
| 143 | + # Pushed Allocs are still needed |
| 144 | + if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable: |
| 145 | + out = new_outs[0] |
| 146 | + batch_ndim = out.type.ndim - len(op.outputs_sig[0]) |
| 147 | + if batch_ndim: |
| 148 | + # The new output already has batch dims, we need to consider this when broadcasting |
| 149 | + bcast_shape = tuple(out.shape)[:batch_ndim] |
| 150 | + batch_shapes.append( |
| 151 | + [ |
| 152 | + dim if not bcast else 1 |
| 153 | + for dim, bcast in zip( |
| 154 | + bcast_shape, out.type.broadcastable[:batch_ndim] |
| 155 | + ) |
| 156 | + ] |
| 157 | + ) |
| 158 | + if len(batch_shapes) == 1: |
| 159 | + [batch_shape] = batch_shapes |
| 160 | + else: |
| 161 | + batch_shape = broadcast_shape_iter(batch_shapes, arrays_are_shapes=True) |
| 162 | + core_shapes = [out.shape[batch_ndim:] for out in new_outs] |
| 163 | + new_outs = [ |
| 164 | + alloc(new_out, *batch_shape, *core_shape) |
| 165 | + for new_out, core_shape in zip(new_outs, core_shapes) |
| 166 | + ] |
| 167 | + |
| 168 | + copy_stack_trace(node.outputs, new_outs) |
| 169 | + return new_outs |
0 commit comments