Skip to content

Commit 849c556

Browse files
committed
Refactor sum_prod_mul rewrite test and add failing case
Rewrite from prod of mul was not correct when only some axes were reduced by prod
1 parent 5ad1181 commit 849c556

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,7 @@ def test_local_sum_prod_mul_by_scalar(self):
25122512
# 4-the inputs to the mul contain two scalars and no non-scalar
25132513
# 5-the inputs to the mul contain two scalars and one non-scalar
25142514
# 6-the inputs to the mul contain two scalars and two non-scalars
2515+
# 7-the reduction happens across only the first of two axes
25152516

25162517
vect = dvector()
25172518
mat = dmatrix()
@@ -2524,10 +2525,15 @@ def test_local_sum_prod_mul_by_scalar(self):
25242525
s2_val = np.random.random()
25252526

25262527
def test_reduction_rewrite(
2527-
inputs, inputs_val, reduction_op, expected_output, nb_expected_sum_nodes
2528+
inputs,
2529+
inputs_val,
2530+
reduction_op,
2531+
expected_output,
2532+
nb_expected_sum_nodes,
2533+
axis=None,
25282534
):
25292535
mul_out = mul(*inputs)
2530-
f = function(inputs, reduction_op()(mul_out), mode=self.mode)
2536+
f = function(inputs, reduction_op(axis=axis)(mul_out), mode=self.mode)
25312537
out = f(*inputs_val)
25322538
utt.assert_allclose(out, expected_output)
25332539

@@ -2581,6 +2587,16 @@ def test_reduction_rewrite(
25812587
1,
25822588
)
25832589

2590+
# Case 7
2591+
test_reduction_rewrite(
2592+
[mat, scalar1, scalar2],
2593+
[m_val, s1_val, s2_val],
2594+
Sum,
2595+
(s1_val * s2_val * m_val).sum(0),
2596+
1,
2597+
axis=(0,),
2598+
)
2599+
25842600
# Test prod
25852601

25862602
# Case 1
@@ -2627,6 +2643,16 @@ def test_reduction_rewrite(
26272643
2,
26282644
)
26292645

2646+
# Case 7
2647+
test_reduction_rewrite(
2648+
[mat, scalar1, scalar2],
2649+
[m_val, s1_val, s2_val],
2650+
Prod,
2651+
(s1_val * s2_val * m_val).prod(0),
2652+
1,
2653+
axis=(0,),
2654+
)
2655+
26302656
def test_local_sum_prod_all_to_none(self):
26312657
a = tensor3()
26322658
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)

0 commit comments

Comments
 (0)