5
5
"""
6
6
7
7
import os
8
- import warnings
9
8
from textwrap import dedent
10
9
11
10
import numpy as np
26
25
expm1 ,
27
26
float64 ,
28
27
float_types ,
28
+ floor ,
29
29
identity ,
30
+ integer_types ,
30
31
isinf ,
31
32
log ,
32
33
log1p ,
@@ -849,15 +850,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x, skip_loop):
849
850
s_sign = - s_sign
850
851
851
852
# log will cast >int16 to float64
852
- log_s_inc = log_x - log (n )
853
- if log_s_inc .type .dtype != log_s .type .dtype :
854
- log_s_inc = log_s_inc .astype (log_s .type .dtype )
855
- log_s += log_s_inc
853
+ log_s += log_x - log (n )
854
+ if log_s .type .dtype != dtype :
855
+ log_s = log_s .astype (dtype )
856
856
857
- new_log_delta = log_s - 2 * log (n + k )
858
- if new_log_delta .type .dtype != log_delta .type .dtype :
859
- new_log_delta = new_log_delta .astype (log_delta .type .dtype )
860
- log_delta = new_log_delta
857
+ log_delta = log_s - 2 * log (n + k )
858
+ if log_delta .type .dtype != dtype :
859
+ log_delta = log_delta .astype (dtype )
861
860
862
861
n += 1
863
862
return (
@@ -1576,9 +1575,9 @@ def grad(self, inputs, grads):
1576
1575
a , b , c , z = inputs
1577
1576
(gz ,) = grads
1578
1577
return [
1579
- gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1580
- gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1581
- gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1578
+ gz * hyp2f1_grad (a , b , c , z , wrt = 0 ),
1579
+ gz * hyp2f1_grad (a , b , c , z , wrt = 1 ),
1580
+ gz * hyp2f1_grad (a , b , c , z , wrt = 2 ),
1582
1581
gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1583
1582
]
1584
1583
@@ -1589,134 +1588,168 @@ def c_code(self, *args, **kwargs):
1589
1588
hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1590
1589
1591
1590
1592
- class Hyp2F1Der ( ScalarOp ):
1593
- """
1594
- Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1591
+ def _unsafe_sign ( x ):
1592
+ # Unlike scalar.sign we don't worry about x being 0 or nan
1593
+ return switch ( x > 0 , 1 , - 1 )
1595
1594
1596
- Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1597
- """
1598
1595
1599
- nin = 5
1596
+ def hyp2f1_grad (a , b , c , z , wrt : int ):
1597
+ dtype = upcast (a .type .dtype , b .type .dtype , c .type .dtype , z .type .dtype , "float32" )
1600
1598
1601
- def impl (self , a , b , c , z , wrt ):
1602
- def check_2f1_converges (a , b , c , z ) -> bool :
1603
- num_terms = 0
1604
- is_polynomial = False
1599
+ def check_2f1_converges (a , b , c , z ):
1600
+ def is_nonpositive_integer (x ):
1601
+ if x .type .dtype not in integer_types :
1602
+ return eq (floor (x ), x ) & (x <= 0 )
1603
+ else :
1604
+ return x <= 0
1605
1605
1606
- def is_nonpositive_integer (x ):
1607
- return x <= 0 and x .is_integer ()
1606
+ a_is_polynomial = is_nonpositive_integer (a ) & (scalar_abs (a ) >= 0 )
1607
+ num_terms = switch (
1608
+ a_is_polynomial ,
1609
+ floor (scalar_abs (a )).astype ("int64" ),
1610
+ 0 ,
1611
+ )
1608
1612
1609
- if is_nonpositive_integer (a ) and abs ( a ) >= num_terms :
1610
- is_polynomial = True
1611
- num_terms = int ( np . floor ( abs ( a )))
1612
- if is_nonpositive_integer ( b ) and abs ( b ) >= num_terms :
1613
- is_polynomial = True
1614
- num_terms = int ( np . floor ( abs ( b )) )
1613
+ b_is_polynomial = is_nonpositive_integer (b ) & ( scalar_abs ( b ) >= num_terms )
1614
+ num_terms = switch (
1615
+ b_is_polynomial ,
1616
+ floor ( scalar_abs ( b )). astype ( "int64" ),
1617
+ num_terms ,
1618
+ )
1615
1619
1616
- is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1620
+ is_undefined = is_nonpositive_integer (c ) & (scalar_abs (c ) <= num_terms )
1621
+ is_polynomial = a_is_polynomial | b_is_polynomial
1617
1622
1618
- return not is_undefined and (
1619
- is_polynomial or np . abs ( z ) < 1 or ( np . abs ( z ) == 1 and c > (a + b ))
1620
- )
1623
+ return ( ~ is_undefined ) & (
1624
+ is_polynomial | ( scalar_abs ( z ) < 1 ) | ( eq ( scalar_abs ( z ), 1 ) & ( c > (a + b ) ))
1625
+ )
1621
1626
1622
- def compute_grad_2f1 (a , b , c , z , wrt ):
1623
- """
1624
- Notes
1625
- -----
1626
- The algorithm can be derived by looking at the ratio of two successive terms in the series
1627
- β_{k+1}/β_{k} = A(k)/B(k)
1628
- β_{k+1} = A(k)/B(k) * β_{k}
1629
- d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1630
-
1631
- In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1632
-
1633
- The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1634
- by dropping the respective term
1635
- d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1636
- d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1637
- d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1638
-
1639
- The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1640
- tracking their signs.
1641
- """
1627
+ def compute_grad_2f1 (a , b , c , z , wrt , skip_loop ):
1628
+ """
1629
+ Notes
1630
+ -----
1631
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1632
+ β_{k+1}/β_{k} = A(k)/B(k)
1633
+ β_{k+1} = A(k)/B(k) * β_{k}
1634
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1642
1635
1643
- wrt_a = wrt_b = False
1644
- if wrt == 0 :
1645
- wrt_a = True
1646
- elif wrt == 1 :
1647
- wrt_b = True
1648
- elif wrt != 2 :
1649
- raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1650
-
1651
- min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1652
- max_steps = int (1e6 )
1653
- precision = 1e-14
1654
-
1655
- res = 0
1656
-
1657
- if z == 0 :
1658
- return res
1659
-
1660
- log_g_old = - np .inf
1661
- log_t_old = 0.0
1662
- log_t_new = 0.0
1663
- sign_z = np .sign (z )
1664
- log_z = np .log (np .abs (z ))
1665
-
1666
- log_g_old_sign = 1
1667
- log_t_old_sign = 1
1668
- log_t_new_sign = 1
1669
- sign_zk = sign_z
1670
-
1671
- for k in range (max_steps ):
1672
- p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1673
- if p == 0 :
1674
- return res
1675
- log_t_new += np .log (np .abs (p )) + log_z
1676
- log_t_new_sign = np .sign (p ) * log_t_new_sign
1677
-
1678
- term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1679
- if wrt_a :
1680
- term += np .reciprocal (a + k )
1681
- elif wrt_b :
1682
- term += np .reciprocal (b + k )
1683
- else :
1684
- term -= np .reciprocal (c + k )
1685
-
1686
- log_g_old = log_t_new + np .log (np .abs (term ))
1687
- log_g_old_sign = np .sign (term ) * log_t_new_sign
1688
- g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1689
- res += g_current
1690
-
1691
- log_t_old = log_t_new
1692
- log_t_old_sign = log_t_new_sign
1693
- sign_zk *= sign_z
1694
-
1695
- if k >= min_steps and np .abs (g_current ) <= precision :
1696
- return res
1697
-
1698
- warnings .warn (
1699
- f"hyp2f1_der did not converge after { k } iterations" ,
1700
- RuntimeWarning ,
1701
- )
1702
- return np .nan
1636
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1703
1637
1704
- # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1705
- if not check_2f1_converges (a , b , c , z ):
1706
- warnings .warn (
1707
- f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1708
- RuntimeWarning ,
1638
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639
+ by dropping the respective term
1640
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643
+
1644
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645
+ tracking their signs.
1646
+ """
1647
+
1648
+ wrt_a = wrt_b = False
1649
+ if wrt == 0 :
1650
+ wrt_a = True
1651
+ elif wrt == 1 :
1652
+ wrt_b = True
1653
+ elif wrt != 2 :
1654
+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1655
+
1656
+ min_steps = np .array (
1657
+ 10 , dtype = "int32"
1658
+ ) # https://github.com/stan-dev/math/issues/2857
1659
+ max_steps = np .array (int (1e6 ), dtype = "int32" )
1660
+ precision = np .array (1e-14 , dtype = config .floatX )
1661
+
1662
+ grad = np .array (0 , dtype = dtype )
1663
+
1664
+ log_g = np .array (- np .inf , dtype = dtype )
1665
+ log_g_sign = np .array (1 , dtype = "int8" )
1666
+
1667
+ log_t = np .array (0.0 , dtype = dtype )
1668
+ log_t_sign = np .array (1 , dtype = "int8" )
1669
+
1670
+ log_z = log (scalar_abs (z ))
1671
+ sign_z = _unsafe_sign (z )
1672
+
1673
+ sign_zk = sign_z
1674
+ k = np .array (0 , dtype = "int32" )
1675
+
1676
+ def inner_loop (
1677
+ grad ,
1678
+ log_g ,
1679
+ log_g_sign ,
1680
+ log_t ,
1681
+ log_t_sign ,
1682
+ sign_zk ,
1683
+ k ,
1684
+ a ,
1685
+ b ,
1686
+ c ,
1687
+ log_z ,
1688
+ sign_z ,
1689
+ skip_loop ,
1690
+ ):
1691
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1692
+ if p .type .dtype != dtype :
1693
+ p = p .astype (dtype )
1694
+
1695
+ term = log_g_sign * log_t_sign * exp (log_g - log_t )
1696
+ if wrt_a :
1697
+ term += reciprocal (a + k )
1698
+ elif wrt_b :
1699
+ term += reciprocal (b + k )
1700
+ else :
1701
+ term -= reciprocal (c + k )
1702
+
1703
+ if term .type .dtype != dtype :
1704
+ term = term .astype (dtype )
1705
+
1706
+ log_t = log_t + log (scalar_abs (p )) + log_z
1707
+ log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1708
+ log_g = log_t + log (scalar_abs (term ))
1709
+ log_g_sign = (_unsafe_sign (term ) * log_t_sign ).astype ("int8" )
1710
+
1711
+ g_current = log_g_sign * exp (log_g ) * sign_zk
1712
+
1713
+ # If p==0, don't update grad and get out of while loop next
1714
+ grad = switch (
1715
+ eq (p , 0 ),
1716
+ grad ,
1717
+ grad + g_current ,
1709
1718
)
1710
- return np .nan
1711
1719
1712
- return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1720
+ sign_zk *= sign_z
1721
+ k += 1
1713
1722
1714
- def __call__ (self , a , b , c , z , wrt , ** kwargs ):
1715
- # This allows wrt to be a keyword argument
1716
- return super ().__call__ (a , b , c , z , wrt , ** kwargs )
1723
+ return (
1724
+ (grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ),
1725
+ (
1726
+ skip_loop
1727
+ | eq (p , 0 )
1728
+ | ((k > min_steps ) & (scalar_abs (g_current ) <= precision ))
1729
+ ),
1730
+ )
1717
1731
1718
- def c_code (self , * args , ** kwargs ):
1719
- raise NotImplementedError ()
1732
+ init = [grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ]
1733
+ constant = [a , b , c , log_z , sign_z , skip_loop ]
1734
+ grad = _make_scalar_loop (
1735
+ max_steps , init , constant , inner_loop , name = "hyp2f1_grad"
1736
+ )
1720
1737
1738
+ return switch (
1739
+ eq (z , 0 ),
1740
+ 0 ,
1741
+ grad ,
1742
+ )
1721
1743
1722
- hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
1744
+ # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1745
+ z_is_zero = eq (z , 0 )
1746
+ converges = check_2f1_converges (a , b , c , z )
1747
+ return switch (
1748
+ z_is_zero ,
1749
+ 0 ,
1750
+ switch (
1751
+ converges ,
1752
+ compute_grad_2f1 (a , b , c , z , wrt , skip_loop = z_is_zero | (~ converges )),
1753
+ np .nan ,
1754
+ ),
1755
+ )
0 commit comments