Skip to content

Commit c1e7726

Browse files
committed
Use ScalarLoop for betainc gradient
1 parent ccae99d commit c1e7726

File tree

2 files changed

+108
-81
lines changed

2 files changed

+108
-81
lines changed

pytensor/scalar/math.py

Lines changed: 106 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
from pytensor.configdefaults import config
1616
from pytensor.gradient import grad_not_implemented
17+
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
18+
from pytensor.scalar.basic import abs as scalar_abs
1719
from pytensor.scalar.basic import (
18-
BinaryScalarOp,
19-
ScalarOp,
20-
UnaryScalarOp,
2120
as_scalar,
2221
complex_types,
2322
constant,
@@ -27,9 +26,12 @@
2726
expm1,
2827
float64,
2928
float_types,
29+
identity,
3030
isinf,
3131
log,
3232
log1p,
33+
reciprocal,
34+
scalar_maximum,
3335
sqrt,
3436
switch,
3537
true_div,
@@ -1329,8 +1331,8 @@ def grad(self, inp, grads):
13291331
(gz,) = grads
13301332

13311333
return [
1332-
gz * betainc_der(a, b, x, True),
1333-
gz * betainc_der(a, b, x, False),
1334+
gz * betainc_grad(a, b, x, True),
1335+
gz * betainc_grad(a, b, x, False),
13341336
gz
13351337
* exp(
13361338
log1p(-x) * (b - 1)
@@ -1346,28 +1348,28 @@ def c_code(self, *args, **kwargs):
13461348
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
13471349

13481350

1349-
class BetaIncDer(ScalarOp):
1350-
"""
1351-
Gradient of the regularized incomplete beta function wrt to the first
1352-
argument (alpha) or the second argument (beta), depending on whether the
1353-
fourth argument to betainc_der is `True` or `False`, respectively.
1351+
def betainc_grad(p, q, x, wrtp: bool):
1352+
"""Gradient of the regularized lower gamma function (P) wrt to the first
1353+
argument (k, a.k.a. alpha).
13541354
1355-
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
1356-
Journal of Statistical Software, 3(1), 1-20.
1355+
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
1356+
1357+
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
1358+
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
13571359
"""
13581360

1359-
nin = 4
1361+
def _betainc_der(p, q, x, wrtp, skip_loop):
1362+
dtype = upcast(p.type.dtype, q.type.dtype, x.type.dtype, "float32")
1363+
1364+
def betaln(a, b):
1365+
return gammaln(a) + (gammaln(b) - gammaln(a + b))
13601366

1361-
def impl(self, p, q, x, wrtp):
13621367
def _betainc_a_n(f, p, q, n):
13631368
"""
13641369
Numerator (a_n) of the nth approximant of the continued fraction
13651370
representation of the regularized incomplete beta function
13661371
"""
13671372

1368-
if n == 1:
1369-
return p * f * (q - 1) / (q * (p + 1))
1370-
13711373
p2n = p + 2 * n
13721374
F1 = p**2 * f**2 * (n - 1) / (q**2)
13731375
F2 = (
@@ -1377,7 +1379,11 @@ def _betainc_a_n(f, p, q, n):
13771379
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
13781380
)
13791381

1380-
return F1 * F2
1382+
return switch(
1383+
eq(n, 1),
1384+
p * f * (q - 1) / (q * (p + 1)),
1385+
F1 * F2,
1386+
)
13811387

13821388
def _betainc_b_n(f, p, q, n):
13831389
"""
@@ -1397,9 +1403,6 @@ def _betainc_da_n_dp(f, p, q, n):
13971403
Derivative of a_n wrt p
13981404
"""
13991405

1400-
if n == 1:
1401-
return -p * f * (q - 1) / (q * (p + 1) ** 2)
1402-
14031406
pp = p**2
14041407
ppp = pp * p
14051408
p2n = p + 2 * n
@@ -1414,20 +1417,25 @@ def _betainc_da_n_dp(f, p, q, n):
14141417
D1 = q**2 * (p2n - 3) ** 2
14151418
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2
14161419

1417-
return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2
1420+
return switch(
1421+
eq(n, 1),
1422+
-p * f * (q - 1) / (q * (p + 1) ** 2),
1423+
(N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2,
1424+
)
14181425

14191426
def _betainc_da_n_dq(f, p, q, n):
14201427
"""
14211428
Derivative of a_n wrt q
14221429
"""
1423-
if n == 1:
1424-
return p * f / (q * (p + 1))
1425-
14261430
p2n = p + 2 * n
14271431
F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
14281432
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)
14291433

1430-
return F1 / D1
1434+
return switch(
1435+
eq(n, 1),
1436+
p * f / (q * (p + 1)),
1437+
F1 / D1,
1438+
)
14311439

14321440
def _betainc_db_n_dp(f, p, q, n):
14331441
"""
@@ -1452,42 +1460,44 @@ def _betainc_db_n_dq(f, p, q, n):
14521460
p2n = p + 2 * n
14531461
return -(p**2 * f) / (q * (p2n - 2) * p2n)
14541462

1455-
# Input validation
1456-
if not (0 <= x <= 1) or p < 0 or q < 0:
1457-
return np.nan
1458-
1459-
if x > (p / (p + q)):
1460-
return -self.impl(q, p, 1 - x, not wrtp)
1461-
1462-
min_iters = 3
1463-
max_iters = 200
1464-
err_threshold = 1e-12
1465-
1466-
derivative_old = 0
1463+
min_iters = np.array(3, dtype="int32")
1464+
max_iters = switch(
1465+
skip_loop, np.array(0, dtype="int32"), np.array(200, dtype="int32")
1466+
)
1467+
err_threshold = np.array(1e-12, dtype=config.floatX)
14671468

1468-
Am2, Am1 = 1, 1
1469-
Bm2, Bm1 = 0, 1
1470-
dAm2, dAm1 = 0, 0
1471-
dBm2, dBm1 = 0, 0
1469+
Am2, Am1 = np.array(1, dtype=dtype), np.array(1, dtype=dtype)
1470+
Bm2, Bm1 = np.array(0, dtype=dtype), np.array(1, dtype=dtype)
1471+
dAm2, dAm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
1472+
dBm2, dBm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype)
14721473

14731474
f = (q * x) / (p * (1 - x))
1474-
K = np.exp(
1475-
p * np.log(x)
1476-
+ (q - 1) * np.log1p(-x)
1477-
- np.log(p)
1478-
- scipy.special.betaln(p, q)
1479-
)
1475+
K = exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - betaln(p, q))
14801476
if wrtp:
1481-
dK = (
1482-
np.log(x)
1483-
- 1 / p
1484-
+ scipy.special.digamma(p + q)
1485-
- scipy.special.digamma(p)
1486-
)
1477+
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
14871478
else:
1488-
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)
1489-
1490-
for n in range(1, max_iters + 1):
1479+
dK = log1p(-x) + psi(p + q) - psi(q)
1480+
1481+
derivative = np.array(0, dtype=dtype)
1482+
n = np.array(1, dtype="int16") # Enough for 200 max iters
1483+
1484+
def inner_loop(
1485+
derivative,
1486+
Am2,
1487+
Am1,
1488+
Bm2,
1489+
Bm1,
1490+
dAm2,
1491+
dAm1,
1492+
dBm2,
1493+
dBm1,
1494+
n,
1495+
f,
1496+
p,
1497+
q,
1498+
K,
1499+
dK,
1500+
):
14911501
a_n_ = _betainc_a_n(f, p, q, n)
14921502
b_n_ = _betainc_b_n(f, p, q, n)
14931503
if wrtp:
@@ -1502,36 +1512,53 @@ def _betainc_db_n_dq(f, p, q, n):
15021512
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
15031513
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1
15041514

