15
15
from pytensor .scan .op import Scan
16
16
from pytensor .scan .rewriting import ScanInplaceOptimizer , ScanMerge
17
17
from pytensor .scan .utils import until
18
+ from pytensor .tensor import stack
18
19
from pytensor .tensor .blas import Dot22
19
20
from pytensor .tensor .elemwise import Elemwise
20
21
from pytensor .tensor .math import Dot , dot , sigmoid
@@ -796,7 +797,13 @@ def inner_fct(seq1, seq2, seq3, previous_output):
796
797
797
798
798
799
class TestScanMerge :
799
- mode = get_default_mode ().including ("scan" )
800
+ mode = get_default_mode ().including ("scan" ).excluding ("scan_pushout_seqs_ops" )
801
+
802
+ @staticmethod
803
+ def count_scans (fn ):
804
+ nodes = fn .maker .fgraph .apply_nodes
805
+ scans = [node for node in nodes if isinstance (node .op , Scan )]
806
+ return len (scans )
800
807
801
808
def test_basic (self ):
802
809
x = vector ()
@@ -808,56 +815,38 @@ def sum(s):
808
815
sx , upx = scan (sum , sequences = [x ])
809
816
sy , upy = scan (sum , sequences = [y ])
810
817
811
- f = function (
812
- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
813
- )
814
- topo = f .maker .fgraph .toposort ()
815
- scans = [n for n in topo if isinstance (n .op , Scan )]
816
- assert len (scans ) == 2
818
+ f = function ([x , y ], [sx , sy ], mode = self .mode )
819
+ assert self .count_scans (f ) == 2
817
820
818
821
sx , upx = scan (sum , sequences = [x ], n_steps = 2 )
819
822
sy , upy = scan (sum , sequences = [y ], n_steps = 3 )
820
823
821
- f = function (
822
- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
823
- )
824
- topo = f .maker .fgraph .toposort ()
825
- scans = [n for n in topo if isinstance (n .op , Scan )]
826
- assert len (scans ) == 2
824
+ f = function ([x , y ], [sx , sy ], mode = self .mode )
825
+ assert self .count_scans (f ) == 2
827
826
828
827
sx , upx = scan (sum , sequences = [x ], n_steps = 4 )
829
828
sy , upy = scan (sum , sequences = [y ], n_steps = 4 )
830
829
831
- f = function (
832
- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
833
- )
834
- topo = f .maker .fgraph .toposort ()
835
- scans = [n for n in topo if isinstance (n .op , Scan )]
836
- assert len (scans ) == 1
830
+ f = function ([x , y ], [sx , sy ], mode = self .mode )
831
+ assert self .count_scans (f ) == 1
837
832
838
833
sx , upx = scan (sum , sequences = [x ])
839
834
sy , upy = scan (sum , sequences = [x ])
840
835
841
- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
842
- topo = f .maker .fgraph .toposort ()
843
- scans = [n for n in topo if isinstance (n .op , Scan )]
844
- assert len (scans ) == 1
836
+ f = function ([x ], [sx , sy ], mode = self .mode )
837
+ assert self .count_scans (f ) == 1
845
838
846
839
sx , upx = scan (sum , sequences = [x ])
847
840
sy , upy = scan (sum , sequences = [x ], mode = "FAST_COMPILE" )
848
841
849
- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
850
- topo = f .maker .fgraph .toposort ()
851
- scans = [n for n in topo if isinstance (n .op , Scan )]
852
- assert len (scans ) == 1
842
+ f = function ([x ], [sx , sy ], mode = self .mode )
843
+ assert self .count_scans (f ) == 1
853
844
854
845
sx , upx = scan (sum , sequences = [x ])
855
846
sy , upy = scan (sum , sequences = [x ], truncate_gradient = 1 )
856
847
857
- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
858
- topo = f .maker .fgraph .toposort ()
859
- scans = [n for n in topo if isinstance (n .op , Scan )]
860
- assert len (scans ) == 2
848
+ f = function ([x ], [sx , sy ], mode = self .mode )
849
+ assert self .count_scans (f ) == 2
861
850
862
851
def test_three_scans (self ):
863
852
r"""
@@ -877,12 +866,8 @@ def sum(s):
877
866
sy , upy = scan (sum , sequences = [2 * y + 2 ], n_steps = 4 , name = "Y" )
878
867
sz , upz = scan (sum , sequences = [sx ], n_steps = 4 , name = "Z" )
879
868
880
- f = function (
881
- [x , y ], [sy , sz ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
882
- )
883
- topo = f .maker .fgraph .toposort ()
884
- scans = [n for n in topo if isinstance (n .op , Scan )]
885
- assert len (scans ) == 2
869
+ f = function ([x , y ], [sy , sz ], mode = self .mode )
870
+ assert self .count_scans (f ) == 2
886
871
887
872
rng = np .random .default_rng (utt .fetch_seed ())
888
873
x_val = rng .uniform (size = (4 ,)).astype (config .floatX )
@@ -913,6 +898,112 @@ def test_belongs_to_set(self):
913
898
assert not opt_obj .belongs_to_set (scan_node1 , [scan_node2 ])
914
899
assert not opt_obj .belongs_to_set (scan_node2 , [scan_node1 ])
915
900
901
+ @config .change_flags (cxx = "" ) # Just for faster compilation
902
+ def test_while_scan (self ):
903
+ x = vector ("x" )
904
+ y = vector ("y" )
905
+
906
+ def add (s ):
907
+ return s + 1 , until (s > 5 )
908
+
909
+ def sub (s ):
910
+ return s - 1 , until (s > 5 )
911
+
912
+ def sub_alt (s ):
913
+ return s - 1 , until (s > 4 )
914
+
915
+ sx , upx = scan (add , sequences = [x ])
916
+ sy , upy = scan (sub , sequences = [y ])
917
+
918
+ f = function ([x , y ], [sx , sy ], mode = self .mode )
919
+ assert self .count_scans (f ) == 2
920
+
921
+ sx , upx = scan (add , sequences = [x ])
922
+ sy , upy = scan (sub , sequences = [x ])
923
+
924
+ f = function ([x ], [sx , sy ], mode = self .mode )
925
+ assert self .count_scans (f ) == 1
926
+
927
+ sx , upx = scan (add , sequences = [x ])
928
+ sy , upy = scan (sub_alt , sequences = [x ])
929
+
930
+ f = function ([x ], [sx , sy ], mode = self .mode )
931
+ assert self .count_scans (f ) == 2
932
+
933
+ @config .change_flags (cxx = "" ) # Just for faster compilation
934
+ def test_while_scan_nominal_dependency (self ):
935
+ """Test case where condition depends on nominal variables.
936
+
937
+ This is a regression test for #509
938
+ """
939
+ c1 = scalar ("c1" )
940
+ c2 = scalar ("c2" )
941
+ x = vector ("x" , shape = (5 ,))
942
+ y = vector ("y" , shape = (5 ,))
943
+ z = vector ("z" , shape = (5 ,))
944
+
945
+ def add (s1 , s2 , const ):
946
+ return s1 + 1 , until (s2 > const )
947
+
948
+ def sub (s1 , s2 , const ):
949
+ return s1 - 1 , until (s2 > const )
950
+
951
+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
952
+ sy , _ = scan (sub , sequences = [y , - z ], non_sequences = [c1 ])
953
+
954
+ f = pytensor .function (inputs = [x , y , z , c1 ], outputs = [sx , sy ], mode = self .mode )
955
+ assert self .count_scans (f ) == 2
956
+ res_sx , res_sy = f (
957
+ x = [0 , 0 , 0 , 0 , 0 ],
958
+ y = [0 , 0 , 0 , 0 , 0 ],
959
+ z = [0 , 1 , 2 , 3 , 4 ],
960
+ c1 = 0 ,
961
+ )
962
+ np .testing .assert_array_equal (res_sx , [1 , 1 ])
963
+ np .testing .assert_array_equal (res_sy , [- 1 , - 1 , - 1 , - 1 , - 1 ])
964
+
965
+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
966
+ sy , _ = scan (sub , sequences = [y , z ], non_sequences = [c2 ])
967
+
968
+ f = pytensor .function (
969
+ inputs = [x , y , z , c1 , c2 ], outputs = [sx , sy ], mode = self .mode
970
+ )
971
+ assert self .count_scans (f ) == 2
972
+ res_sx , res_sy = f (
973
+ x = [0 , 0 , 0 , 0 , 0 ],
974
+ y = [0 , 0 , 0 , 0 , 0 ],
975
+ z = [0 , 1 , 2 , 3 , 4 ],
976
+ c1 = 3 ,
977
+ c2 = 1 ,
978
+ )
979
+ np .testing .assert_array_equal (res_sx , [1 , 1 , 1 , 1 , 1 ])
980
+ np .testing .assert_array_equal (res_sy , [- 1 , - 1 , - 1 ])
981
+
982
+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
983
+ sy , _ = scan (sub , sequences = [y , z ], non_sequences = [c1 ])
984
+
985
+ f = pytensor .function (inputs = [x , y , z , c1 ], outputs = [sx , sy ], mode = self .mode )
986
+ assert self .count_scans (f ) == 1
987
+
988
+ def nested_scan (c , x , z ):
989
+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c ])
990
+ sy , _ = scan (sub , sequences = [x , z ], non_sequences = [c ])
991
+ return sx .sum () + sy .sum ()
992
+
993
+ sz , _ = scan (
994
+ nested_scan ,
995
+ sequences = [stack ([c1 , c2 ])],
996
+ non_sequences = [x , z ],
997
+ mode = self .mode ,
998
+ )
999
+
1000
+ f = pytensor .function (inputs = [x , z , c1 , c2 ], outputs = sz , mode = mode )
1001
+ [scan_node ] = [
1002
+ node for node in f .maker .fgraph .apply_nodes if isinstance (node .op , Scan )
1003
+ ]
1004
+ inner_f = scan_node .op .fn
1005
+ assert self .count_scans (inner_f ) == 1
1006
+
916
1007
917
1008
class TestScanInplaceOptimizer :
918
1009
mode = get_default_mode ().including ("scan_make_inplace" , "inplace" )
0 commit comments