@@ -758,40 +758,43 @@ def check_allocs_in_fgraph(fgraph, n):
758
758
def setup_method (self ):
759
759
self .rng = np .random .default_rng (seed = utt .fetch_seed ())
760
760
761
- def test_alloc_constant_folding (self ):
761
+ @pytest .mark .parametrize (
762
+ "subtensor_fn, expected_grad_n_alloc" ,
763
+ [
764
+ # IncSubtensor1
765
+ (lambda x : x [:60 ], 1 ),
766
+ # AdvancedIncSubtensor1
767
+ (lambda x : x [np .arange (60 )], 1 ),
768
+ # AdvancedIncSubtensor
769
+ (lambda x : x [np .arange (50 ), np .arange (50 )], 1 ),
770
+ ],
771
+ )
772
+ def test_alloc_constant_folding (self , subtensor_fn , expected_grad_n_alloc ):
762
773
test_params = np .asarray (self .rng .standard_normal (50 * 60 ), self .dtype )
763
774
764
775
some_vector = vector ("some_vector" , dtype = self .dtype )
765
776
some_matrix = some_vector .reshape ((60 , 50 ))
766
777
variables = self .shared (np .ones ((50 ,), dtype = self .dtype ))
767
- idx = constant (np .arange (50 ))
768
778
769
- for alloc_ , (subtensor , n_alloc ) in zip (
770
- self .allocs ,
771
- [
772
- # IncSubtensor1
773
- (some_matrix [:60 ], 2 ),
774
- # AdvancedIncSubtensor1
775
- (some_matrix [arange (60 )], 2 ),
776
- # AdvancedIncSubtensor
777
- (some_matrix [idx , idx ], 1 ),
778
- ],
779
- ):
780
- derp = pt_sum (dense_dot (subtensor , variables ))
779
+ subtensor = subtensor_fn (some_matrix )
781
780
782
- fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
783
- grad_derp = pytensor .grad (derp , some_vector )
784
- fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
785
-
786
- topo_obj = fobj .maker .fgraph .toposort ()
787
- assert sum (isinstance (node .op , type (alloc_ )) for node in topo_obj ) == 0
781
+ derp = pt_sum (dense_dot (subtensor , variables ))
782
+ fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
783
+ assert (
784
+ sum (isinstance (node .op , Alloc ) for node in fobj .maker .fgraph .apply_nodes )
785
+ == 0
786
+ )
787
+ # TODO: Assert something about the value if we bothered to call it?
788
+ fobj (test_params )
788
789
789
- topo_grad = fgrad .maker .fgraph .toposort ()
790
- assert (
791
- sum (isinstance (node .op , type (alloc_ )) for node in topo_grad ) == n_alloc
792
- ), (alloc_ , subtensor , n_alloc , topo_grad )
793
- fobj (test_params )
794
- fgrad (test_params )
790
+ grad_derp = pytensor .grad (derp , some_vector )
791
+ fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
792
+ assert (
793
+ sum (isinstance (node .op , Alloc ) for node in fgrad .maker .fgraph .apply_nodes )
794
+ == expected_grad_n_alloc
795
+ )
796
+ # TODO: Assert something about the value if we bothered to call it?
797
+ fgrad (test_params )
795
798
796
799
def test_alloc_output (self ):
797
800
val = constant (self .rng .standard_normal ((1 , 1 )), dtype = self .dtype )
0 commit comments