1505-
Am2, Am1 = Am1, A
1506-
Bm2, Bm1 = Bm1, B
1507-
dAm2, dAm1 = dAm1, dA
1508-
dBm2, dBm1 = dBm1, dB
1509-
1510-
if n < min_iters - 1:
1511-
continue
1515+
Am2, Am1 = identity(Am1), identity(A)
1516+
Bm2, Bm1 = identity(Bm1), identity(B)
1517+
dAm2, dAm1 = identity(dAm1), identity(dA)
1518+
dBm2, dBm1 = identity(dBm1), identity(dB)
15121519

15131520
F1 = A / B
15141521
F2 = (dA - F1 * dB) / B
1515-
derivative = K * (F1 * dK + F2)
1522+
derivative_new = K * (F1 * dK + F2)
15161523

1517-
errapx = abs(derivative_old - derivative)
1518-
d_errapx = errapx / max(err_threshold, abs(derivative))
1519-
derivative_old = derivative
1520-
1521-
if d_errapx <= err_threshold:
1522-
return derivative
1524+
errapx = scalar_abs(derivative - derivative_new)
1525+
d_errapx = errapx / scalar_maximum(
1526+
err_threshold, scalar_abs(derivative_new)
1527+
)
15231528

1524-
warnings.warn(
1525-
f"betainc_der did not converge after {n} iterations",
1526-
RuntimeWarning,
1527-
)
1528-
return np.nan
1529+
min_iters_cond = n > (min_iters - 1)
1530+
derivative = switch(
1531+
min_iters_cond,
1532+
derivative_new,
1533+
derivative,
1534+
)
1535+
n += 1
15291536

1530-
def c_code(self, *args, **kwargs):
1531-
raise NotImplementedError()
1537+
return (
1538+
(derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n),
1539+
(d_errapx <= err_threshold) & min_iters_cond,
1540+
)
15321541

1542+
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
1543+
constant = [f, p, q, K, dK]
1544+
grad = _make_scalar_loop(
1545+
max_iters, init, constant, inner_loop, name="betainc_grad"
1546+
)
1547+
return grad
15331548

1534-
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1549+
# Input validation
1550+
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
1551+
flip_branch = x > (p / (p + q))
1552+
grad = switch(
1553+
nan_branch,
1554+
np.nan,
1555+
switch(
1556+
flip_branch,
1557+
-_betainc_der(q, p, 1 - x, not wrtp, skip_loop=nan_branch | (~flip_branch)),
1558+
_betainc_der(p, q, x, wrtp, skip_loop=nan_branch | flip_branch),
1559+
),
1560+
)
1561+
return grad
15351562

15361563

15371564
class Hyp2F1(ScalarOp):

tests/scalar/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.link.c.basic import CLinker
99
from pytensor.scalar.math import (
1010
betainc,
11-
betainc_der,
11+
betainc_grad,
1212
gammainc,
1313
gammaincc,
1414
gammal,
@@ -82,7 +82,7 @@ def test_betainc():
8282

8383
def test_betainc_derivative_nan():
8484
a, b, x = at.scalars("a", "b", "x")
85-
res = betainc_der(a, b, x, True)
85+
res = betainc_grad(a, b, x, True)
8686
test_func = function([a, b, x], res, mode=Mode("py"))
8787
assert not np.isnan(test_func(1, 1, 1))
8888
assert np.isnan(test_func(1, 1, -1))

0 commit comments

Comments
 (0)