Skip to content

Commit 91d3b7c

Browse files
committed
Do not merge while scans with different until condition
The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables
1 parent eb552ee commit 91d3b7c

File tree

2 files changed

+170
-48
lines changed

2 files changed

+170
-48
lines changed

pytensor/scan/rewriting.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from pytensor.graph.basic import (
1818
Apply,
1919
Constant,
20+
NominalVariable,
2021
Variable,
22+
ancestors,
2123
apply_depends_on,
2224
equal_computations,
2325
graph_inputs,
@@ -1950,11 +1952,13 @@ def belongs_to_set(self, node, set_nodes):
19501952
Questionable, we should also consider profile ?
19511953
19521954
"""
1953-
rep = set_nodes[0]
1955+
op = node.op
1956+
rep_node = set_nodes[0]
1957+
rep_op = rep_node.op
19541958
if (
1955-
rep.op.info.as_while != node.op.info.as_while
1956-
or node.op.truncate_gradient != rep.op.truncate_gradient
1957-
or node.op.mode != rep.op.mode
1959+
op.info.as_while != rep_op.info.as_while
1960+
or op.truncate_gradient != rep_op.truncate_gradient
1961+
or op.mode != rep_op.mode
19581962
):
19591963
return False
19601964

@@ -1964,7 +1968,7 @@ def belongs_to_set(self, node, set_nodes):
19641968
except NotScalarConstantError:
19651969
pass
19661970

1967-
rep_nsteps = rep.inputs[0]
1971+
rep_nsteps = rep_node.inputs[0]
19681972
try:
19691973
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
19701974
except NotScalarConstantError:
@@ -1978,13 +1982,40 @@ def belongs_to_set(self, node, set_nodes):
19781982
if apply_depends_on(node, nd) or apply_depends_on(nd, node):
19791983
return False
19801984

1981-
if not node.op.info.as_while:
1985+
if not op.info.as_while:
19821986
return True
1983-
cond = node.op.inner_outputs[-1]
1984-
rep_cond = rep.op.inner_outputs[-1]
1985-
return equal_computations(
1986-
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs
1987-
)
1987+
1988+
# We need to check the while conditions are identical
1989+
conds = [op.inner_outputs[-1]]
1990+
rep_conds = [rep_op.inner_outputs[-1]]
1991+
if not equal_computations(
1992+
conds, rep_conds, op.inner_inputs, rep_op.inner_inputs
1993+
):
1994+
return False
1995+
1996+
# If they depend on inner inputs we need to check for equivalence on the respective outer inputs
1997+
nominal_inputs = [a for a in ancestors(conds) if isinstance(a, NominalVariable)]
1998+
if not nominal_inputs:
1999+
return True
2000+
rep_nominal_inputs = [
2001+
a for a in ancestors(rep_conds) if isinstance(a, NominalVariable)
2002+
]
2003+
2004+
conds = []
2005+
rep_conds = []
2006+
mapping = op.get_oinp_iinp_iout_oout_mappings()["outer_inp_from_inner_inp"]
2007+
rep_mapping = rep_op.get_oinp_iinp_iout_oout_mappings()[
2008+
"outer_inp_from_inner_inp"
2009+
]
2010+
inner_inputs = op.inner_inputs
2011+
rep_inner_inputs = rep_op.inner_inputs
2012+
for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs):
2013+
conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]])
2014+
rep_conds.append(
2015+
rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]]
2016+
)
2017+
2018+
return equal_computations(conds, rep_conds)
19882019

19892020
def apply(self, fgraph):
19902021
# Collect all scan nodes ordered according to toposort

tests/scan/test_rewriting.py

Lines changed: 128 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.scan.op import Scan
1616
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
1717
from pytensor.scan.utils import until
18+
from pytensor.tensor import stack
1819
from pytensor.tensor.blas import Dot22
1920
from pytensor.tensor.elemwise import Elemwise
2021
from pytensor.tensor.math import Dot, dot, sigmoid
@@ -796,7 +797,13 @@ def inner_fct(seq1, seq2, seq3, previous_output):
796797

797798

798799
class TestScanMerge:
799-
mode = get_default_mode().including("scan")
800+
mode = get_default_mode().including("scan").excluding("scan_pushout_seqs_ops")
801+
802+
@staticmethod
803+
def count_scans(fn):
804+
nodes = fn.maker.fgraph.apply_nodes
805+
scans = [node for node in nodes if isinstance(node.op, Scan)]
806+
return len(scans)
800807

801808
def test_basic(self):
802809
x = vector()
@@ -808,56 +815,38 @@ def sum(s):
808815
sx, upx = scan(sum, sequences=[x])
809816
sy, upy = scan(sum, sequences=[y])
810817

811-
f = function(
812-
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
813-
)
814-
topo = f.maker.fgraph.toposort()
815-
scans = [n for n in topo if isinstance(n.op, Scan)]
816-
assert len(scans) == 2
818+
f = function([x, y], [sx, sy], mode=self.mode)
819+
assert self.count_scans(f) == 2
817820

818821
sx, upx = scan(sum, sequences=[x], n_steps=2)
819822
sy, upy = scan(sum, sequences=[y], n_steps=3)
820823

821-
f = function(
822-
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
823-
)
824-
topo = f.maker.fgraph.toposort()
825-
scans = [n for n in topo if isinstance(n.op, Scan)]
826-
assert len(scans) == 2
824+
f = function([x, y], [sx, sy], mode=self.mode)
825+
assert self.count_scans(f) == 2
827826

828827
sx, upx = scan(sum, sequences=[x], n_steps=4)
829828
sy, upy = scan(sum, sequences=[y], n_steps=4)
830829

831-
f = function(
832-
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
833-
)
834-
topo = f.maker.fgraph.toposort()
835-
scans = [n for n in topo if isinstance(n.op, Scan)]
836-
assert len(scans) == 1
830+
f = function([x, y], [sx, sy], mode=self.mode)
831+
assert self.count_scans(f) == 1
837832

838833
sx, upx = scan(sum, sequences=[x])
839834
sy, upy = scan(sum, sequences=[x])
840835

841-
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
842-
topo = f.maker.fgraph.toposort()
843-
scans = [n for n in topo if isinstance(n.op, Scan)]
844-
assert len(scans) == 1
836+
f = function([x], [sx, sy], mode=self.mode)
837+
assert self.count_scans(f) == 1
845838

846839
sx, upx = scan(sum, sequences=[x])
847840
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
848841

849-
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
850-
topo = f.maker.fgraph.toposort()
851-
scans = [n for n in topo if isinstance(n.op, Scan)]
852-
assert len(scans) == 1
842+
f = function([x], [sx, sy], mode=self.mode)
843+
assert self.count_scans(f) == 1
853844

854845
sx, upx = scan(sum, sequences=[x])
855846
sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
856847

857-
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
858-
topo = f.maker.fgraph.toposort()
859-
scans = [n for n in topo if isinstance(n.op, Scan)]
860-
assert len(scans) == 2
848+
f = function([x], [sx, sy], mode=self.mode)
849+
assert self.count_scans(f) == 2
861850

862851
def test_three_scans(self):
863852
r"""
@@ -877,12 +866,8 @@ def sum(s):
877866
sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
878867
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
879868

880-
f = function(
881-
[x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops")
882-
)
883-
topo = f.maker.fgraph.toposort()
884-
scans = [n for n in topo if isinstance(n.op, Scan)]
885-
assert len(scans) == 2
869+
f = function([x, y], [sy, sz], mode=self.mode)
870+
assert self.count_scans(f) == 2
886871

887872
rng = np.random.default_rng(utt.fetch_seed())
888873
x_val = rng.uniform(size=(4,)).astype(config.floatX)
@@ -913,6 +898,112 @@ def test_belongs_to_set(self):
913898
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
914899
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
915900

901+
@config.change_flags(cxx="") # Just for faster compilation
902+
def test_while_scan(self):
903+
x = vector("x")
904+
y = vector("y")
905+
906+
def add(s):
907+
return s + 1, until(s > 5)
908+
909+
def sub(s):
910+
return s - 1, until(s > 5)
911+
912+
def sub_alt(s):
913+
return s - 1, until(s > 4)
914+
915+
sx, upx = scan(add, sequences=[x])
916+
sy, upy = scan(sub, sequences=[y])
917+
918+
f = function([x, y], [sx, sy], mode=self.mode)
919+
assert self.count_scans(f) == 2
920+
921+
sx, upx = scan(add, sequences=[x])
922+
sy, upy = scan(sub, sequences=[x])
923+
924+
f = function([x], [sx, sy], mode=self.mode)
925+
assert self.count_scans(f) == 1
926+
927+
sx, upx = scan(add, sequences=[x])
928+
sy, upy = scan(sub_alt, sequences=[x])
929+
930+
f = function([x], [sx, sy], mode=self.mode)
931+
assert self.count_scans(f) == 2
932+
933+
@config.change_flags(cxx="") # Just for faster compilation
934+
def test_while_scan_nominal_dependency(self):
935+
"""Test case where condition depends on nominal variables.
936+
937+
This is a regression test for #509
938+
"""
939+
c1 = scalar("c1")
940+
c2 = scalar("c2")
941+
x = vector("x", shape=(5,))
942+
y = vector("y", shape=(5,))
943+
z = vector("z", shape=(5,))
944+
945+
def add(s1, s2, const):
946+
return s1 + 1, until(s2 > const)
947+
948+
def sub(s1, s2, const):
949+
return s1 - 1, until(s2 > const)
950+
951+
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
952+
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1])
953+
954+
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
955+
assert self.count_scans(f) == 2
956+
res_sx, res_sy = f(
957+
x=[0, 0, 0, 0, 0],
958+
y=[0, 0, 0, 0, 0],
959+
z=[0, 1, 2, 3, 4],
960+
c1=0,
961+
)
962+
np.testing.assert_array_equal(res_sx, [1, 1])
963+
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])
964+
965+
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
966+
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2])
967+
968+
f = pytensor.function(
969+
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
970+
)
971+
assert self.count_scans(f) == 2
972+
res_sx, res_sy = f(
973+
x=[0, 0, 0, 0, 0],
974+
y=[0, 0, 0, 0, 0],
975+
z=[0, 1, 2, 3, 4],
976+
c1=3,
977+
c2=1,
978+
)
979+
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
980+
np.testing.assert_array_equal(res_sy, [-1, -1, -1])
981+
982+
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
983+
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1])
984+
985+
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
986+
assert self.count_scans(f) == 1
987+
988+
def nested_scan(c, x, z):
989+
sx, _ = scan(add, sequences=[x, z], non_sequences=[c])
990+
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c])
991+
return sx.sum() + sy.sum()
992+
993+
sz, _ = scan(
994+
nested_scan,
995+
sequences=[stack([c1, c2])],
996+
non_sequences=[x, z],
997+
mode=self.mode,
998+
)
999+
1000+
f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
1001+
[scan_node] = [
1002+
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
1003+
]
1004+
inner_f = scan_node.op.fn
1005+
assert self.count_scans(inner_f) == 1
1006+
9161007

9171008
class TestScanInplaceOptimizer:
9181009
mode = get_default_mode().including("scan_make_inplace", "inplace")

0 commit comments

Comments
 (0)