Skip to content

Commit c963117

Browse files
committed
Use ScalarLoop for hyp2f1 gradient
1 parent ff7d38f commit c963117

File tree

2 files changed

+330
-281
lines changed

2 files changed

+330
-281
lines changed

pytensor/scalar/math.py

Lines changed: 160 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
import os
8-
import warnings
98
from textwrap import dedent
109

1110
import numpy as np
@@ -26,7 +25,9 @@
2625
expm1,
2726
float64,
2827
float_types,
28+
floor,
2929
identity,
30+
integer_types,
3031
isinf,
3132
log,
3233
log1p,
@@ -849,15 +850,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x, skip_loop):
849850
s_sign = -s_sign
850851

851852
# log will cast >int16 to float64
852-
log_s_inc = log_x - log(n)
853-
if log_s_inc.type.dtype != log_s.type.dtype:
854-
log_s_inc = log_s_inc.astype(log_s.type.dtype)
855-
log_s += log_s_inc
853+
log_s += log_x - log(n)
854+
if log_s.type.dtype != dtype:
855+
log_s = log_s.astype(dtype)
856856

857-
new_log_delta = log_s - 2 * log(n + k)
858-
if new_log_delta.type.dtype != log_delta.type.dtype:
859-
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
860-
log_delta = new_log_delta
857+
log_delta = log_s - 2 * log(n + k)
858+
if log_delta.type.dtype != dtype:
859+
log_delta = log_delta.astype(dtype)
861860

