Skip to content

Commit 1a36f5c

Browse files
committed
Fuse hyp2f1 grads
1 parent c963117 commit 1a36f5c

File tree

3 files changed

+283
-137
lines changed

3 files changed

+283
-137
lines changed

pytensor/scalar/math.py

Lines changed: 186 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"""
66

77
import os
8+
from functools import reduce
89
from textwrap import dedent
10+
from typing import Tuple
911

1012
import numpy as np
1113
import scipy.special
@@ -684,12 +686,18 @@ def __hash__(self):
684686

685687

686688
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
687-
init = [as_scalar(x) for x in init]
689+
init = [as_scalar(x) if x is not None else None for x in init]
688690
constant = [as_scalar(x) for x in constant]
691+
689692
# Create dummy types, in case some variables have the same initial form
690-
init_ = [x.type() for x in init]
693+
init_ = [x.type() if x is not None else None for x in init]
691694
constant_ = [x.type() for x in constant]
692695
update_, until_ = inner_loop_fn(*init_, *constant_)
696+
697+
# Filter Nones
698+
init = [i for i in init if i is not None]
699+
init_ = [i for i in init_ if i is not None]
700+
update_ = [u for u in update_ if u is not None]
693701
op = ScalarLoop(
694702
init=init_,
695703
constant=constant_,
@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
698706
until_condition_failed="warn",
699707
name=name,
700708
)
701-
S, *_ = op(n_steps, *init, *constant)
702-
return S
709+
return op(n_steps, *init, *constant)
703710

704711

705712
def gammainc_grad(k, x):
@@ -738,7 +745,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x, skip_loop):
738745

739746
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
740747
constant = [log_x, skip_loop]
741-
sum_a = _make_scalar_loop(
748+
sum_a, *_ = _make_scalar_loop(
742749
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
743750
)
744751

@@ -825,7 +832,7 @@ def inner_loop_a(
825832

826833
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
827834
constant = [x, skip_loop]
828-
sum_a = _make_scalar_loop(
835+
sum_a, *_ = _make_scalar_loop(
829836
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
830837
)
831838
grad_approx_a = (
@@ -866,7 +873,7 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x, skip_loop):
866873

867874
init = [sum_b0, log_s, s_sign, log_delta, n]
868875
constant = [k, log_x, skip_loop]
869-
sum_b = _make_scalar_loop(
876+
sum_b, *_ = _make_scalar_loop(
870877
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
871878
)
872879
grad_approx_b = (
@@ -1535,7 +1542,7 @@ def inner_loop(
15351542

15361543
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
15371544
constant = [f, p, q, K, dK, skip_loop]
1538-
grad = _make_scalar_loop(
1545+
grad, *_ = _make_scalar_loop(
15391546
max_iters, init, constant, inner_loop, name="betainc_grad"
15401547
)
15411548
return grad
@@ -1574,10 +1581,11 @@ def impl(self, a, b, c, z):
15741581
def grad(self, inputs, grads):
15751582
a, b, c, z = inputs
15761583
(gz,) = grads
1584+
grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2])
15771585
return [
1578-
gz * hyp2f1_grad(a, b, c, z, wrt=0),
1579-
gz * hyp2f1_grad(a, b, c, z, wrt=1),
1580-
gz * hyp2f1_grad(a, b, c, z, wrt=2),
1586+
gz * grad_a,
1587+
gz * grad_b,
1588+
gz * grad_c,
15811589
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
15821590
]
15831591

@@ -1593,7 +1601,158 @@ def _unsafe_sign(x):
15931601
return switch(x > 0, 1, -1)
15941602

15951603

1596-
def hyp2f1_grad(a, b, c, z, wrt: int):
1604+
def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
1605+
"""
1606+
Notes
1607+
-----
1608+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1609+
β_{k+1}/β_{k} = A(k)/B(k)
1610+
β_{k+1} = A(k)/B(k) * β_{k}
1611+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1612+
1613+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1614+
1615+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1616+
by dropping the respective term
1617+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1618+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1619+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1620+
1621+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1622+
tracking their signs.
1623+
"""
1624+
1625+
min_steps = np.array(
1626+
10, dtype="int32"
1627+
) # https://github.com/stan-dev/math/issues/2857
1628+
max_steps = np.array(int(1e6), dtype="int32")
1629+
precision = np.array(1e-14, dtype=config.floatX)
1630+
1631+
grads = [np.array(0, dtype=dtype) if i in wrt else None for i in range(3)]
1632+
log_gs = [np.array(-np.inf, dtype=dtype) if i in wrt else None for i in range(3)]
1633+
log_gs_signs = [np.array(1, dtype="int8") if i in wrt else None for i in range(3)]
1634+
1635+
log_t = np.array(0.0, dtype=dtype)
1636+
log_t_sign = np.array(1, dtype="int8")
1637+
1638+
log_z = log(scalar_abs(z))
1639+
sign_z = _unsafe_sign(z)
1640+
1641+
sign_zk = sign_z
1642+
k = np.array(0, dtype="int32")
1643+
1644+
def inner_loop(*args):
1645+
(
1646+
*grads_vars,
1647+
log_t,
1648+
log_t_sign,
1649+
sign_zk,
1650+
k,
1651+
a,
1652+
b,
1653+
c,
1654+
log_z,
1655+
sign_z,
1656+
skip_loop,
1657+
) = args
1658+
1659+
(
1660+
grad_a,
1661+
grad_b,
1662+
grad_c,
1663+
log_g_a,
1664+
log_g_b,
1665+
log_g_c,
1666+
log_g_sign_a,
1667+
log_g_sign_b,
1668+
log_g_sign_c,
1669+
) = grads_vars
1670+
1671+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1672+
if p.type.dtype != dtype:
1673+
p = p.astype(dtype)
1674+
1675+
# If p==0, don't update grad and get out of while loop next
1676+
p_zero = eq(p, 0)
1677+
1678+
if 0 in wrt:
1679+
term_a = log_g_sign_a * log_t_sign * exp(log_g_a - log_t)
1680+
term_a += reciprocal(a + k)
1681+
if term_a.type.dtype != dtype:
1682+
term_a = term_a.astype(dtype)
1683+
if 1 in wrt:
1684+
term_b = log_g_sign_b * log_t_sign * exp(log_g_b - log_t)
1685+
term_b += reciprocal(b + k)
1686+
if term_b.type.dtype != dtype:
1687+
term_b = term_b.astype(dtype)
1688+
if 2 in wrt:
1689+
term_c = log_g_sign_c * log_t_sign * exp(log_g_c - log_t)
1690+
term_c -= reciprocal(c + k)
1691+
if term_c.type.dtype != dtype:
1692+
term_c = term_c.astype(dtype)
1693+
1694+
log_t = log_t + log(scalar_abs(p)) + log_z
1695+
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1696+
1697+
grads = [None] * 3
1698+
log_gs = [None] * 3
1699+
log_gs_signs = [None] * 3
1700+
grad_incs = [None] * 3
1701+
1702+
if 0 in wrt:
1703+
log_g_a = log_t + log(scalar_abs(term_a))
1704+
log_g_sign_a = (_unsafe_sign(term_a) * log_t_sign).astype("int8")
1705+
grad_inc_a = log_g_sign_a * exp(log_g_a) * sign_zk
1706+
grads[0] = switch(p_zero, grad_a, grad_a + grad_inc_a)
1707+
log_gs[0] = log_g_a
1708+
log_gs_signs[0] = log_g_sign_a
1709+
grad_incs[0] = grad_inc_a
1710+
if 1 in wrt:
1711+
log_g_b = log_t + log(scalar_abs(term_b))
1712+
log_g_sign_b = (_unsafe_sign(term_b) * log_t_sign).astype("int8")
1713+
grad_inc_b = log_g_sign_b * exp(log_g_b) * sign_zk
1714+
grads[1] = switch(p_zero, grad_b, grad_b + grad_inc_b)
1715+
log_gs[1] = log_g_b
1716+
log_gs_signs[1] = log_g_sign_b
1717+
grad_incs[1] = grad_inc_b
1718+
if 2 in wrt:
1719+
log_g_c = log_t + log(scalar_abs(term_c))
1720+
log_g_sign_c = (_unsafe_sign(term_c) * log_t_sign).astype("int8")
1721+
grad_inc_c = log_g_sign_c * exp(log_g_c) * sign_zk
1722+
grads[2] = switch(p_zero, grad_c, grad_c + grad_inc_c)
1723+
log_gs[2] = log_g_c
1724+
log_gs_signs[2] = log_g_sign_c
1725+
grad_incs[2] = grad_inc_c
1726+
1727+
sign_zk *= sign_z
1728+
k += 1
1729+
1730+
abs_grad_incs = [
1731+
scalar_abs(grad_inc) for grad_inc in grad_incs if grad_inc is not None
1732+
]
1733+
if len(grad_incs) == 1:
1734+
[max_abs_grad_inc] = grad_incs
1735+
else:
1736+
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
1737+
1738+
return (
1739+
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
1740+
(
1741+
skip_loop
1742+
| eq(p, 0)
1743+
| ((k > min_steps) & (max_abs_grad_inc <= precision))
1744+
),
1745+
)
1746+
1747+
init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
1748+
constant = [a, b, c, log_z, sign_z, skip_loop]
1749+
loop_outs = _make_scalar_loop(
1750+
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1751+
)
1752+
return loop_outs[: len(wrt)]
1753+
1754+
1755+
def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
15971756
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
15981757

15991758
def check_2f1_converges(a, b, c, z):
@@ -1624,132 +1783,22 @@ def is_nonpositive_integer(x):
16241783
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
16251784
)
16261785

1627-
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
1628-
"""
1629-
Notes
1630-
-----
1631-
The algorithm can be derived by looking at the ratio of two successive terms in the series
1632-
β_{k+1}/β_{k} = A(k)/B(k)
1633-
β_{k+1} = A(k)/B(k) * β_{k}
1634-
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1635-
1636-
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1637-
1638-
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639-
by dropping the respective term
1640-
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641-
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642-
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643-
1644-
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645-
tracking their signs.
1646-
"""
1647-
1648-
wrt_a = wrt_b = False
1649-
if wrt == 0:
1650-
wrt_a = True
1651-
elif wrt == 1:
1652-
wrt_b = True
1653-
elif wrt != 2:
1654-
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1655-
1656-
min_steps = np.array(
1657-
10, dtype="int32"
1658-
) # https://github.com/stan-dev/math/issues/2857
1659-
max_steps = np.array(int(1e6), dtype="int32")
1660-
precision = np.array(1e-14, dtype=config.floatX)
1661-
1662-
grad = np.array(0, dtype=dtype)
1663-
1664-
log_g = np.array(-np.inf, dtype=dtype)
1665-
log_g_sign = np.array(1, dtype="int8")
1666-
1667-
log_t = np.array(0.0, dtype=dtype)
1668-
log_t_sign = np.array(1, dtype="int8")
1669-
1670-
log_z = log(scalar_abs(z))
1671-
sign_z = _unsafe_sign(z)
1672-
1673-
sign_zk = sign_z
1674-
k = np.array(0, dtype="int32")
1675-
1676-
def inner_loop(
1677-
grad,
1678-
log_g,
1679-
log_g_sign,
1680-
log_t,
1681-
log_t_sign,
1682-
sign_zk,
1683-
k,
1684-
a,
1685-
b,
1686-
c,
1687-
log_z,
1688-
sign_z,
1689-
skip_loop,
1690-
):
1691-
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1692-
if p.type.dtype != dtype:
1693-
p = p.astype(dtype)
1694-
1695-
term = log_g_sign * log_t_sign * exp(log_g - log_t)
1696-
if wrt_a:
1697-
term += reciprocal(a + k)
1698-
elif wrt_b:
1699-
term += reciprocal(b + k)
1700-
else:
1701-
term -= reciprocal(c + k)
1702-
1703-
if term.type.dtype != dtype:
1704-
term = term.astype(dtype)
1705-
1706-
log_t = log_t + log(scalar_abs(p)) + log_z
1707-
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1708-
log_g = log_t + log(scalar_abs(term))
1709-
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
1710-
1711-
g_current = log_g_sign * exp(log_g) * sign_zk
1712-
1713-
# If p==0, don't update grad and get out of while loop next
1714-
grad = switch(
1715-
eq(p, 0),
1716-
grad,
1717-
grad + g_current,
1718-
)
1719-
1720-
sign_zk *= sign_z
1721-
k += 1
1722-
1723-
return (
1724-
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
1725-
(
1726-
skip_loop
1727-
| eq(p, 0)
1728-
| ((k > min_steps) & (scalar_abs(g_current) <= precision))
1729-
),
1730-
)
1731-
1732-
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
1733-
constant = [a, b, c, log_z, sign_z, skip_loop]
1734-
grad = _make_scalar_loop(
1735-
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1736-
)
1737-
1738-
return switch(
1739-
eq(z, 0),
1740-
0,
1741-
grad,
1742-
)
1743-
17441786
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
17451787
z_is_zero = eq(z, 0)
17461788
converges = check_2f1_converges(a, b, c, z)
1747-
return switch(
1748-
z_is_zero,
1749-
0,
1750-
switch(
1751-
converges,
1752-
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
1753-
np.nan,
1754-
),
1789+
grads = _grad_2f1_loop(
1790+
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
17551791
)
1792+
1793+
return [
1794+
switch(
1795+
z_is_zero,
1796+
0,
1797+
switch(
1798+
converges,
1799+
grad,
1800+
np.nan,
1801+
),
1802+
)
1803+
for grad in grads
1804+
]

0 commit comments

Comments
 (0)