14
14
15
15
from pytensor .configdefaults import config
16
16
from pytensor .gradient import grad_not_implemented
17
+ from pytensor .scalar .basic import BinaryScalarOp , ScalarOp , UnaryScalarOp
18
+ from pytensor .scalar .basic import abs as scalar_abs
17
19
from pytensor .scalar .basic import (
18
- BinaryScalarOp ,
19
- ScalarOp ,
20
- UnaryScalarOp ,
21
20
as_scalar ,
22
21
complex_types ,
23
22
constant ,
27
26
expm1 ,
28
27
float64 ,
29
28
float_types ,
29
+ identity ,
30
30
isinf ,
31
31
log ,
32
32
log1p ,
33
+ reciprocal ,
34
+ scalar_maximum ,
33
35
sqrt ,
34
36
switch ,
35
37
true_div ,
@@ -1325,8 +1327,8 @@ def grad(self, inp, grads):
1325
1327
(gz ,) = grads
1326
1328
1327
1329
return [
1328
- gz * betainc_der (a , b , x , True ),
1329
- gz * betainc_der (a , b , x , False ),
1330
+ gz * betainc_grad (a , b , x , True ),
1331
+ gz * betainc_grad (a , b , x , False ),
1330
1332
gz
1331
1333
* exp (
1332
1334
log1p (- x ) * (b - 1 )
@@ -1342,28 +1344,28 @@ def c_code(self, *args, **kwargs):
1342
1344
betainc = BetaInc (upgrade_to_float_no_complex , name = "betainc" )
1343
1345
1344
1346
1345
- class BetaIncDer ( ScalarOp ):
1346
- """
1347
- Gradient of the regularized incomplete beta function wrt to the first
1348
- argument (alpha) or the second argument (beta), depending on whether the
1349
- fourth argument to betainc_der is `True` or `False`, respectively.
1347
+ def betainc_grad ( p , q , x , wrtp : bool ):
1348
+ """Gradient of the regularized lower gamma function (P) wrt to the first
1349
+ argument (k, a.k.a. alpha).
1350
+
1351
+ Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
1350
1352
1351
- Reference: Boik, R. J., & Robison-Cox, J. F. (1998 ). Derivatives of the incomplete beta function .
1352
- Journal of Statistical Software, 3(1 ), 1-20 .
1353
+ Reference: Gautschi, W. (1979 ). A computational procedure for incomplete gamma functions .
1354
+ ACM Transactions on Mathematical Software (TOMS), 5(4 ), 466-481 .
1353
1355
"""
1354
1356
1355
- nin = 4
1357
+ def _betainc_der (p , q , x , wrtp , skip_loop ):
1358
+ dtype = upcast (p .type .dtype , q .type .dtype , x .type .dtype , "float32" )
1359
+
1360
+ def betaln (a , b ):
1361
+ return gammaln (a ) + (gammaln (b ) - gammaln (a + b ))
1356
1362
1357
- def impl (self , p , q , x , wrtp ):
1358
1363
def _betainc_a_n (f , p , q , n ):
1359
1364
"""
1360
1365
Numerator (a_n) of the nth approximant of the continued fraction
1361
1366
representation of the regularized incomplete beta function
1362
1367
"""
1363
1368
1364
- if n == 1 :
1365
- return p * f * (q - 1 ) / (q * (p + 1 ))
1366
-
1367
1369
p2n = p + 2 * n
1368
1370
F1 = p ** 2 * f ** 2 * (n - 1 ) / (q ** 2 )
1369
1371
F2 = (
@@ -1373,7 +1375,11 @@ def _betainc_a_n(f, p, q, n):
1373
1375
/ ((p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 ))
1374
1376
)
1375
1377
1376
- return F1 * F2
1378
+ return switch (
1379
+ eq (n , 1 ),
1380
+ p * f * (q - 1 ) / (q * (p + 1 )),
1381
+ F1 * F2 ,
1382
+ )
1377
1383
1378
1384
def _betainc_b_n (f , p , q , n ):
1379
1385
"""
@@ -1393,9 +1399,6 @@ def _betainc_da_n_dp(f, p, q, n):
1393
1399
Derivative of a_n wrt p
1394
1400
"""
1395
1401
1396
- if n == 1 :
1397
- return - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 )
1398
-
1399
1402
pp = p ** 2
1400
1403
ppp = pp * p
1401
1404
p2n = p + 2 * n
@@ -1410,20 +1413,25 @@ def _betainc_da_n_dp(f, p, q, n):
1410
1413
D1 = q ** 2 * (p2n - 3 ) ** 2
1411
1414
D2 = (p2n - 2 ) ** 3 * (p2n - 1 ) ** 2
1412
1415
1413
- return (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2
1416
+ return switch (
1417
+ eq (n , 1 ),
1418
+ - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 ),
1419
+ (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2 ,
1420
+ )
1414
1421
1415
1422
def _betainc_da_n_dq (f , p , q , n ):
1416
1423
"""
1417
1424
Derivative of a_n wrt q
1418
1425
"""
1419
- if n == 1 :
1420
- return p * f / (q * (p + 1 ))
1421
-
1422
1426
p2n = p + 2 * n
1423
1427
F1 = (p ** 2 * f ** 2 / (q ** 2 )) * (n - 1 ) * (p + n - 1 ) * (2 * q + p - 2 )
1424
1428
D1 = (p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 )
1425
1429
1426
- return F1 / D1
1430
+ return switch (
1431
+ eq (n , 1 ),
1432
+ p * f / (q * (p + 1 )),
1433
+ F1 / D1 ,
1434
+ )
1427
1435
1428
1436
def _betainc_db_n_dp (f , p , q , n ):
1429
1437
"""
@@ -1448,42 +1456,43 @@ def _betainc_db_n_dq(f, p, q, n):
1448
1456
p2n = p + 2 * n
1449
1457
return - (p ** 2 * f ) / (q * (p2n - 2 ) * p2n )
1450
1458
1451
- # Input validation
1452
- if not (0 <= x <= 1 ) or p < 0 or q < 0 :
1453
- return np .nan
1454
-
1455
- if x > (p / (p + q )):
1456
- return - self .impl (q , p , 1 - x , not wrtp )
1457
-
1458
- min_iters = 3
1459
- max_iters = 200
1460
- err_threshold = 1e-12
1461
-
1462
- derivative_old = 0
1459
+ min_iters = np .array (3 , dtype = "int32" )
1460
+ max_iters = np .array (200 , dtype = "int32" )
1461
+ err_threshold = np .array (1e-12 , dtype = config .floatX )
1463
1462
1464
- Am2 , Am1 = 1 , 1
1465
- Bm2 , Bm1 = 0 , 1
1466
- dAm2 , dAm1 = 0 , 0
1467
- dBm2 , dBm1 = 0 , 0
1463
+ Am2 , Am1 = np . array ( 1 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1464
+ Bm2 , Bm1 = np . array ( 0 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1465
+ dAm2 , dAm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
1466
+ dBm2 , dBm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
1468
1467
1469
1468
f = (q * x ) / (p * (1 - x ))
1470
- K = np .exp (
1471
- p * np .log (x )
1472
- + (q - 1 ) * np .log1p (- x )
1473
- - np .log (p )
1474
- - scipy .special .betaln (p , q )
1475
- )
1469
+ K = exp (p * log (x ) + (q - 1 ) * log1p (- x ) - log (p ) - betaln (p , q ))
1476
1470
if wrtp :
1477
- dK = (
1478
- np .log (x )
1479
- - 1 / p
1480
- + scipy .special .digamma (p + q )
1481
- - scipy .special .digamma (p )
1482
- )
1471
+ dK = log (x ) - reciprocal (p ) + psi (p + q ) - psi (p )
1483
1472
else :
1484
- dK = np .log1p (- x ) + scipy .special .digamma (p + q ) - scipy .special .digamma (q )
1485
-
1486
- for n in range (1 , max_iters + 1 ):
1473
+ dK = log1p (- x ) + psi (p + q ) - psi (q )
1474
+
1475
+ derivative = np .array (0 , dtype = dtype )
1476
+ n = np .array (1 , dtype = "int16" ) # Enough for 200 max iters
1477
+
1478
+ def inner_loop (
1479
+ derivative ,
1480
+ Am2 ,
1481
+ Am1 ,
1482
+ Bm2 ,
1483
+ Bm1 ,
1484
+ dAm2 ,
1485
+ dAm1 ,
1486
+ dBm2 ,
1487
+ dBm1 ,
1488
+ n ,
1489
+ f ,
1490
+ p ,
1491
+ q ,
1492
+ K ,
1493
+ dK ,
1494
+ skip_loop ,
1495
+ ):
1487
1496
a_n_ = _betainc_a_n (f , p , q , n )
1488
1497
b_n_ = _betainc_b_n (f , p , q , n )
1489
1498
if wrtp :
@@ -1498,36 +1507,53 @@ def _betainc_db_n_dq(f, p, q, n):
1498
1507
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
1499
1508
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
1500
1509
1501
- Am2 , Am1 = Am1 , A
1502
- Bm2 , Bm1 = Bm1 , B
1503
- dAm2 , dAm1 = dAm1 , dA
1504
- dBm2 , dBm1 = dBm1 , dB
1505
-
1506
- if n < min_iters - 1 :
1507
- continue
1510
+ Am2 , Am1 = identity (Am1 ), identity (A )
1511
+ Bm2 , Bm1 = identity (Bm1 ), identity (B )
1512
+ dAm2 , dAm1 = identity (dAm1 ), identity (dA )
1513
+ dBm2 , dBm1 = identity (dBm1 ), identity (dB )
1508
1514
1509
1515
F1 = A / B
1510
1516
F2 = (dA - F1 * dB ) / B
1511
- derivative = K * (F1 * dK + F2 )
1517
+ derivative_new = K * (F1 * dK + F2 )
1512
1518
1513
- errapx = abs (derivative_old - derivative )
1514
- d_errapx = errapx / max (err_threshold , abs (derivative ))
1515
- derivative_old = derivative
1516
-
1517
- if d_errapx <= err_threshold :
1518
- return derivative
1519
+ errapx = scalar_abs (derivative - derivative_new )
1520
+ d_errapx = errapx / scalar_maximum (
1521
+ err_threshold , scalar_abs (derivative_new )
1522
+ )
1519
1523
1520
- warnings .warn (
1521
- f"betainc_der did not converge after { n } iterations" ,
1522
- RuntimeWarning ,
1523
- )
1524
- return np .nan
1524
+ min_iters_cond = n > (min_iters - 1 )
1525
+ derivative = switch (
1526
+ min_iters_cond ,
1527
+ derivative_new ,
1528
+ derivative ,
1529
+ )
1530
+ n += 1
1525
1531
1526
- def c_code (self , * args , ** kwargs ):
1527
- raise NotImplementedError ()
1532
+ return (
1533
+ (derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ),
1534
+ (skip_loop | ((d_errapx <= err_threshold ) & min_iters_cond )),
1535
+ )
1528
1536
1537
+ init = [derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ]
1538
+ constant = [f , p , q , K , dK , skip_loop ]
1539
+ grad = _make_scalar_loop (
1540
+ max_iters , init , constant , inner_loop , name = "betainc_grad"
1541
+ )
1542
+ return grad
1529
1543
1530
- betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1544
+ # Input validation
1545
+ nan_branch = (x < 0 ) | (x > 1 ) | (p < 0 ) | (q < 0 )
1546
+ flip_branch = x > (p / (p + q ))
1547
+ grad = switch (
1548
+ nan_branch ,
1549
+ np .nan ,
1550
+ switch (
1551
+ flip_branch ,
1552
+ - _betainc_der (q , p , 1 - x , not wrtp , skip_loop = nan_branch | (~ flip_branch )),
1553
+ _betainc_der (p , q , x , wrtp , skip_loop = nan_branch | flip_branch ),
1554
+ ),
1555
+ )
1556
+ return grad
1531
1557
1532
1558
1533
1559
class Hyp2F1 (ScalarOp ):
0 commit comments