5
5
"""
6
6
7
7
import os
8
+ from functools import reduce
8
9
from textwrap import dedent
10
+ from typing import Tuple
9
11
10
12
import numpy as np
11
13
import scipy .special
@@ -684,12 +686,18 @@ def __hash__(self):
684
686
685
687
686
688
def _make_scalar_loop (n_steps , init , constant , inner_loop_fn , name ):
687
- init = [as_scalar (x ) for x in init ]
689
+ init = [as_scalar (x ) if x is not None else None for x in init ]
688
690
constant = [as_scalar (x ) for x in constant ]
691
+
689
692
# Create dummy types, in case some variables have the same initial form
690
- init_ = [x .type () for x in init ]
693
+ init_ = [x .type () if x is not None else None for x in init ]
691
694
constant_ = [x .type () for x in constant ]
692
695
update_ , until_ = inner_loop_fn (* init_ , * constant_ )
696
+
697
+ # Filter Nones
698
+ init = [i for i in init if i is not None ]
699
+ init_ = [i for i in init_ if i is not None ]
700
+ update_ = [u for u in update_ if u is not None ]
693
701
op = ScalarLoop (
694
702
init = init_ ,
695
703
constant = constant_ ,
@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
698
706
until_condition_failed = "warn" ,
699
707
name = name ,
700
708
)
701
- S , * _ = op (n_steps , * init , * constant )
702
- return S
709
+ return op (n_steps , * init , * constant )
703
710
704
711
705
712
def gammainc_grad (k , x ):
@@ -738,7 +745,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x, skip_loop):
738
745
739
746
init = [sum_a0 , log_gamma_k_plus_n_plus_1 , k_plus_n ]
740
747
constant = [log_x , skip_loop ]
741
- sum_a = _make_scalar_loop (
748
+ sum_a , * _ = _make_scalar_loop (
742
749
max_iters , init , constant , inner_loop_a , name = "gammainc_grad_a"
743
750
)
744
751
@@ -825,7 +832,7 @@ def inner_loop_a(
825
832
826
833
init = [sum_a0 , delta , xpow , k_minus_one_minus_n , fac , dfac ]
827
834
constant = [x , skip_loop ]
828
- sum_a = _make_scalar_loop (
835
+ sum_a , * _ = _make_scalar_loop (
829
836
n_steps , init , constant , inner_loop_a , name = "gammaincc_grad_a"
830
837
)
831
838
grad_approx_a = (
@@ -866,7 +873,7 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x, skip_loop):
866
873
867
874
init = [sum_b0 , log_s , s_sign , log_delta , n ]
868
875
constant = [k , log_x , skip_loop ]
869
- sum_b = _make_scalar_loop (
876
+ sum_b , * _ = _make_scalar_loop (
870
877
max_iters , init , constant , inner_loop_b , name = "gammaincc_grad_b"
871
878
)
872
879
grad_approx_b = (
@@ -1535,7 +1542,7 @@ def inner_loop(
1535
1542
1536
1543
init = [derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ]
1537
1544
constant = [f , p , q , K , dK , skip_loop ]
1538
- grad = _make_scalar_loop (
1545
+ grad , * _ = _make_scalar_loop (
1539
1546
max_iters , init , constant , inner_loop , name = "betainc_grad"
1540
1547
)
1541
1548
return grad
@@ -1574,10 +1581,11 @@ def impl(self, a, b, c, z):
1574
1581
def grad (self , inputs , grads ):
1575
1582
a , b , c , z = inputs
1576
1583
(gz ,) = grads
1584
+ grad_a , grad_b , grad_c = hyp2f1_grad (a , b , c , z , wrt = [0 , 1 , 2 ])
1577
1585
return [
1578
- gz * hyp2f1_grad ( a , b , c , z , wrt = 0 ) ,
1579
- gz * hyp2f1_grad ( a , b , c , z , wrt = 1 ) ,
1580
- gz * hyp2f1_grad ( a , b , c , z , wrt = 2 ) ,
1586
+ gz * grad_a ,
1587
+ gz * grad_b ,
1588
+ gz * grad_c ,
1581
1589
gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1582
1590
]
1583
1591
@@ -1593,7 +1601,158 @@ def _unsafe_sign(x):
1593
1601
return switch (x > 0 , 1 , - 1 )
1594
1602
1595
1603
1596
- def hyp2f1_grad (a , b , c , z , wrt : int ):
1604
+ def _grad_2f1_loop (a , b , c , z , * , skip_loop , wrt , dtype ):
1605
+ """
1606
+ Notes
1607
+ -----
1608
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1609
+ β_{k+1}/β_{k} = A(k)/B(k)
1610
+ β_{k+1} = A(k)/B(k) * β_{k}
1611
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1612
+
1613
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1614
+
1615
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1616
+ by dropping the respective term
1617
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1618
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1619
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1620
+
1621
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1622
+ tracking their signs.
1623
+ """
1624
+
1625
+ min_steps = np .array (
1626
+ 10 , dtype = "int32"
1627
+ ) # https://github.com/stan-dev/math/issues/2857
1628
+ max_steps = np .array (int (1e6 ), dtype = "int32" )
1629
+ precision = np .array (1e-14 , dtype = config .floatX )
1630
+
1631
+ grads = [np .array (0 , dtype = dtype ) if i in wrt else None for i in range (3 )]
1632
+ log_gs = [np .array (- np .inf , dtype = dtype ) if i in wrt else None for i in range (3 )]
1633
+ log_gs_signs = [np .array (1 , dtype = "int8" ) if i in wrt else None for i in range (3 )]
1634
+
1635
+ log_t = np .array (0.0 , dtype = dtype )
1636
+ log_t_sign = np .array (1 , dtype = "int8" )
1637
+
1638
+ log_z = log (scalar_abs (z ))
1639
+ sign_z = _unsafe_sign (z )
1640
+
1641
+ sign_zk = sign_z
1642
+ k = np .array (0 , dtype = "int32" )
1643
+
1644
+ def inner_loop (* args ):
1645
+ (
1646
+ * grads_vars ,
1647
+ log_t ,
1648
+ log_t_sign ,
1649
+ sign_zk ,
1650
+ k ,
1651
+ a ,
1652
+ b ,
1653
+ c ,
1654
+ log_z ,
1655
+ sign_z ,
1656
+ skip_loop ,
1657
+ ) = args
1658
+
1659
+ (
1660
+ grad_a ,
1661
+ grad_b ,
1662
+ grad_c ,
1663
+ log_g_a ,
1664
+ log_g_b ,
1665
+ log_g_c ,
1666
+ log_g_sign_a ,
1667
+ log_g_sign_b ,
1668
+ log_g_sign_c ,
1669
+ ) = grads_vars
1670
+
1671
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1672
+ if p .type .dtype != dtype :
1673
+ p = p .astype (dtype )
1674
+
1675
+ # If p==0, don't update grad and get out of while loop next
1676
+ p_zero = eq (p , 0 )
1677
+
1678
+ if 0 in wrt :
1679
+ term_a = log_g_sign_a * log_t_sign * exp (log_g_a - log_t )
1680
+ term_a += reciprocal (a + k )
1681
+ if term_a .type .dtype != dtype :
1682
+ term_a = term_a .astype (dtype )
1683
+ if 1 in wrt :
1684
+ term_b = log_g_sign_b * log_t_sign * exp (log_g_b - log_t )
1685
+ term_b += reciprocal (b + k )
1686
+ if term_b .type .dtype != dtype :
1687
+ term_b = term_b .astype (dtype )
1688
+ if 2 in wrt :
1689
+ term_c = log_g_sign_c * log_t_sign * exp (log_g_c - log_t )
1690
+ term_c -= reciprocal (c + k )
1691
+ if term_c .type .dtype != dtype :
1692
+ term_c = term_c .astype (dtype )
1693
+
1694
+ log_t = log_t + log (scalar_abs (p )) + log_z
1695
+ log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1696
+
1697
+ grads = [None ] * 3
1698
+ log_gs = [None ] * 3
1699
+ log_gs_signs = [None ] * 3
1700
+ grad_incs = [None ] * 3
1701
+
1702
+ if 0 in wrt :
1703
+ log_g_a = log_t + log (scalar_abs (term_a ))
1704
+ log_g_sign_a = (_unsafe_sign (term_a ) * log_t_sign ).astype ("int8" )
1705
+ grad_inc_a = log_g_sign_a * exp (log_g_a ) * sign_zk
1706
+ grads [0 ] = switch (p_zero , grad_a , grad_a + grad_inc_a )
1707
+ log_gs [0 ] = log_g_a
1708
+ log_gs_signs [0 ] = log_g_sign_a
1709
+ grad_incs [0 ] = grad_inc_a
1710
+ if 1 in wrt :
1711
+ log_g_b = log_t + log (scalar_abs (term_b ))
1712
+ log_g_sign_b = (_unsafe_sign (term_b ) * log_t_sign ).astype ("int8" )
1713
+ grad_inc_b = log_g_sign_b * exp (log_g_b ) * sign_zk
1714
+ grads [1 ] = switch (p_zero , grad_b , grad_b + grad_inc_b )
1715
+ log_gs [1 ] = log_g_b
1716
+ log_gs_signs [1 ] = log_g_sign_b
1717
+ grad_incs [1 ] = grad_inc_b
1718
+ if 2 in wrt :
1719
+ log_g_c = log_t + log (scalar_abs (term_c ))
1720
+ log_g_sign_c = (_unsafe_sign (term_c ) * log_t_sign ).astype ("int8" )
1721
+ grad_inc_c = log_g_sign_c * exp (log_g_c ) * sign_zk
1722
+ grads [2 ] = switch (p_zero , grad_c , grad_c + grad_inc_c )
1723
+ log_gs [2 ] = log_g_c
1724
+ log_gs_signs [2 ] = log_g_sign_c
1725
+ grad_incs [2 ] = grad_inc_c
1726
+
1727
+ sign_zk *= sign_z
1728
+ k += 1
1729
+
1730
+ abs_grad_incs = [
1731
+ scalar_abs (grad_inc ) for grad_inc in grad_incs if grad_inc is not None
1732
+ ]
1733
+ if len (grad_incs ) == 1 :
1734
+ [max_abs_grad_inc ] = grad_incs
1735
+ else :
1736
+ max_abs_grad_inc = reduce (scalar_maximum , abs_grad_incs )
1737
+
1738
+ return (
1739
+ (* grads , * log_gs , * log_gs_signs , log_t , log_t_sign , sign_zk , k ),
1740
+ (
1741
+ skip_loop
1742
+ | eq (p , 0 )
1743
+ | ((k > min_steps ) & (max_abs_grad_inc <= precision ))
1744
+ ),
1745
+ )
1746
+
1747
+ init = [* grads , * log_gs , * log_gs_signs , log_t , log_t_sign , sign_zk , k ]
1748
+ constant = [a , b , c , log_z , sign_z , skip_loop ]
1749
+ loop_outs = _make_scalar_loop (
1750
+ max_steps , init , constant , inner_loop , name = "hyp2f1_grad"
1751
+ )
1752
+ return loop_outs [: len (wrt )]
1753
+
1754
+
1755
+ def hyp2f1_grad (a , b , c , z , wrt : Tuple [int , ...]):
1597
1756
dtype = upcast (a .type .dtype , b .type .dtype , c .type .dtype , z .type .dtype , "float32" )
1598
1757
1599
1758
def check_2f1_converges (a , b , c , z ):
@@ -1624,132 +1783,22 @@ def is_nonpositive_integer(x):
1624
1783
is_polynomial | (scalar_abs (z ) < 1 ) | (eq (scalar_abs (z ), 1 ) & (c > (a + b )))
1625
1784
)
1626
1785
1627
- def compute_grad_2f1 (a , b , c , z , wrt , skip_loop ):
1628
- """
1629
- Notes
1630
- -----
1631
- The algorithm can be derived by looking at the ratio of two successive terms in the series
1632
- β_{k+1}/β_{k} = A(k)/B(k)
1633
- β_{k+1} = A(k)/B(k) * β_{k}
1634
- d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1635
-
1636
- In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1637
-
1638
- The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639
- by dropping the respective term
1640
- d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641
- d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642
- d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643
-
1644
- The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645
- tracking their signs.
1646
- """
1647
-
1648
- wrt_a = wrt_b = False
1649
- if wrt == 0 :
1650
- wrt_a = True
1651
- elif wrt == 1 :
1652
- wrt_b = True
1653
- elif wrt != 2 :
1654
- raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1655
-
1656
- min_steps = np .array (
1657
- 10 , dtype = "int32"
1658
- ) # https://github.com/stan-dev/math/issues/2857
1659
- max_steps = np .array (int (1e6 ), dtype = "int32" )
1660
- precision = np .array (1e-14 , dtype = config .floatX )
1661
-
1662
- grad = np .array (0 , dtype = dtype )
1663
-
1664
- log_g = np .array (- np .inf , dtype = dtype )
1665
- log_g_sign = np .array (1 , dtype = "int8" )
1666
-
1667
- log_t = np .array (0.0 , dtype = dtype )
1668
- log_t_sign = np .array (1 , dtype = "int8" )
1669
-
1670
- log_z = log (scalar_abs (z ))
1671
- sign_z = _unsafe_sign (z )
1672
-
1673
- sign_zk = sign_z
1674
- k = np .array (0 , dtype = "int32" )
1675
-
1676
- def inner_loop (
1677
- grad ,
1678
- log_g ,
1679
- log_g_sign ,
1680
- log_t ,
1681
- log_t_sign ,
1682
- sign_zk ,
1683
- k ,
1684
- a ,
1685
- b ,
1686
- c ,
1687
- log_z ,
1688
- sign_z ,
1689
- skip_loop ,
1690
- ):
1691
- p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1692
- if p .type .dtype != dtype :
1693
- p = p .astype (dtype )
1694
-
1695
- term = log_g_sign * log_t_sign * exp (log_g - log_t )
1696
- if wrt_a :
1697
- term += reciprocal (a + k )
1698
- elif wrt_b :
1699
- term += reciprocal (b + k )
1700
- else :
1701
- term -= reciprocal (c + k )
1702
-
1703
- if term .type .dtype != dtype :
1704
- term = term .astype (dtype )
1705
-
1706
- log_t = log_t + log (scalar_abs (p )) + log_z
1707
- log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1708
- log_g = log_t + log (scalar_abs (term ))
1709
- log_g_sign = (_unsafe_sign (term ) * log_t_sign ).astype ("int8" )
1710
-
1711
- g_current = log_g_sign * exp (log_g ) * sign_zk
1712
-
1713
- # If p==0, don't update grad and get out of while loop next
1714
- grad = switch (
1715
- eq (p , 0 ),
1716
- grad ,
1717
- grad + g_current ,
1718
- )
1719
-
1720
- sign_zk *= sign_z
1721
- k += 1
1722
-
1723
- return (
1724
- (grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ),
1725
- (
1726
- skip_loop
1727
- | eq (p , 0 )
1728
- | ((k > min_steps ) & (scalar_abs (g_current ) <= precision ))
1729
- ),
1730
- )
1731
-
1732
- init = [grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ]
1733
- constant = [a , b , c , log_z , sign_z , skip_loop ]
1734
- grad = _make_scalar_loop (
1735
- max_steps , init , constant , inner_loop , name = "hyp2f1_grad"
1736
- )
1737
-
1738
- return switch (
1739
- eq (z , 0 ),
1740
- 0 ,
1741
- grad ,
1742
- )
1743
-
1744
1786
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1745
1787
z_is_zero = eq (z , 0 )
1746
1788
converges = check_2f1_converges (a , b , c , z )
1747
- return switch (
1748
- z_is_zero ,
1749
- 0 ,
1750
- switch (
1751
- converges ,
1752
- compute_grad_2f1 (a , b , c , z , wrt , skip_loop = z_is_zero | (~ converges )),
1753
- np .nan ,
1754
- ),
1789
+ grads = _grad_2f1_loop (
1790
+ a , b , c , z , skip_loop = z_is_zero | (~ converges ), wrt = wrt , dtype = dtype
1755
1791
)
1792
+
1793
+ return [
1794
+ switch (
1795
+ z_is_zero ,
1796
+ 0 ,
1797
+ switch (
1798
+ converges ,
1799
+ grad ,
1800
+ np .nan ,
1801
+ ),
1802
+ )
1803
+ for grad in grads
1804
+ ]
0 commit comments