862861
n += 1
863862
return (
@@ -1576,9 +1575,9 @@ def grad(self, inputs, grads):
15761575
a, b, c, z = inputs
15771576
(gz,) = grads
15781577
return [
1579-
gz * hyp2f1_der(a, b, c, z, wrt=0),
1580-
gz * hyp2f1_der(a, b, c, z, wrt=1),
1581-
gz * hyp2f1_der(a, b, c, z, wrt=2),
1578+
gz * hyp2f1_grad(a, b, c, z, wrt=0),
1579+
gz * hyp2f1_grad(a, b, c, z, wrt=1),
1580+
gz * hyp2f1_grad(a, b, c, z, wrt=2),
15821581
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
15831582
]
15841583

@@ -1589,134 +1588,168 @@ def c_code(self, *args, **kwargs):
15891588
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
15901589

15911590

1592-
class Hyp2F1Der(ScalarOp):
1593-
"""
1594-
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1591+
def _unsafe_sign(x):
1592+
# Unlike scalar.sign we don't worry about x being 0 or nan
1593+
return switch(x > 0, 1, -1)
15951594

1596-
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1597-
"""
15981595

1599-
nin = 5
1596+
def hyp2f1_grad(a, b, c, z, wrt: int):
1597+
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
16001598

1601-
def impl(self, a, b, c, z, wrt):
1602-
def check_2f1_converges(a, b, c, z) -> bool:
1603-
num_terms = 0
1604-
is_polynomial = False
1599+
def check_2f1_converges(a, b, c, z):
1600+
def is_nonpositive_integer(x):
1601+
if x.type.dtype not in integer_types:
1602+
return eq(floor(x), x) & (x <= 0)
1603+
else:
1604+
return x <= 0
16051605

1606-
def is_nonpositive_integer(x):
1607-
return x <= 0 and x.is_integer()
1606+
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
1607+
num_terms = switch(
1608+
a_is_polynomial,
1609+
floor(scalar_abs(a)).astype("int64"),
1610+
0,
1611+
)
16081612

1609-
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1610-
is_polynomial = True
1611-
num_terms = int(np.floor(abs(a)))
1612-
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1613-
is_polynomial = True
1614-
num_terms = int(np.floor(abs(b)))
1613+
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
1614+
num_terms = switch(
1615+
b_is_polynomial,
1616+
floor(scalar_abs(b)).astype("int64"),
1617+
num_terms,
1618+
)
16151619

1616-
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1620+
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
1621+
is_polynomial = a_is_polynomial | b_is_polynomial
16171622

1618-
return not is_undefined and (
1619-
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1620-
)
1623+
return (~is_undefined) & (
1624+
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
1625+
)
16211626

1622-
def compute_grad_2f1(a, b, c, z, wrt):
1623-
"""
1624-
Notes
1625-
-----
1626-
The algorithm can be derived by looking at the ratio of two successive terms in the series
1627-
β_{k+1}/β_{k} = A(k)/B(k)
1628-
β_{k+1} = A(k)/B(k) * β_{k}
1629-
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1630-
1631-
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1632-
1633-
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1634-
by dropping the respective term
1635-
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1636-
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1637-
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1638-
1639-
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1640-
tracking their signs.
1641-
"""
1627+
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
1628+
"""
1629+
Notes
1630+
-----
1631+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1632+
β_{k+1}/β_{k} = A(k)/B(k)
1633+
β_{k+1} = A(k)/B(k) * β_{k}
1634+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
16421635
1643-
wrt_a = wrt_b = False
1644-
if wrt == 0:
1645-
wrt_a = True
1646-
elif wrt == 1:
1647-
wrt_b = True
1648-
elif wrt != 2:
1649-
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1650-
1651-
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1652-
max_steps = int(1e6)
1653-
precision = 1e-14
1654-
1655-
res = 0
1656-
1657-
if z == 0:
1658-
return res
1659-
1660-
log_g_old = -np.inf
1661-
log_t_old = 0.0
1662-
log_t_new = 0.0
1663-
sign_z = np.sign(z)
1664-
log_z = np.log(np.abs(z))
1665-
1666-
log_g_old_sign = 1
1667-
log_t_old_sign = 1
1668-
log_t_new_sign = 1
1669-
sign_zk = sign_z
1670-
1671-
for k in range(max_steps):
1672-
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1673-
if p == 0:
1674-
return res
1675-
log_t_new += np.log(np.abs(p)) + log_z
1676-
log_t_new_sign = np.sign(p) * log_t_new_sign
1677-
1678-
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1679-
if wrt_a:
1680-
term += np.reciprocal(a + k)
1681-
elif wrt_b:
1682-
term += np.reciprocal(b + k)
1683-
else:
1684-
term -= np.reciprocal(c + k)
1685-
1686-
log_g_old = log_t_new + np.log(np.abs(term))
1687-
log_g_old_sign = np.sign(term) * log_t_new_sign
1688-
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1689-
res += g_current
1690-
1691-
log_t_old = log_t_new
1692-
log_t_old_sign = log_t_new_sign
1693-
sign_zk *= sign_z
1694-
1695-
if k >= min_steps and np.abs(g_current) <= precision:
1696-
return res
1697-
1698-
warnings.warn(
1699-
f"hyp2f1_der did not converge after {k} iterations",
1700-
RuntimeWarning,
1701-
)
1702-
return np.nan
1636+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
17031637
1704-
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1705-
if not check_2f1_converges(a, b, c, z):
1706-
warnings.warn(
1707-
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1708-
RuntimeWarning,
1638+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639+
by dropping the respective term
1640+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643+
1644+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645+
tracking their signs.
1646+
"""
1647+
1648+
wrt_a = wrt_b = False
1649+
if wrt == 0:
1650+
wrt_a = True
1651+
elif wrt == 1:
1652+
wrt_b = True
1653+
elif wrt != 2:
1654+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1655+
1656+
min_steps = np.array(
1657+
10, dtype="int32"
1658+
) # https://github.com/stan-dev/math/issues/2857
1659+
max_steps = np.array(int(1e6), dtype="int32")
1660+
precision = np.array(1e-14, dtype=config.floatX)
1661+
1662+
grad = np.array(0, dtype=dtype)
1663+
1664+
log_g = np.array(-np.inf, dtype=dtype)
1665+
log_g_sign = np.array(1, dtype="int8")
1666+
1667+
log_t = np.array(0.0, dtype=dtype)
1668+
log_t_sign = np.array(1, dtype="int8")
1669+
1670+
log_z = log(scalar_abs(z))
1671+
sign_z = _unsafe_sign(z)
1672+
1673+
sign_zk = sign_z
1674+
k = np.array(0, dtype="int32")
1675+
1676+
def inner_loop(
1677+
grad,
1678+
log_g,
1679+
log_g_sign,
1680+
log_t,
1681+
log_t_sign,
1682+
sign_zk,
1683+
k,
1684+
a,
1685+
b,
1686+
c,
1687+
log_z,
1688+
sign_z,
1689+
skip_loop,
1690+
):
1691+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1692+
if p.type.dtype != dtype:
1693+
p = p.astype(dtype)
1694+
1695+
term = log_g_sign * log_t_sign * exp(log_g - log_t)
1696+
if wrt_a:
1697+
term += reciprocal(a + k)
1698+
elif wrt_b:
1699+
term += reciprocal(b + k)
1700+
else:
1701+
term -= reciprocal(c + k)
1702+
1703+
if term.type.dtype != dtype:
1704+
term = term.astype(dtype)
1705+
1706+
log_t = log_t + log(scalar_abs(p)) + log_z
1707+
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1708+
log_g = log_t + log(scalar_abs(term))
1709+
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
1710+
1711+
g_current = log_g_sign * exp(log_g) * sign_zk
1712+
1713+
# If p==0, don't update grad and get out of while loop next
1714+
grad = switch(
1715+
eq(p, 0),
1716+
grad,
1717+
grad + g_current,
17091718
)
1710-
return np.nan
17111719

1712-
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1720+
sign_zk *= sign_z
1721+
k += 1
17131722

1714-
def __call__(self, a, b, c, z, wrt, **kwargs):
1715-
# This allows wrt to be a keyword argument
1716-
return super().__call__(a, b, c, z, wrt, **kwargs)
1723+
return (
1724+
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
1725+
(
1726+
skip_loop
1727+
| eq(p, 0)
1728+
| ((k > min_steps) & (scalar_abs(g_current) <= precision))
1729+
),
1730+
)
17171731

1718-
def c_code(self, *args, **kwargs):
1719-
raise NotImplementedError()
1732+
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
1733+
constant = [a, b, c, log_z, sign_z, skip_loop]
1734+
grad = _make_scalar_loop(
1735+
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1736+
)
17201737

1738+
return switch(
1739+
eq(z, 0),
1740+
0,
1741+
grad,
1742+
)
17211743

1722-
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
1744+
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1745+
z_is_zero = eq(z, 0)
1746+
converges = check_2f1_converges(a, b, c, z)
1747+
return switch(
1748+
z_is_zero,
1749+
0,
1750+
switch(
1751+
converges,
1752+
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
1753+
np.nan,
1754+
),
1755+
)

0 commit comments

Comments
 (0)