14
14
15
15
from pytensor .configdefaults import config
16
16
from pytensor .gradient import grad_not_implemented
17
+ from pytensor .scalar .basic import BinaryScalarOp , ScalarOp , UnaryScalarOp
18
+ from pytensor .scalar .basic import abs as scalar_abs
17
19
from pytensor .scalar .basic import (
18
- BinaryScalarOp ,
19
- ScalarOp ,
20
- UnaryScalarOp ,
21
20
as_scalar ,
22
21
complex_types ,
23
22
constant ,
27
26
expm1 ,
28
27
float64 ,
29
28
float_types ,
29
+ identity ,
30
30
isinf ,
31
31
log ,
32
32
log1p ,
33
+ reciprocal ,
34
+ scalar_maximum ,
33
35
sqrt ,
34
36
switch ,
35
37
true_div ,
@@ -1329,8 +1331,8 @@ def grad(self, inp, grads):
1329
1331
(gz ,) = grads
1330
1332
1331
1333
return [
1332
- gz * betainc_der (a , b , x , True ),
1333
- gz * betainc_der (a , b , x , False ),
1334
+ gz * betainc_grad (a , b , x , True ),
1335
+ gz * betainc_grad (a , b , x , False ),
1334
1336
gz
1335
1337
* exp (
1336
1338
log1p (- x ) * (b - 1 )
@@ -1346,28 +1348,28 @@ def c_code(self, *args, **kwargs):
1346
1348
betainc = BetaInc (upgrade_to_float_no_complex , name = "betainc" )
1347
1349
1348
1350
1349
- class BetaIncDer (ScalarOp ):
1350
- """
1351
- Gradient of the regularized incomplete beta function wrt to the first
1352
- argument (alpha) or the second argument (beta), depending on whether the
1353
- fourth argument to betainc_der is `True` or `False`, respectively.
1351
+ def betainc_grad (p , q , x , wrtp : bool ):
1352
+ """Gradient of the regularized lower gamma function (P) wrt to the first
1353
+ argument (k, a.k.a. alpha).
1354
1354
1355
- Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
1356
- Journal of Statistical Software, 3(1), 1-20.
1355
+ Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
1356
+
1357
+ Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
1358
+ ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
1357
1359
"""
1358
1360
1359
- nin = 4
1361
+ def _betainc_der (p , q , x , wrtp , skip_loop ):
1362
+ dtype = upcast (p .type .dtype , q .type .dtype , x .type .dtype , "float32" )
1363
+
1364
+ def betaln (a , b ):
1365
+ return gammaln (a ) + (gammaln (b ) - gammaln (a + b ))
1360
1366
1361
- def impl (self , p , q , x , wrtp ):
1362
1367
def _betainc_a_n (f , p , q , n ):
1363
1368
"""
1364
1369
Numerator (a_n) of the nth approximant of the continued fraction
1365
1370
representation of the regularized incomplete beta function
1366
1371
"""
1367
1372
1368
- if n == 1 :
1369
- return p * f * (q - 1 ) / (q * (p + 1 ))
1370
-
1371
1373
p2n = p + 2 * n
1372
1374
F1 = p ** 2 * f ** 2 * (n - 1 ) / (q ** 2 )
1373
1375
F2 = (
@@ -1377,7 +1379,11 @@ def _betainc_a_n(f, p, q, n):
1377
1379
/ ((p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 ))
1378
1380
)
1379
1381
1380
- return F1 * F2
1382
+ return switch (
1383
+ eq (n , 1 ),
1384
+ p * f * (q - 1 ) / (q * (p + 1 )),
1385
+ F1 * F2 ,
1386
+ )
1381
1387
1382
1388
def _betainc_b_n (f , p , q , n ):
1383
1389
"""
@@ -1397,9 +1403,6 @@ def _betainc_da_n_dp(f, p, q, n):
1397
1403
Derivative of a_n wrt p
1398
1404
"""
1399
1405
1400
- if n == 1 :
1401
- return - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 )
1402
-
1403
1406
pp = p ** 2
1404
1407
ppp = pp * p
1405
1408
p2n = p + 2 * n
@@ -1414,20 +1417,25 @@ def _betainc_da_n_dp(f, p, q, n):
1414
1417
D1 = q ** 2 * (p2n - 3 ) ** 2
1415
1418
D2 = (p2n - 2 ) ** 3 * (p2n - 1 ) ** 2
1416
1419
1417
- return (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2
1420
+ return switch (
1421
+ eq (n , 1 ),
1422
+ - p * f * (q - 1 ) / (q * (p + 1 ) ** 2 ),
1423
+ (N1 / D1 ) * (N2a + N2b + N2c + N2d + N2e ) / D2 ,
1424
+ )
1418
1425
1419
1426
def _betainc_da_n_dq (f , p , q , n ):
1420
1427
"""
1421
1428
Derivative of a_n wrt q
1422
1429
"""
1423
- if n == 1 :
1424
- return p * f / (q * (p + 1 ))
1425
-
1426
1430
p2n = p + 2 * n
1427
1431
F1 = (p ** 2 * f ** 2 / (q ** 2 )) * (n - 1 ) * (p + n - 1 ) * (2 * q + p - 2 )
1428
1432
D1 = (p2n - 3 ) * (p2n - 2 ) ** 2 * (p2n - 1 )
1429
1433
1430
- return F1 / D1
1434
+ return switch (
1435
+ eq (n , 1 ),
1436
+ p * f / (q * (p + 1 )),
1437
+ F1 / D1 ,
1438
+ )
1431
1439
1432
1440
def _betainc_db_n_dp (f , p , q , n ):
1433
1441
"""
@@ -1452,42 +1460,44 @@ def _betainc_db_n_dq(f, p, q, n):
1452
1460
p2n = p + 2 * n
1453
1461
return - (p ** 2 * f ) / (q * (p2n - 2 ) * p2n )
1454
1462
1455
- # Input validation
1456
- if not (0 <= x <= 1 ) or p < 0 or q < 0 :
1457
- return np .nan
1458
-
1459
- if x > (p / (p + q )):
1460
- return - self .impl (q , p , 1 - x , not wrtp )
1461
-
1462
- min_iters = 3
1463
- max_iters = 200
1464
- err_threshold = 1e-12
1465
-
1466
- derivative_old = 0
1463
+ min_iters = np .array (3 , dtype = "int32" )
1464
+ max_iters = switch (
1465
+ skip_loop , np .array (0 , dtype = "int32" ), np .array (200 , dtype = "int32" )
1466
+ )
1467
+ err_threshold = np .array (1e-12 , dtype = config .floatX )
1467
1468
1468
- Am2 , Am1 = 1 , 1
1469
- Bm2 , Bm1 = 0 , 1
1470
- dAm2 , dAm1 = 0 , 0
1471
- dBm2 , dBm1 = 0 , 0
1469
+ Am2 , Am1 = np . array ( 1 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1470
+ Bm2 , Bm1 = np . array ( 0 , dtype = dtype ), np . array ( 1 , dtype = dtype )
1471
+ dAm2 , dAm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
1472
+ dBm2 , dBm1 = np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype )
1472
1473
1473
1474
f = (q * x ) / (p * (1 - x ))
1474
- K = np .exp (
1475
- p * np .log (x )
1476
- + (q - 1 ) * np .log1p (- x )
1477
- - np .log (p )
1478
- - scipy .special .betaln (p , q )
1479
- )
1475
+ K = exp (p * log (x ) + (q - 1 ) * log1p (- x ) - log (p ) - betaln (p , q ))
1480
1476
if wrtp :
1481
- dK = (
1482
- np .log (x )
1483
- - 1 / p
1484
- + scipy .special .digamma (p + q )
1485
- - scipy .special .digamma (p )
1486
- )
1477
+ dK = log (x ) - reciprocal (p ) + psi (p + q ) - psi (p )
1487
1478
else :
1488
- dK = np .log1p (- x ) + scipy .special .digamma (p + q ) - scipy .special .digamma (q )
1489
-
1490
- for n in range (1 , max_iters + 1 ):
1479
+ dK = log1p (- x ) + psi (p + q ) - psi (q )
1480
+
1481
+ derivative = np .array (0 , dtype = dtype )
1482
+ n = np .array (1 , dtype = "int16" ) # Enough for 200 max iters
1483
+
1484
+ def inner_loop (
1485
+ derivative ,
1486
+ Am2 ,
1487
+ Am1 ,
1488
+ Bm2 ,
1489
+ Bm1 ,
1490
+ dAm2 ,
1491
+ dAm1 ,
1492
+ dBm2 ,
1493
+ dBm1 ,
1494
+ n ,
1495
+ f ,
1496
+ p ,
1497
+ q ,
1498
+ K ,
1499
+ dK ,
1500
+ ):
1491
1501
a_n_ = _betainc_a_n (f , p , q , n )
1492
1502
b_n_ = _betainc_b_n (f , p , q , n )
1493
1503
if wrtp :
@@ -1502,36 +1512,53 @@ def _betainc_db_n_dq(f, p, q, n):
1502
1512
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
1503
1513
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
1504
1514
1505
- Am2 , Am1 = Am1 , A
1506
- Bm2 , Bm1 = Bm1 , B
1507
- dAm2 , dAm1 = dAm1 , dA
1508
- dBm2 , dBm1 = dBm1 , dB
1509
-
1510
- if n < min_iters - 1 :
1511
- continue
1515
+ Am2 , Am1 = identity (Am1 ), identity (A )
1516
+ Bm2 , Bm1 = identity (Bm1 ), identity (B )
1517
+ dAm2 , dAm1 = identity (dAm1 ), identity (dA )
1518
+ dBm2 , dBm1 = identity (dBm1 ), identity (dB )
1512
1519
1513
1520
F1 = A / B
1514
1521
F2 = (dA - F1 * dB ) / B
1515
- derivative = K * (F1 * dK + F2 )
1522
+ derivative_new = K * (F1 * dK + F2 )
1516
1523
1517
- errapx = abs (derivative_old - derivative )
1518
- d_errapx = errapx / max (err_threshold , abs (derivative ))
1519
- derivative_old = derivative
1520
-
1521
- if d_errapx <= err_threshold :
1522
- return derivative
1524
+ errapx = scalar_abs (derivative - derivative_new )
1525
+ d_errapx = errapx / scalar_maximum (
1526
+ err_threshold , scalar_abs (derivative_new )
1527
+ )
1523
1528
1524
- warnings .warn (
1525
- f"betainc_der did not converge after { n } iterations" ,
1526
- RuntimeWarning ,
1527
- )
1528
- return np .nan
1529
+ min_iters_cond = n > (min_iters - 1 )
1530
+ derivative = switch (
1531
+ min_iters_cond ,
1532
+ derivative_new ,
1533
+ derivative ,
1534
+ )
1535
+ n += 1
1529
1536
1530
- def c_code (self , * args , ** kwargs ):
1531
- raise NotImplementedError ()
1537
+ return (
1538
+ (derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ),
1539
+ (d_errapx <= err_threshold ) & min_iters_cond ,
1540
+ )
1532
1541
1542
+ init = [derivative , Am2 , Am1 , Bm2 , Bm1 , dAm2 , dAm1 , dBm2 , dBm1 , n ]
1543
+ constant = [f , p , q , K , dK ]
1544
+ grad = _make_scalar_loop (
1545
+ max_iters , init , constant , inner_loop , name = "betainc_grad"
1546
+ )
1547
+ return grad
1533
1548
1534
- betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1549
+ # Input validation
1550
+ nan_branch = (x < 0 ) | (x > 1 ) | (p < 0 ) | (q < 0 )
1551
+ flip_branch = x > (p / (p + q ))
1552
+ grad = switch (
1553
+ nan_branch ,
1554
+ np .nan ,
1555
+ switch (
1556
+ flip_branch ,
1557
+ - _betainc_der (q , p , 1 - x , not wrtp , skip_loop = nan_branch | (~ flip_branch )),
1558
+ _betainc_der (p , q , x , wrtp , skip_loop = nan_branch | flip_branch ),
1559
+ ),
1560
+ )
1561
+ return grad
1535
1562
1536
1563
1537
1564
class Hyp2F1 (ScalarOp ):
0 commit comments