Skip to content

Commit df1eaed

Browse files
committed
Vectorize Subtensor without batched indices
1 parent 503d939 commit df1eaed

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

pytensor/tensor/subtensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.gradient import DisconnectedType
1414
from pytensor.graph.basic import Apply, Constant, Variable
1515
from pytensor.graph.op import Op
16+
from pytensor.graph.replace import _vectorize_node
1617
from pytensor.graph.type import Type
1718
from pytensor.graph.utils import MethodNotDefined
1819
from pytensor.link.c.op import COp
@@ -22,6 +23,7 @@
2223
from pytensor.scalar.basic import ScalarConstant
2324
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
2425
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
26+
from pytensor.tensor.blockwise import vectorize_node_fallback
2527
from pytensor.tensor.elemwise import DimShuffle
2628
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
2729
from pytensor.tensor.math import clip
@@ -1283,6 +1285,21 @@ def _process(self, idxs, op_inputs, pstate):
12831285
pprint.assign(Subtensor, SubtensorPrinter())
12841286

12851287

1288+
# TODO: Implement similar vectorize for Inc/SetSubtensor
1289+
@_vectorize_node.register(Subtensor)
1290+
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
1291+
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
1292+
1293+
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
1294+
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
1295+
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
1296+
1297+
old_x, *_ = node.inputs
1298+
batch_ndims = batch_x.type.ndim - old_x.type.ndim
1299+
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
1300+
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
1301+
1302+
12861303
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
12871304
"""
12881305
Return x with the given subtensor overwritten by y.

tests/tensor/test_subtensor.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import pytensor
1010
import pytensor.scalar as scal
1111
import pytensor.tensor.basic as at
12+
from pytensor import function
1213
from pytensor.compile import DeepCopyOp, shared
1314
from pytensor.compile.io import In
1415
from pytensor.configdefaults import config
1516
from pytensor.graph.op import get_test_value
1617
from pytensor.graph.rewriting.utils import is_same_graph
1718
from pytensor.printing import pprint
1819
from pytensor.scalar.basic import as_scalar
19-
from pytensor.tensor import get_vector_length
20+
from pytensor.tensor import get_vector_length, vectorize
21+
from pytensor.tensor.blockwise import Blockwise
2022
from pytensor.tensor.elemwise import DimShuffle
2123
from pytensor.tensor.math import exp, isinf
2224
from pytensor.tensor.math import sum as at_sum
@@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected):
27092711
x = at.tensor(dtype="float64", shape=x_shape)
27102712
y = x[indices]
27112713
assert y.type.shape == expected
2714+
2715+
2716+
def test_vectorize_subtensor_without_batch_indices():
2717+
signature = "(t1,t2,t3),()->(t1,t3)"
2718+
2719+
def core_fn(x, start):
2720+
return x[:, start, :]
2721+
2722+
x = tensor(shape=(11, 7, 5, 3))
2723+
start = tensor(shape=(), dtype="int")
2724+
vectorize_pt = function(
2725+
[x, start], vectorize(core_fn, signature=signature)(x, start)
2726+
)
2727+
assert not any(
2728+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
2729+
)
2730+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
2731+
start_test = np.random.randint(0, x.type.shape[-2])
2732+
vectorize_np = np.vectorize(core_fn, signature=signature)
2733+
np.testing.assert_allclose(
2734+
vectorize_pt(x_test, start_test),
2735+
vectorize_np(x_test, start_test),
2736+
)
2737+
2738+
# If we vectorize start, we should get a Blockwise that still works
2739+
x = tensor(shape=(11, 7, 5, 3))
2740+
start = tensor(shape=(11,), dtype="int")
2741+
vectorize_pt = function(
2742+
[x, start], vectorize(core_fn, signature=signature)(x, start)
2743+
)
2744+
assert any(
2745+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
2746+
)
2747+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
2748+
start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0])
2749+
vectorize_np = np.vectorize(core_fn, signature=signature)
2750+
np.testing.assert_allclose(
2751+
vectorize_pt(x_test, start_test),
2752+
vectorize_np(x_test, start_test),
2753+
)

0 commit comments

Comments
 (0)