Skip to content

Commit ff7d38f

Browse files
committed
Use ScalarLoop for betainc gradient
1 parent a3a8751 commit ff7d38f

File tree

2 files changed

+107
-81
lines changed

2 files changed

+107
-81
lines changed

pytensor/scalar/math.py

Lines changed: 105 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
from pytensor.configdefaults import config
1616
from pytensor.gradient import grad_not_implemented
17+
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
18+
from pytensor.scalar.basic import abs as scalar_abs
1719
from pytensor.scalar.basic import (
18-
BinaryScalarOp,
19-
ScalarOp,
20-
UnaryScalarOp,
2120
as_scalar,
2221
complex_types,
2322
constant,
@@ -27,9 +26,12 @@
2726
expm1,
2827
float64,
2928
float_types,
29+
identity,
3030
isinf,
3131
log,
3232
log1p,
33+
reciprocal,
34+
scalar_maximum,
3335
sqrt,
3436
switch,
3537
true_div,
@@ -1325,8 +1327,8 @@ def grad(self, inp, grads):
13251327
(gz,) = grads
13261328

13271329
return [
1328-
gz * betainc_der(a, b, x, True),
1329-
gz * betainc_der(a, b, x, False),
1330+
gz * betainc_grad(a, b, x, True),
1331+
gz * betainc_grad(a, b, x, False),
13301332
gz
13311333
* exp(
13321334
log1p(-x) * (b - 1)
@@ -1342,28 +1344,28 @@ def c_code(self, *args, **kwargs):
13421344
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
13431345

13441346

1345-
class BetaIncDer(ScalarOp):
1346-
"""
1347-
Gradient of the regularized incomplete beta function wrt to the first
1348-
argument (alpha) or the second argument (beta), depending on whether the
1349-
fourth argument to betainc_der is `True` or `False`, respectively.
1347+
def betainc_grad(p, q, x, wrtp: bool):
1348+
"""Gradient of the regularized lower gamma function (P) wrt to the first
1349+
argument (k, a.k.a. alpha).
1350+
1351+
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
13501352
1351-
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
1352-
Journal of Statistical Software, 3(1), 1-20.
1353+
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
1354+
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
13531355
"""
13541356

1355-
nin = 4
1357+
def _betainc_der(p, q, x, wrtp, skip_loop):
1358+
dtype = upcast(p.type.dtype, q.type.dtype, x.type.dtype, "float32")
1359+
1360+
def betaln(a, b):
1361+
return gammaln(a) + (gammaln(b) - gammaln(a + b))
13561362

1357-
def impl(self, p, q, x, wrtp):
13581363
def _betainc_a_n(f, p, q, n):
13591364
"""
13601365
Numerator (a_n) of the nth approximant of the continued fraction
13611366
representation of the regularized incomplete beta function
13621367
"""
13631368

1364-
if n == 1:
1365-
return p * f * (q - 1) / (q * (p + 1))
1366-
13671369
p2n = p + 2 * n
13681370
F1 = p**2 * f**2 * (n - 1) / (q**2)
13691371
F2 = (
@@ -1373,7 +1375,11 @@ def _betainc_a_n(f, p, q, n):
13731375
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
13741376
)
13751377

1376-
return F1 * F2
1378+
return switch(
1379+
eq(n, 1),
1380+
p * f * (q - 1) / (q * (p + 1)),
1381+
F1 * F2,
1382+
)
13771383

13781384
def _betainc_b_n(f, p, q, n):
13791385
"""
@@ -1393,9 +1399,6 @@ def _betainc_da_n_dp(f, p, q, n):
13931399
Derivative of a_n wrt p
13941400
"""
13951401

1396-
if n == 1:
1397-
return -p * f * (q - 1) / (q * (p + 1) ** 2)
1398-
13991402
pp = p**2
14001403
ppp = pp * p
14011404
p2n = p + 2 * n
@@ -1410,20 +1413,25 @@ def _betainc_da_n_dp(f, p, q, n):
14101413
D1 = q**2 * (p2n - 3) ** 2
14111414
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2
14121415

1413-
return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2
1416+
return switch(
1417+
eq(n, 1),
1418+
-p * f * (q - 1) / (q * (p + 1) ** 2),
1419+
(N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2,
1420+
)
14141421

14151422
def _betainc_da_n_dq(f, p, q, n):
14161423
"""
14171424
Derivative of a_n wrt q
14181425
"""
1419-
if n == 1:
1420-
return p * f / (q * (p + 1))
1421-
14221426
p2n = p + 2 * n
14231427
F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
14241428
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)
14251429

1426-
return F1 / D1
1430+
return switch(
1431+
eq(n, 1),
1432+
p * f / (q * (p + 1)),
1433+
F1 / D1,
1434+
)
14271435

14281436
def _betainc_db_n_dp(f, p, q, n):
14291437
"""
@@ -1448,42 +1456,43 @@ def _betainc_db_n_dq(f, p, q, n):
14481456
p2n = p + 2 * n
14491457
return -(p**2 * f) / (q * (p2n - 2) * p2n)
14501458

1451-
# Input validation
1452-
if not (0 <= x <= 1) or p < 0 or q < 0:
1453-
return np.nan
1454-
1455-
if x > (p / (p + q)):
1456-
return -self.impl(q, p, 1 - x, not wrtp)
1457-
1458-
min_iters = 3
1459-
max_iters = 200
1460-
err_threshold = 1e-12
1461-
1462-
derivative_old = 0
1459+
min_iters = np.array(3, dtype="int32")
1460+
max_iters = np.array(200, dtype="int32")
1461+
err_threshold = np.array(1e-12, dtype=config.floatX)
14631462

1464-
Am2, Am1 = 1, 1
1465-
Bm2, Bm1 = 0, 1
1466-
dAm2, dAm1 = 0, 0
1467-
dBm2, dBm1 = 0, 0
1463+
Am2, Am1 = np.array(1, dtype=dtype), np.array(1, dtype=dtype)
1464+
Bm2, Bm1 = np.array(0, dtype=dtype), np.array(1, dtype=dtype)
1465+
dAm2, dAm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
1466+
dBm2, dBm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
14681467

14691468
f = (q * x) / (p * (1 - x))
1470-
K = np.exp(
1471-
p * np.log(x)
1472-
+ (q - 1) * np.log1p(-x)
1473-
- np.log(p)
1474-
- scipy.special.betaln(p, q)
1475-
)
1469+
K = exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - betaln(p, q))
14761470
if wrtp:
1477-
dK = (
1478-
np.log(x)
1479-
- 1 / p
1480-
+ scipy.special.digamma(p + q)
1481-
- scipy.special.digamma(p)
1482-
)
1471+
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
14831472
else:
1484-
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)
1485-
1486-
for n in range(1, max_iters + 1):
1473+
dK = log1p(-x) + psi(p + q) - psi(q)
1474+
1475+
derivative = np.array(0, dtype=dtype)
1476+
n = np.array(1, dtype="int16") # Enough for 200 max iters
1477+
1478+
def inner_loop(
1479+
derivative,
1480+
Am2,
1481+
Am1,
1482+
Bm2,
1483+
Bm1,
1484+
dAm2,
1485+
dAm1,
1486+
dBm2,
1487+
dBm1,
1488+
n,
1489+
f,
1490+
p,
1491+
q,
1492+
K,
1493+
dK,
1494+
skip_loop,
1495+
):
14871496
a_n_ = _betainc_a_n(f, p, q, n)
14881497
b_n_ = _betainc_b_n(f, p, q, n)
14891498
if wrtp:
@@ -1498,36 +1507,53 @@ def _betainc_db_n_dq(f, p, q, n):
14981507
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
14991508
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
15001509

1501-
Am2, Am1 = Am1, A
1502-
Bm2, Bm1 = Bm1, B
1503-
dAm2, dAm1 = dAm1, dA
1504-
dBm2, dBm1 = dBm1, dB
1505-
1506-
if n < min_iters - 1:
1507-
continue
1510+
Am2, Am1 = identity(Am1), identity(A)
1511+
Bm2, Bm1 = identity(Bm1), identity(B)
1512+
dAm2, dAm1 = identity(dAm1), identity(dA)
1513+
dBm2, dBm1 = identity(dBm1), identity(dB)
15081514

15091515
F1 = A / B
15101516
F2 = (dA - F1 * dB) / B
1511-
derivative = K * (F1 * dK + F2)
1517+
derivative_new = K * (F1 * dK + F2)
15121518

1513-
errapx = abs(derivative_old - derivative)
1514-
d_errapx = errapx / max(err_threshold, abs(derivative))
1515-
derivative_old = derivative
1516-
1517-
if d_errapx <= err_threshold:
1518-
return derivative
1519+
errapx = scalar_abs(derivative - derivative_new)
1520+
d_errapx = errapx / scalar_maximum(
1521+
err_threshold, scalar_abs(derivative_new)
1522+
)
15191523

1520-
warnings.warn(
1521-
f"betainc_der did not converge after {n} iterations",
1522-
RuntimeWarning,
1523-
)
1524-
return np.nan
1524+
min_iters_cond = n > (min_iters - 1)
1525+
derivative = switch(
1526+
min_iters_cond,
1527+
derivative_new,
1528+
derivative,
1529+
)
1530+
n += 1
15251531

1526-
def c_code(self, *args, **kwargs):
1527-
raise NotImplementedError()
1532+
return (
1533+
(derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n),
1534+
(skip_loop | ((d_errapx <= err_threshold) & min_iters_cond)),
1535+
)
15281536

1537+
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
1538+
constant = [f, p, q, K, dK, skip_loop]
1539+
grad = _make_scalar_loop(
1540+
max_iters, init, constant, inner_loop, name="betainc_grad"
1541+
)
1542+
return grad
15291543

1530-
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1544+
# Input validation
1545+
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
1546+
flip_branch = x > (p / (p + q))
1547+
grad = switch(
1548+
nan_branch,
1549+
np.nan,
1550+
switch(
1551+
flip_branch,
1552+
-_betainc_der(q, p, 1 - x, not wrtp, skip_loop=nan_branch | (~flip_branch)),
1553+
_betainc_der(p, q, x, wrtp, skip_loop=nan_branch | flip_branch),
1554+
),
1555+
)
1556+
return grad
15311557

15321558

15331559
class Hyp2F1(ScalarOp):

tests/scalar/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.link.c.basic import CLinker
99
from pytensor.scalar.math import (
1010
betainc,
11-
betainc_der,
11+
betainc_grad,
1212
gammainc,
1313
gammaincc,
1414
gammal,
@@ -82,7 +82,7 @@ def test_betainc():
8282

8383
def test_betainc_derivative_nan():
8484
a, b, x = at.scalars("a", "b", "x")
85-
res = betainc_der(a, b, x, True)
85+
res = betainc_grad(a, b, x, True)
8686
test_func = function([a, b, x], res, mode=Mode("py"))
8787
assert not np.isnan(test_func(1, 1, 1))
8888
assert np.isnan(test_func(1, 1, -1))

0 commit comments

Comments
 (0)