|
92 | 92 | local_grad_log_erfc_neg,
|
93 | 93 | local_greedy_distributor,
|
94 | 94 | local_mul_canonizer,
|
| 95 | + local_sum_prod_of_mul, |
95 | 96 | mul_canonizer,
|
96 | 97 | parse_mul_tree,
|
97 | 98 | perform_sigm_times_exp,
|
@@ -2503,7 +2504,7 @@ class TestLocalSumProd:
|
2503 | 2504 | def setup_method(self):
|
2504 | 2505 | self.mode = get_default_mode().including("canonicalize", "specialize")
|
2505 | 2506 |
|
2506 |
| - def test_local_sum_prod_mul_by_scalar(self): |
| 2507 | + def test_local_sum_prod_of_scalar_mul(self): |
2507 | 2508 | # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and
|
2508 | 2509 | # Prod ops in six cases each :
|
2509 | 2510 | # 1-the inputs to the mul contain a scalar and no non-scalar
|
@@ -2653,6 +2654,157 @@ def test_reduction_rewrite(
|
2653 | 2654 | axis=(0,),
|
2654 | 2655 | )
|
2655 | 2656 |
|
| 2657 | + def test_sum_of_non_scalar_mul(self): |
| 2658 | + mode = Mode("vm", optimizer="None") |
| 2659 | + rewrite = out2in(local_sum_prod_of_mul) |
| 2660 | + |
| 2661 | + row1 = matrix(shape=(1, None), dtype="float64") |
| 2662 | + row2 = matrix(shape=(1, None), dtype="float64") |
| 2663 | + col1 = matrix(shape=(None, 1), dtype="float64") |
| 2664 | + col2 = matrix(shape=(None, 1), dtype="float64") |
| 2665 | + mat1 = matrix(shape=(None, None), dtype="float64") |
| 2666 | + mat2 = matrix(shape=(None, None), dtype="float64") |
| 2667 | + |
| 2668 | + inputs = [row1, row2, col1, col2, mat1, mat2] |
| 2669 | + test_vals = [ |
| 2670 | + np.random.random((1, 2)), |
| 2671 | + np.random.random((1, 2)), |
| 2672 | + np.random.random((2, 1)), |
| 2673 | + np.random.random((2, 1)), |
| 2674 | + np.random.random((2, 2)), |
| 2675 | + np.random.random((2, 2)), |
| 2676 | + ] |
| 2677 | + |
| 2678 | + for out, expected_out in [ |
| 2679 | + ( |
| 2680 | + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None), |
| 2681 | + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None), |
| 2682 | + ), |
| 2683 | + ( |
| 2684 | + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=0), |
| 2685 | + mul(row1.squeeze(), row2.squeeze()) |
| 2686 | + * mul(mat1, mat2, col1, col2).sum(axis=0), |
| 2687 | + ), |
| 2688 | + ( |
| 2689 | + mul(row1, mat1, mat2, col1, col2).sum(axis=0), |
| 2690 | + row1.squeeze() * mul(mat1, mat2, col1, col2).sum(axis=0), |
| 2691 | + ), |
| 2692 | + ( |
| 2693 | + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=1), |
| 2694 | + mul(col1.squeeze(), col2.squeeze()) |
| 2695 | + * mul(row1, row2, mat1, mat2).sum(axis=1), |
| 2696 | + ), |
| 2697 | + ( |
| 2698 | + mul(row1, row2, mat1, mat2, col2).sum(axis=1), |
| 2699 | + col2.squeeze() * mul(row1, row2, mat1, mat2).sum(axis=1), |
| 2700 | + ), |
| 2701 | + ( |
| 2702 | + mul(row1, row2).sum(axis=1), |
| 2703 | + mul(row1, row2).sum(axis=1), |
| 2704 | + ), |
| 2705 | + ( |
| 2706 | + mul(row1, row2).sum(axis=0), |
| 2707 | + mul(row1.squeeze(), row2.squeeze()), |
| 2708 | + ), |
| 2709 | + ( |
| 2710 | + mul(row1, col1).sum(axis=0), |
| 2711 | + row1.squeeze() * col1.sum(axis=0), |
| 2712 | + ), |
| 2713 | + ]: |
| 2714 | + out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore") |
| 2715 | + |
| 2716 | + rewritten_out = rewrite_graph(out, custom_rewrite=rewrite) |
| 2717 | + assert equal_computations([rewritten_out], [expected_out]) |
| 2718 | + |
| 2719 | + rewritten_out_fn = pytensor.function( |
| 2720 | + inputs, rewritten_out, mode=mode, on_unused_input="ignore" |
| 2721 | + ) |
| 2722 | + np.testing.assert_allclose( |
| 2723 | + out_fn(*test_vals), |
| 2724 | + rewritten_out_fn(*test_vals), |
| 2725 | + ) |
| 2726 | + |
| 2727 | + def test_prod_of_non_scalar_mul(self): |
| 2728 | + mode = Mode("vm", optimizer="None") |
| 2729 | + rewrite = out2in(local_sum_prod_of_mul) |
| 2730 | + |
| 2731 | + scl1 = matrix(shape=(1, 1), dtype="float64") |
| 2732 | + row1 = matrix(shape=(1, None), dtype="float64") |
| 2733 | + row2 = matrix(shape=(1, None), dtype="float64") |
| 2734 | + col1 = matrix(shape=(None, 1), dtype="float64") |
| 2735 | + col2 = matrix(shape=(None, 1), dtype="float64") |
| 2736 | + mat1 = matrix(shape=(None, None), dtype="float64") |
| 2737 | + mat2 = matrix(shape=(None, None), dtype="float64") |
| 2738 | + |
| 2739 | + inputs = [scl1, row1, row2, col1, col2, mat1, mat2] |
| 2740 | + test_vals = [ |
| 2741 | + np.random.random((1, 1)), |
| 2742 | + np.random.random((1, 2)), |
| 2743 | + np.random.random((1, 2)), |
| 2744 | + np.random.random((2, 1)), |
| 2745 | + np.random.random((2, 1)), |
| 2746 | + np.random.random((2, 2)), |
| 2747 | + np.random.random((2, 2)), |
| 2748 | + ] |
| 2749 | + |
| 2750 | + for out, expected_out in [ |
| 2751 | + ( |
| 2752 | + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None), |
| 2753 | + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None), |
| 2754 | + ), |
| 2755 | + ( |
| 2756 | + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0), |
| 2757 | + ( |
| 2758 | + mul(row1.squeeze(), row2.squeeze()) |
| 2759 | + ** prod([mul(mat1, mat2, col1, col2).shape[0]]) |
| 2760 | + * mul(mat1, mat2, col1, col2).prod(axis=0) |
| 2761 | + ), |
| 2762 | + ), |
| 2763 | + ( |
| 2764 | + mul(row1, mat1, mat2, col1, col2).prod(axis=0), |
| 2765 | + ( |
| 2766 | + row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]]) |
| 2767 | + * mul(mat1, mat2, col1, col2).prod(axis=0) |
| 2768 | + ), |
| 2769 | + ), |
| 2770 | + ( |
| 2771 | + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1), |
| 2772 | + ( |
| 2773 | + mul(col1.squeeze(), col2.squeeze()) |
| 2774 | + ** prod([mul(row1, row2, mat1, mat2).shape[1]]) |
| 2775 | + * mul(row1, row2, mat1, mat2).prod(axis=1) |
| 2776 | + ), |
| 2777 | + ), |
| 2778 | + ( |
| 2779 | + mul(row1, row2).prod(axis=0), |
| 2780 | + mul(row1.squeeze(), row2.squeeze()), |
| 2781 | + ), |
| 2782 | + ( |
| 2783 | + mul(row1, col1).prod(axis=0), |
| 2784 | + (row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)), |
| 2785 | + ), |
| 2786 | + ( |
| 2787 | + mul(scl1, mat1, row1).prod(axis=None), |
| 2788 | + ( |
| 2789 | + scl1.squeeze() |
| 2790 | + ** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]]) |
| 2791 | + * mul(mat1, row1).prod(axis=None) |
| 2792 | + ), |
| 2793 | + ), |
| 2794 | + ]: |
| 2795 | + out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore") |
| 2796 | + |
| 2797 | + rewritten_out = rewrite_graph(out, custom_rewrite=rewrite) |
| 2798 | + assert equal_computations([rewritten_out], [expected_out]) |
| 2799 | + |
| 2800 | + rewritten_out_fn = pytensor.function( |
| 2801 | + inputs, rewritten_out, mode=mode, on_unused_input="ignore" |
| 2802 | + ) |
| 2803 | + np.testing.assert_allclose( |
| 2804 | + out_fn(*test_vals), |
| 2805 | + rewritten_out_fn(*test_vals), |
| 2806 | + ) |
| 2807 | + |
2656 | 2808 | def test_local_sum_prod_all_to_none(self):
|
2657 | 2809 | a = tensor3()
|
2658 | 2810 | input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
|
|
0 commit comments