Skip to content

Commit 4a34ef1

Browse files
committed
Allow inplacing of SITSOT and last MITSOT in numba Scan, when they are discarded immediately
1 parent ca09602 commit 4a34ef1

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def range_arr(x):
5555

5656

5757
@numba_funcify.register(Scan)
58-
def numba_funcify_Scan(op, node, **kwargs):
58+
def numba_funcify_Scan(op: Scan, node, **kwargs):
5959
# Apply inner rewrites
6060
# TODO: Not sure this is the right place to do this, should we have a rewrite that
6161
# explicitly triggers the optimization of the inner graphs of Scan?
@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
6767
.optimizer
6868
)
6969
fgraph = op.fgraph
70+
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
71+
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
72+
# TODO: Allow inplace for MITMOT
73+
destroyable_sitsot = [
74+
inner_sitsot
75+
for outer_sitsot, inner_sitsot in zip(
76+
op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True
77+
)
78+
if outer_sitsot.type.shape[0] == 1
79+
]
80+
destroyable_mitsot = [
81+
oldest_inner_mitmot
82+
for outer_mitsot, oldest_inner_mitmot, taps in zip(
83+
op.outer_mitsot(node.inputs),
84+
op.oldest_inner_mitsot(fgraph.inputs),
85+
op.info.mit_sot_in_slices,
86+
strict=True,
87+
)
88+
if outer_mitsot.type.shape[0] == abs(min(taps))
89+
]
90+
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
7091
add_supervisor_to_fgraph(
7192
fgraph=fgraph,
72-
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
93+
input_specs=[
94+
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
95+
],
7396
accept_inplace=True,
7497
)
7598
rewriter(fgraph)

pytensor/scan/op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ def inner_mitsot(self, list_inputs):
321321
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
322322
]
323323

324+
def oldest_inner_mitsot(self, list_inputs):
325+
inner_mitsot_inputs = self.inner_mitsot(list_inputs)
326+
oldest_inner_mitsot_inputs = []
327+
offset = 0
328+
for taps in self.info.mit_sot_in_slices:
329+
oldest_tap = np.argmin(taps)
330+
oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]]
331+
offset += len(taps)
332+
return oldest_inner_mitsot_inputs
333+
324334
def outer_mitsot(self, list_inputs):
325335
offset = 1 + self.info.n_seqs + self.info.n_mit_mot
326336
return list_inputs[offset : offset + self.info.n_mit_sot]

tests/link/numba/test_scan.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,68 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
451451
benchmark(numba_fn, *test.values())
452452

453453

454+
@pytest.mark.parametrize("n_steps_constant", (True, False))
455+
def test_inplace_taps(n_steps_constant):
456+
"""Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps."""
457+
n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int)
458+
a = scalar("a")
459+
x0 = scalar("x0")
460+
y0 = vector("y0", shape=(3,))
461+
462+
def step(xtm1, ytm1, ytm3, a):
463+
x = xtm1 + 1
464+
y = ytm1 + 1 + ytm3 + a
465+
return x, x + y, y
466+
467+
[xs, zs, ys], _ = scan(
468+
fn=step,
469+
outputs_info=[
470+
dict(initial=x0, taps=[-1]),
471+
None,
472+
dict(initial=y0, taps=[-1, -3]),
473+
],
474+
non_sequences=[a],
475+
n_steps=n_steps,
476+
)
477+
numba_fn, _ = compare_numba_and_py(
478+
[n_steps] * (not n_steps_constant) + [a, x0, y0],
479+
[xs[-1], zs[-1], ys[-1]],
480+
[10] * (not n_steps_constant) + [np.pi, np.e, [0, np.euler_gamma, 1]],
481+
numba_mode="NUMBA",
482+
eval_obj_mode=False,
483+
)
484+
[scan_op] = [
485+
node.op
486+
for node in numba_fn.maker.fgraph.toposort()
487+
if isinstance(node.op, Scan)
488+
]
489+
490+
# Scan reorders inputs internally, so we need to check its ordering
491+
inner_inps = scan_op.fgraph.inputs
492+
oldest_mit_sot_tap = scan_op.info.mit_sot_in_slices[0].index(-3)
493+
oldest_mit_sot_inp = scan_op.inner_mitsot(inner_inps)[oldest_mit_sot_tap]
494+
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
495+
496+
inner_outs = scan_op.fgraph.outputs
497+
[mit_sot_out] = scan_op.inner_mitsot_outs(inner_outs)
498+
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
499+
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
500+
501+
if n_steps_constant:
502+
assert mit_sot_out.owner.op.destroy_map == {
503+
0: [mit_sot_out.owner.inputs.index(oldest_mit_sot_inp)]
504+
}
505+
assert sit_sot_out.owner.op.destroy_map == {
506+
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
507+
}
508+
else:
509+
# This is not a feature, but a current limitation
510+
# https://github.com/pymc-devs/pytensor/issues/1283
511+
assert mit_sot_out.owner.op.destroy_map == {}
512+
assert sit_sot_out.owner.op.destroy_map == {}
513+
assert nit_sot_out.owner.op.destroy_map == {}
514+
515+
454516
@pytest.mark.parametrize(
455517
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
456518
)

0 commit comments

Comments
 (0)