|
29 | 29 | register_infer_shape,
|
30 | 30 | switch,
|
31 | 31 | )
|
| 32 | +from pytensor.tensor.blockwise import Blockwise |
32 | 33 | from pytensor.tensor.elemwise import Elemwise
|
33 | 34 | from pytensor.tensor.exceptions import NotScalarConstantError
|
34 | 35 | from pytensor.tensor.math import Dot, add
|
@@ -1865,3 +1866,37 @@ def local_uint_constant_indices(fgraph, node):
|
1865 | 1866 | copy_stack_trace(node.outputs, new_outs)
|
1866 | 1867 |
|
1867 | 1868 | return new_outs
|
| 1869 | + |
| 1870 | + |
| 1871 | +@register_canonicalize |
| 1872 | +@register_stabilize |
| 1873 | +@register_specialize |
| 1874 | +@node_rewriter([Blockwise]) |
| 1875 | +def useless_blockwise_advanced_inc_subtensor(fgraph, node): |
| 1876 | + """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" |
| 1877 | + if not isinstance(node.op.core_op, AdvancedIncSubtensor): |
| 1878 | + return None |
| 1879 | + |
| 1880 | + op: Blockwise = node.op |
| 1881 | + |
| 1882 | + x, y, *idxs = node.inputs |
| 1883 | + |
| 1884 | + batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0]) |
| 1885 | + |
| 1886 | + new_idxs = [] |
| 1887 | + for idx in idxs: |
| 1888 | + if all(idx.type.broadcastable[:batch_ndim]): |
| 1889 | + new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) |
| 1890 | + else: |
| 1891 | + # Rewrite does not apply |
| 1892 | + return None |
| 1893 | + |
| 1894 | + # Need to broadcast x to final shape |
| 1895 | + if any(x.type.broadcastable[:batch_ndim]): |
| 1896 | + x = alloc(x, *tuple(y.shape)[:batch_ndim], *tuple(x.shape)[batch_ndim:]) |
| 1897 | + |
| 1898 | + new_idxs = [slice(None)] * batch_ndim + new_idxs |
| 1899 | + symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:] |
| 1900 | + new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs |
| 1901 | + copy_stack_trace(node.outputs, new_out) |
| 1902 | + return new_out |
0 commit comments