Skip to content

Commit 5e7b095

Browse files
Add a push_out_non_seq_scan test for an OpFromGraph with a shared variable
1 parent e65b0c5 commit 5e7b095

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/scan/test_opt.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import aesara
55
import aesara.tensor.basic as at
66
from aesara import function, scan, shared
7+
from aesara.compile.builders import OpFromGraph
78
from aesara.compile.io import In
89
from aesara.compile.mode import get_default_mode
910
from aesara.configdefaults import config
@@ -550,6 +551,28 @@ def inner_fct(seq1, previous_output1, nonseq1):
550551
utt.assert_allclose(output_opt[0], output_no_opt[0])
551552
utt.assert_allclose(output_opt[1], output_no_opt[1])
552553

554+
def test_OpFromGraph_shared(self):
555+
"""Make sure that a simple `OpFromGraph` with a shared variable can be pushed out."""
556+
557+
y = shared(1.0, name="y")
558+
559+
test_ofg = OpFromGraph([], [1 + y])
560+
561+
def inner_func():
562+
return test_ofg()
563+
564+
out, out_updates = aesara.scan(inner_func, n_steps=10)
565+
566+
out_fn = function([], out, updates=out_updates)
567+
568+
res = out_fn()
569+
assert np.array_equal(res, np.repeat(2.0, 10))
570+
571+
y.set_value(2.0)
572+
573+
res = out_fn()
574+
assert np.array_equal(res, np.repeat(3.0, 10))
575+
553576

554577
class TestPushOutAddScan:
555578
"""

0 commit comments

Comments
 (0)