@@ -2512,6 +2512,7 @@ def test_local_sum_prod_mul_by_scalar(self):
2512
2512
# 4-the inputs to the mul contain two scalars and no non-scalar
2513
2513
# 5-the inputs to the mul contain two scalars and one non-scalar
2514
2514
# 6-the inputs to the mul contain two scalars and two non-scalars
2515
+ # 7-the reduction happens across only the first of two axes
2515
2516
2516
2517
vect = dvector ()
2517
2518
mat = dmatrix ()
@@ -2524,10 +2525,15 @@ def test_local_sum_prod_mul_by_scalar(self):
2524
2525
s2_val = np .random .random ()
2525
2526
2526
2527
def test_reduction_rewrite (
2527
- inputs , inputs_val , reduction_op , expected_output , nb_expected_sum_nodes
2528
+ inputs ,
2529
+ inputs_val ,
2530
+ reduction_op ,
2531
+ expected_output ,
2532
+ nb_expected_sum_nodes ,
2533
+ axis = None ,
2528
2534
):
2529
2535
mul_out = mul (* inputs )
2530
- f = function (inputs , reduction_op ()(mul_out ), mode = self .mode )
2536
+ f = function (inputs , reduction_op (axis = axis )(mul_out ), mode = self .mode )
2531
2537
out = f (* inputs_val )
2532
2538
utt .assert_allclose (out , expected_output )
2533
2539
@@ -2581,6 +2587,16 @@ def test_reduction_rewrite(
2581
2587
1 ,
2582
2588
)
2583
2589
2590
+ # Case 7
2591
+ test_reduction_rewrite (
2592
+ [mat , scalar1 , scalar2 ],
2593
+ [m_val , s1_val , s2_val ],
2594
+ Sum ,
2595
+ (s1_val * s2_val * m_val ).sum (0 ),
2596
+ 1 ,
2597
+ axis = (0 ,),
2598
+ )
2599
+
2584
2600
# Test prod
2585
2601
2586
2602
# Case 1
@@ -2627,6 +2643,16 @@ def test_reduction_rewrite(
2627
2643
2 ,
2628
2644
)
2629
2645
2646
+ # Case 7
2647
+ test_reduction_rewrite (
2648
+ [mat , scalar1 , scalar2 ],
2649
+ [m_val , s1_val , s2_val ],
2650
+ Prod ,
2651
+ (s1_val * s2_val * m_val ).prod (0 ),
2652
+ 1 ,
2653
+ axis = (0 ,),
2654
+ )
2655
+
2630
2656
def test_local_sum_prod_all_to_none (self ):
2631
2657
a = tensor3 ()
2632
2658
input = np .arange (3 * 4 * 5 , dtype = config .floatX ).reshape (3 , 4 , 5 )
0 commit comments