Skip to content

Commit 88979e7

Browse files
committed
Add rewrite to remove Blockwise of AdvancedIncSubtensor
1 parent 12f3cc3 commit 88979e7

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
register_infer_shape,
3030
switch,
3131
)
32+
from pytensor.tensor.blockwise import Blockwise
3233
from pytensor.tensor.elemwise import Elemwise
3334
from pytensor.tensor.exceptions import NotScalarConstantError
3435
from pytensor.tensor.math import Dot, add
@@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node):
18801881
copy_stack_trace(node.outputs, new_outs)
18811882

18821883
return new_outs
1884+
1885+
1886+
@register_canonicalize("shape_unsafe")
1887+
@register_stabilize("shape_unsafe")
1888+
@register_specialize("shape_unsafe")
1889+
@node_rewriter([Blockwise])
1890+
def local_blockwise_advanced_inc_subtensor(fgraph, node):
1891+
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1892+
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1893+
return None
1894+
1895+
x, y, *idxs = node.inputs
1896+
1897+
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1898+
if any(
1899+
(
1900+
isinstance(idx, (SliceType, NoneTypeT))
1901+
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
1902+
)
1903+
for idx in idxs
1904+
):
1905+
return None
1906+
1907+
op: Blockwise = node.op # type: ignore
1908+
batch_ndim = op.batch_ndim(node)
1909+
1910+
new_idxs = []
1911+
for idx in idxs:
1912+
if all(idx.type.broadcastable[:batch_ndim]):
1913+
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1914+
else:
1915+
# Rewrite does not apply
1916+
return None
1917+
1918+
x_batch_bcast = x.type.broadcastable[:batch_ndim]
1919+
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1920+
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)):
1921+
# Need to broadcast batch x dims
1922+
batch_shape = tuple(
1923+
x_dim if (not xb or yb) else y_dim
1924+
for xb, x_dim, yb, y_dim in zip(
1925+
x_batch_bcast,
1926+
tuple(x.shape)[:batch_ndim],
1927+
y_batch_bcast,
1928+
tuple(y.shape)[:batch_ndim],
1929+
)
1930+
)
1931+
core_shape = tuple(x.shape)[batch_ndim:]
1932+
x = alloc(x, *batch_shape, *core_shape)
1933+
1934+
new_idxs = [slice(None)] * batch_ndim + new_idxs
1935+
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
1936+
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1937+
copy_stack_trace(node.outputs, new_out)
1938+
return new_out

tests/tensor/rewriting/test_subtensor.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
12-
from pytensor.graph import FunctionGraph
12+
from pytensor.graph import FunctionGraph, vectorize_graph
1313
from pytensor.graph.basic import Constant, Variable, ancestors
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -18,6 +18,7 @@
1818
from pytensor.raise_op import Assert
1919
from pytensor.tensor import inplace
2020
from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
21+
from pytensor.tensor.blockwise import Blockwise
2122
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2223
from pytensor.tensor.math import Dot, add, dot, exp, sqr
2324
from pytensor.tensor.rewriting.subtensor import (
@@ -2314,3 +2315,105 @@ def test_local_uint_constant_indices():
23142315
new_index = subtensor_node.inputs[1]
23152316
assert isinstance(new_index, Constant)
23162317
assert new_index.type.dtype == "uint8"
2318+
2319+
2320+
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
2321+
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
2322+
core_x = tensor("x", shape=(6,))
2323+
core_y = tensor("y", shape=(3,))
2324+
core_idxs = [0, 2, 4]
2325+
if set_instead_of_inc:
2326+
core_graph = set_subtensor(core_x[core_idxs], core_y)
2327+
else:
2328+
core_graph = inc_subtensor(core_x[core_idxs], core_y)
2329+
2330+
# Only x is batched
2331+
x = tensor(
2332+
"x",
2333+
shape=(
2334+
5,
2335+
2,
2336+
6,
2337+
),
2338+
)
2339+
y = tensor("y", shape=(3,))
2340+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2341+
assert isinstance(out.owner.op, Blockwise)
2342+
2343+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2344+
assert not any(
2345+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2346+
)
2347+
2348+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2349+
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
2350+
expected_out = test_x.copy()
2351+
if set_instead_of_inc:
2352+
expected_out[:, :, core_idxs] = test_y
2353+
else:
2354+
expected_out[:, :, core_idxs] += test_y
2355+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2356+
2357+
# Only y is batched
2358+
x = tensor("y", shape=(6,))
2359+
y = tensor("y", shape=(2, 3))
2360+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2361+
assert isinstance(out.owner.op, Blockwise)
2362+
2363+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2364+
assert not any(
2365+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2366+
)
2367+
2368+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2369+
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
2370+
expected_out = np.ones((2, *x.type.shape))
2371+
if set_instead_of_inc:
2372+
expected_out[:, core_idxs] = test_y
2373+
else:
2374+
expected_out[:, core_idxs] += test_y
2375+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2376+
2377+
# Both x and y are batched, and do not need to be broadcasted
2378+
x = tensor("y", shape=(2, 6))
2379+
y = tensor("y", shape=(2, 3))
2380+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2381+
assert isinstance(out.owner.op, Blockwise)
2382+
2383+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2384+
assert not any(
2385+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2386+
)
2387+
2388+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2389+
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
2390+
expected_out = test_x.copy()
2391+
if set_instead_of_inc:
2392+
expected_out[:, core_idxs] = test_y
2393+
else:
2394+
expected_out[:, core_idxs] += test_y
2395+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2396+
2397+
# Both x and y are batched, but must be broadcasted
2398+
x = tensor("y", shape=(5, 1, 6))
2399+
y = tensor("y", shape=(1, 2, 3))
2400+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2401+
assert isinstance(out.owner.op, Blockwise)
2402+
2403+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2404+
assert not any(
2405+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2406+
)
2407+
2408+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2409+
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
2410+
final_shape = (
2411+
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
2412+
x.type.shape[-1],
2413+
)
2414+
expected_out = np.broadcast_to(test_x, final_shape).copy()
2415+
if set_instead_of_inc:
2416+
expected_out[:, :, core_idxs] = test_y
2417+
else:
2418+
expected_out[:, :, core_idxs] += test_y
2419+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)

0 commit comments

Comments
 (0)