|
18 | 18 | BinaryScalarOp,
|
19 | 19 | ScalarOp,
|
20 | 20 | UnaryScalarOp,
|
| 21 | + as_scalar, |
21 | 22 | complex_types,
|
| 23 | + constant, |
22 | 24 | discrete_types,
|
| 25 | + eq, |
23 | 26 | exp,
|
24 | 27 | expm1,
|
25 | 28 | float64,
|
26 | 29 | float_types,
|
27 | 30 | isinf,
|
28 | 31 | log,
|
29 | 32 | log1p,
|
| 33 | + sqrt, |
30 | 34 | switch,
|
31 | 35 | true_div,
|
32 | 36 | upcast,
|
33 | 37 | upgrade_to_float,
|
34 | 38 | upgrade_to_float64,
|
35 | 39 | upgrade_to_float_no_complex,
|
36 | 40 | )
|
| 41 | +from pytensor.scalar.loop import ScalarLoop |
37 | 42 |
|
38 | 43 |
|
39 | 44 | class Erf(UnaryScalarOp):
|
@@ -595,7 +600,7 @@ def grad(self, inputs, grads):
|
595 | 600 | (k, x) = inputs
|
596 | 601 | (gz,) = grads
|
597 | 602 | return [
|
598 |
| - gz * gammainc_der(k, x), |
| 603 | + gz * gammainc_grad(k, x), |
599 | 604 | gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
|
600 | 605 | ]
|
601 | 606 |
|
@@ -644,7 +649,7 @@ def grad(self, inputs, grads):
|
644 | 649 | (k, x) = inputs
|
645 | 650 | (gz,) = grads
|
646 | 651 | return [
|
647 |
| - gz * gammaincc_der(k, x), |
| 652 | + gz * gammaincc_grad(k, x), |
648 | 653 | gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
|
649 | 654 | ]
|
650 | 655 |
|
@@ -675,162 +680,209 @@ def __hash__(self):
|
675 | 680 | gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
|
676 | 681 |
|
677 | 682 |
|
678 |
| -class GammaIncDer(BinaryScalarOp): |
679 |
| - """ |
680 |
| - Gradient of the the regularized lower gamma function (P) wrt to the first |
681 |
| - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp` |
| 683 | +def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): |
| 684 | + init = [as_scalar(x) for x in init] |
| 685 | + constant = [as_scalar(x) for x in constant] |
| 686 | + # Create dummy types, in case some variables have the same initial form |
| 687 | + init_ = [x.type() for x in init] |
| 688 | + constant_ = [x.type() for x in constant] |
| 689 | + update_, until_ = inner_loop_fn(*init_, *constant_) |
| 690 | + op = ScalarLoop( |
| 691 | + init=init_, |
| 692 | + constant=constant_, |
| 693 | + update=update_, |
| 694 | + until=until_, |
| 695 | + until_condition_failed="warn", |
| 696 | + name=name, |
| 697 | + ) |
| 698 | + S, *_ = op(n_steps, *init, *constant) |
| 699 | + return S |
| 700 | + |
| 701 | + |
| 702 | +def gammainc_grad(k, x): |
| 703 | + """Gradient of the regularized lower gamma function (P) wrt to the first |
| 704 | + argument (k, a.k.a. alpha). |
| 705 | +
|
| 706 | + Adapted from STAN `grad_reg_lower_inc_gamma.hpp` |
682 | 707 |
|
683 | 708 | Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
|
684 | 709 | ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
|
685 | 710 | """
|
| 711 | + dtype = upcast(k.type.dtype, x.type.dtype, "float32") |
686 | 712 |
|
687 |
| - def impl(self, k, x): |
688 |
| - if x == 0: |
689 |
| - return 0 |
690 |
| - |
691 |
| - sqrt_exp = -756 - x**2 + 60 * x |
692 |
| - if ( |
693 |
| - (k < 0.8 and x > 15) |
694 |
| - or (k < 12 and x > 30) |
695 |
| - or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp)) |
696 |
| - ): |
697 |
| - return -GammaIncCDer.st_impl(k, x) |
698 |
| - |
699 |
| - precision = 1e-10 |
700 |
| - max_iters = int(1e5) |
| 713 | + def grad_approx(skip_loop): |
| 714 | + precision = np.array(1e-10, dtype=config.floatX) |
| 715 | + max_iters = switch( |
| 716 | + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") |
| 717 | + ) |
701 | 718 |
|
702 |
| - log_x = np.log(x) |
703 |
| - log_gamma_k_plus_1 = scipy.special.gammaln(k + 1) |
| 719 | + log_x = log(x) |
| 720 | + log_gamma_k_plus_1 = gammaln(k + 1) |
704 | 721 |
|
705 |
| - k_plus_n = k |
| 722 | + # First loop |
| 723 | + k_plus_n = k # Should not overflow unless k > 2,147,383,647 |
706 | 724 | log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
|
707 |
| - sum_a = 0.0 |
708 |
| - for n in range(0, max_iters + 1): |
709 |
| - term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) |
710 |
| - sum_a += term |
| 725 | + sum_a0 = np.array(0.0, dtype=dtype) |
711 | 726 |
|
712 |
| - if term <= precision: |
713 |
| - break |
| 727 | + def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x): |
| 728 | + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) |
| 729 | + sum_a += term |
714 | 730 |
|
715 |
| - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) |
| 731 | + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) |
716 | 732 | k_plus_n += 1
|
717 |
| - |
718 |
| - if n >= max_iters: |
719 |
| - warnings.warn( |
720 |
| - f"gammainc_der did not converge after {n} iterations", |
721 |
| - RuntimeWarning, |
| 733 | + return ( |
| 734 | + (sum_a, log_gamma_k_plus_n_plus_1, k_plus_n), |
| 735 | + (term <= precision), |
722 | 736 | )
|
723 |
| - return np.nan |
724 | 737 |
|
725 |
| - k_plus_n = k |
| 738 | + init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] |
| 739 | + constant = [log_x] |
| 740 | + sum_a = _make_scalar_loop( |
| 741 | + max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" |
| 742 | + ) |
| 743 | + |
| 744 | + # Second loop |
| 745 | + n = np.array(0, dtype="int32") |
726 | 746 | log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
|
727 |
| - sum_b = 0.0 |
728 |
| - for n in range(0, max_iters + 1): |
729 |
| - term = np.exp( |
730 |
| - k_plus_n * log_x - log_gamma_k_plus_n_plus_1 |
731 |
| - ) * scipy.special.digamma(k_plus_n + 1) |
732 |
| - sum_b += term |
| 747 | + k_plus_n = k |
| 748 | + sum_b0 = np.array(0.0, dtype=dtype) |
733 | 749 |
|
734 |
| - if term <= precision and n >= 1: # Require at least two iterations |
735 |
| - return np.exp(-x) * (log_x * sum_a - sum_b) |
| 750 | + def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x): |
| 751 | + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1) |
| 752 | + sum_b += term |
736 | 753 |
|
737 |
| - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) |
| 754 | + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) |
| 755 | + n += 1 |
738 | 756 | k_plus_n += 1
|
| 757 | + return ( |
| 758 | + (sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n), |
| 759 | + # Require at least two iterations |
| 760 | + ((term <= precision) & (n > 1)), |
| 761 | + ) |
739 | 762 |
|
740 |
| - warnings.warn( |
741 |
| - f"gammainc_der did not converge after {n} iterations", |
742 |
| - RuntimeWarning, |
| 763 | + init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n] |
| 764 | + constant = [log_x] |
| 765 | + sum_b, *_ = _make_scalar_loop( |
| 766 | + max_iters, init, constant, inner_loop_b, name="gammainc_grad_b" |
743 | 767 | )
|
744 |
| - return np.nan |
745 |
| - |
746 |
| - def c_code(self, *args, **kwargs): |
747 |
| - raise NotImplementedError() |
748 |
| - |
749 |
| - |
750 |
| -gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der") |
751 | 768 |
|
752 |
| - |
753 |
| -class GammaIncCDer(BinaryScalarOp): |
754 |
| - """ |
755 |
| - Gradient of the the regularized upper gamma function (Q) wrt to the first |
756 |
| - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` |
| 769 | + grad_approx = exp(-x) * (log_x * sum_a - sum_b) |
| 770 | + return grad_approx |
| 771 | + |
| 772 | + zero_branch = eq(x, 0) |
| 773 | + sqrt_exp = -756 - x**2 + 60 * x |
| 774 | + gammaincc_branch = ( |
| 775 | + ((k < 0.8) & (x > 15)) |
| 776 | + | ((k < 12) & (x > 30)) |
| 777 | + | ((sqrt_exp > 0) & (k < sqrt(sqrt_exp))) |
| 778 | + ) |
| 779 | + grad = switch( |
| 780 | + zero_branch, |
| 781 | + 0, |
| 782 | + switch( |
| 783 | + gammaincc_branch, |
| 784 | + -gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)), |
| 785 | + grad_approx(skip_loop=zero_branch | gammaincc_branch), |
| 786 | + ), |
| 787 | + ) |
| 788 | + return grad |
| 789 | + |
| 790 | + |
| 791 | +def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): |
| 792 | + """Gradient of the regularized upper gamma function (Q) wrt to the first |
| 793 | + argument (k, a.k.a. alpha). |
| 794 | +
|
| 795 | + Adapted from STAN `grad_reg_inc_gamma.hpp` |
| 796 | +
|
| 797 | + skip_loops is used for faster branching when this function is called by `gammainc_der` |
757 | 798 | """
|
| 799 | + dtype = upcast(k.type.dtype, x.type.dtype, "float32") |
758 | 800 |
|
759 |
| - @staticmethod |
760 |
| - def st_impl(k, x): |
761 |
| - gamma_k = scipy.special.gamma(k) |
762 |
| - digamma_k = scipy.special.digamma(k) |
763 |
| - log_x = np.log(x) |
764 |
| - |
765 |
| - # asymptotic expansion http://dlmf.nist.gov/8.11#E2 |
766 |
| - if (x >= k) and (x >= 8): |
767 |
| - S = 0 |
768 |
| - k_minus_one_minus_n = k - 1 |
769 |
| - fac = k_minus_one_minus_n |
770 |
| - dfac = 1 |
771 |
| - xpow = x |
| 801 | + gamma_k = gamma(k) |
| 802 | + digamma_k = psi(k) |
| 803 | + log_x = log(x) |
| 804 | + |
| 805 | + def approx_a(skip_loop): |
| 806 | + n_steps = switch( |
| 807 | + skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32") |
| 808 | + ) |
| 809 | + sum_a0 = np.array(0.0, dtype=dtype) |
| 810 | + dfac = np.array(1.0, dtype=dtype) |
| 811 | + xpow = x |
| 812 | + k_minus_one_minus_n = k - 1 |
| 813 | + fac = k_minus_one_minus_n |
| 814 | + delta = true_div(dfac, xpow) |
| 815 | + |
| 816 | + def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x): |
| 817 | + sum_a += delta |
| 818 | + xpow *= x |
| 819 | + k_minus_one_minus_n -= 1 |
| 820 | + dfac = k_minus_one_minus_n * dfac + fac |
| 821 | + fac *= k_minus_one_minus_n |
772 | 822 | delta = dfac / xpow
|
| 823 | + return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), () |
773 | 824 |
|
774 |
| - for n in range(1, 10): |
775 |
| - k_minus_one_minus_n -= 1 |
776 |
| - S += delta |
777 |
| - xpow *= x |
778 |
| - dfac = k_minus_one_minus_n * dfac + fac |
779 |
| - fac *= k_minus_one_minus_n |
780 |
| - delta = dfac / xpow |
781 |
| - if np.isinf(delta): |
782 |
| - warnings.warn( |
783 |
| - "gammaincc_der did not converge", |
784 |
| - RuntimeWarning, |
785 |
| - ) |
786 |
| - return np.nan |
| 825 | + init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] |
| 826 | + constant = [x] |
| 827 | + sum_a = _make_scalar_loop( |
| 828 | + n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" |
| 829 | + ) |
| 830 | + grad_approx_a = ( |
| 831 | + gammaincc(k, x) * (log_x - digamma_k) |
| 832 | + + exp(-x + (k - 1) * log_x) * sum_a / gamma_k |
| 833 | + ) |
| 834 | + return grad_approx_a |
787 | 835 |
|
| 836 | + def approx_b(skip_loop): |
| 837 | + max_iters = switch( |
| 838 | + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") |
| 839 | + ) |
| 840 | + log_precision = np.array(np.log(1e-6), dtype=config.floatX) |
| 841 | + |
| 842 | + sum_b0 = np.array(0.0, dtype=dtype) |
| 843 | + log_s = np.array(0.0, dtype=dtype) |
| 844 | + s_sign = np.array(1, dtype="int8") |
| 845 | + n = np.array(1, dtype="int32") |
| 846 | + log_delta = log_s - 2 * log(k) |
| 847 | + |
| 848 | + def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): |
| 849 | + delta = exp(log_delta) |
| 850 | + sum_b += switch(s_sign > 0, delta, -delta) |
| 851 | + s_sign = -s_sign |
| 852 | + |
| 853 | + # log will cast >int16 to float64 |
| 854 | + log_s_inc = log_x - log(n) |
| 855 | + if log_s_inc.type.dtype != log_s.type.dtype: |
| 856 | + log_s_inc = log_s_inc.astype(log_s.type.dtype) |
| 857 | + log_s += log_s_inc |
| 858 | + |
| 859 | + new_log_delta = log_s - 2 * log(n + k) |
| 860 | + if new_log_delta.type.dtype != log_delta.type.dtype: |
| 861 | + new_log_delta = new_log_delta.astype(log_delta.type.dtype) |
| 862 | + log_delta = new_log_delta |
| 863 | + |
| 864 | + n += 1 |
788 | 865 | return (
|
789 |
| - scipy.special.gammaincc(k, x) * (log_x - digamma_k) |
790 |
| - + np.exp(-x + (k - 1) * log_x) * S / gamma_k |
791 |
| - ) |
792 |
| - |
793 |
| - # gradient of series expansion http://dlmf.nist.gov/8.7#E3 |
794 |
| - else: |
795 |
| - log_precision = np.log(1e-6) |
796 |
| - max_iters = int(1e5) |
797 |
| - S = 0 |
798 |
| - log_s = 0.0 |
799 |
| - s_sign = 1 |
800 |
| - log_delta = log_s - 2 * np.log(k) |
801 |
| - for n in range(1, max_iters + 1): |
802 |
| - S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) |
803 |
| - s_sign = -s_sign |
804 |
| - log_s += log_x - np.log(n) |
805 |
| - log_delta = log_s - 2 * np.log(n + k) |
806 |
| - |
807 |
| - if np.isinf(log_delta): |
808 |
| - warnings.warn( |
809 |
| - "gammaincc_der did not converge", |
810 |
| - RuntimeWarning, |
811 |
| - ) |
812 |
| - return np.nan |
813 |
| - |
814 |
| - if log_delta <= log_precision: |
815 |
| - return ( |
816 |
| - scipy.special.gammainc(k, x) * (digamma_k - log_x) |
817 |
| - + np.exp(k * log_x) * S / gamma_k |
818 |
| - ) |
819 |
| - |
820 |
| - warnings.warn( |
821 |
| - f"gammaincc_der did not converge after {n} iterations", |
822 |
| - RuntimeWarning, |
| 866 | + (sum_b, log_s, s_sign, log_delta, n), |
| 867 | + log_delta <= log_precision, |
823 | 868 | )
|
824 |
| - return np.nan |
825 |
| - |
826 |
| - def impl(self, k, x): |
827 |
| - return self.st_impl(k, x) |
828 |
| - |
829 |
| - def c_code(self, *args, **kwargs): |
830 |
| - raise NotImplementedError() |
831 | 869 |
|
832 |
| - |
833 |
| -gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") |
| 870 | + init = [sum_b0, log_s, s_sign, log_delta, n] |
| 871 | + constant = [k, log_x] |
| 872 | + sum_b = _make_scalar_loop( |
| 873 | + max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" |
| 874 | + ) |
| 875 | + grad_approx_b = ( |
| 876 | + gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k |
| 877 | + ) |
| 878 | + return grad_approx_b |
| 879 | + |
| 880 | + branch_a = (x >= k) & (x >= 8) |
| 881 | + return switch( |
| 882 | + branch_a, |
| 883 | + approx_a(skip_loop=~branch_a | skip_loops), |
| 884 | + approx_b(skip_loop=branch_a | skip_loops), |
| 885 | + ) |
834 | 886 |
|
835 | 887 |
|
836 | 888 | class GammaU(BinaryScalarOp):
|
|
0 commit comments