Skip to content

Commit ccae99d

Browse files
committed
Use ScalarLoop for gammainc(c) gradients
1 parent 468a458 commit ccae99d

File tree

2 files changed

+209
-134
lines changed

2 files changed

+209
-134
lines changed

pytensor/scalar/math.py

Lines changed: 182 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,27 @@
1818
BinaryScalarOp,
1919
ScalarOp,
2020
UnaryScalarOp,
21+
as_scalar,
2122
complex_types,
23+
constant,
2224
discrete_types,
25+
eq,
2326
exp,
2427
expm1,
2528
float64,
2629
float_types,
2730
isinf,
2831
log,
2932
log1p,
33+
sqrt,
3034
switch,
3135
true_div,
3236
upcast,
3337
upgrade_to_float,
3438
upgrade_to_float64,
3539
upgrade_to_float_no_complex,
3640
)
41+
from pytensor.scalar.loop import ScalarLoop
3742

3843

3944
class Erf(UnaryScalarOp):
@@ -595,7 +600,7 @@ def grad(self, inputs, grads):
595600
(k, x) = inputs
596601
(gz,) = grads
597602
return [
598-
gz * gammainc_der(k, x),
603+
gz * gammainc_grad(k, x),
599604
gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
600605
]
601606

@@ -644,7 +649,7 @@ def grad(self, inputs, grads):
644649
(k, x) = inputs
645650
(gz,) = grads
646651
return [
647-
gz * gammaincc_der(k, x),
652+
gz * gammaincc_grad(k, x),
648653
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
649654
]
650655

@@ -675,162 +680,209 @@ def __hash__(self):
675680
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
676681

677682

678-
class GammaIncDer(BinaryScalarOp):
679-
"""
680-
Gradient of the the regularized lower gamma function (P) wrt to the first
681-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
683+
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
684+
init = [as_scalar(x) for x in init]
685+
constant = [as_scalar(x) for x in constant]
686+
# Create dummy types, in case some variables have the same initial form
687+
init_ = [x.type() for x in init]
688+
constant_ = [x.type() for x in constant]
689+
update_, until_ = inner_loop_fn(*init_, *constant_)
690+
op = ScalarLoop(
691+
init=init_,
692+
constant=constant_,
693+
update=update_,
694+
until=until_,
695+
until_condition_failed="warn",
696+
name=name,
697+
)
698+
S, *_ = op(n_steps, *init, *constant)
699+
return S
700+
701+
702+
def gammainc_grad(k, x):
703+
"""Gradient of the regularized lower gamma function (P) wrt to the first
704+
argument (k, a.k.a. alpha).
705+
706+
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
682707
683708
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
684709
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
685710
"""
711+
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
686712

687-
def impl(self, k, x):
688-
if x == 0:
689-
return 0
690-
691-
sqrt_exp = -756 - x**2 + 60 * x
692-
if (
693-
(k < 0.8 and x > 15)
694-
or (k < 12 and x > 30)
695-
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
696-
):
697-
return -GammaIncCDer.st_impl(k, x)
698-
699-
precision = 1e-10
700-
max_iters = int(1e5)
713+
def grad_approx(skip_loop):
714+
precision = np.array(1e-10, dtype=config.floatX)
715+
max_iters = switch(
716+
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
717+
)
701718

702-
log_x = np.log(x)
703-
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1)
719+
log_x = log(x)
720+
log_gamma_k_plus_1 = gammaln(k + 1)
704721

705-
k_plus_n = k
722+
# First loop
723+
k_plus_n = k # Should not overflow unless k > 2,147,383,647
706724
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
707-
sum_a = 0.0
708-
for n in range(0, max_iters + 1):
709-
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
710-
sum_a += term
725+
sum_a0 = np.array(0.0, dtype=dtype)
711726

712-
if term <= precision:
713-
break
727+
def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
728+
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
729+
sum_a += term
714730

715-
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
731+
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
716732
k_plus_n += 1
717-
718-
if n >= max_iters:
719-
warnings.warn(
720-
f"gammainc_der did not converge after {n} iterations",
721-
RuntimeWarning,
733+
return (
734+
(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n),
735+
(term <= precision),
722736
)
723-
return np.nan
724737

725-
k_plus_n = k
738+
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
739+
constant = [log_x]
740+
sum_a = _make_scalar_loop(
741+
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
742+
)
743+
744+
# Second loop
745+
n = np.array(0, dtype="int32")
726746
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
727-
sum_b = 0.0
728-
for n in range(0, max_iters + 1):
729-
term = np.exp(
730-
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
731-
) * scipy.special.digamma(k_plus_n + 1)
732-
sum_b += term
747+
k_plus_n = k
748+
sum_b0 = np.array(0.0, dtype=dtype)
733749

734-
if term <= precision and n >= 1: # Require at least two iterations
735-
return np.exp(-x) * (log_x * sum_a - sum_b)
750+
def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
751+
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
752+
sum_b += term
736753

737-
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
754+
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
755+
n += 1
738756
k_plus_n += 1
757+
return (
758+
(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n),
759+
# Require at least two iterations
760+
((term <= precision) & (n > 1)),
761+
)
739762

740-
warnings.warn(
741-
f"gammainc_der did not converge after {n} iterations",
742-
RuntimeWarning,
763+
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
764+
constant = [log_x]
765+
sum_b, *_ = _make_scalar_loop(
766+
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
743767
)
744-
return np.nan
745-
746-
def c_code(self, *args, **kwargs):
747-
raise NotImplementedError()
748-
749-
750-
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
751768

752-
753-
class GammaIncCDer(BinaryScalarOp):
754-
"""
755-
Gradient of the the regularized upper gamma function (Q) wrt to the first
756-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
769+
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
770+
return grad_approx
771+
772+
zero_branch = eq(x, 0)
773+
sqrt_exp = -756 - x**2 + 60 * x
774+
gammaincc_branch = (
775+
((k < 0.8) & (x > 15))
776+
| ((k < 12) & (x > 30))
777+
| ((sqrt_exp > 0) & (k < sqrt(sqrt_exp)))
778+
)
779+
grad = switch(
780+
zero_branch,
781+
0,
782+
switch(
783+
gammaincc_branch,
784+
-gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)),
785+
grad_approx(skip_loop=zero_branch | gammaincc_branch),
786+
),
787+
)
788+
return grad
789+
790+
791+
def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
792+
"""Gradient of the regularized upper gamma function (Q) wrt to the first
793+
argument (k, a.k.a. alpha).
794+
795+
Adapted from STAN `grad_reg_inc_gamma.hpp`
796+
797+
skip_loops is used for faster branching when this function is called by `gammainc_der`
757798
"""
799+
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
758800

759-
@staticmethod
760-
def st_impl(k, x):
761-
gamma_k = scipy.special.gamma(k)
762-
digamma_k = scipy.special.digamma(k)
763-
log_x = np.log(x)
764-
765-
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
766-
if (x >= k) and (x >= 8):
767-
S = 0
768-
k_minus_one_minus_n = k - 1
769-
fac = k_minus_one_minus_n
770-
dfac = 1
771-
xpow = x
801+
gamma_k = gamma(k)
802+
digamma_k = psi(k)
803+
log_x = log(x)
804+
805+
def approx_a(skip_loop):
806+
n_steps = switch(
807+
skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32")
808+
)
809+
sum_a0 = np.array(0.0, dtype=dtype)
810+
dfac = np.array(1.0, dtype=dtype)
811+
xpow = x
812+
k_minus_one_minus_n = k - 1
813+
fac = k_minus_one_minus_n
814+
delta = true_div(dfac, xpow)
815+
816+
def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
817+
sum_a += delta
818+
xpow *= x
819+
k_minus_one_minus_n -= 1
820+
dfac = k_minus_one_minus_n * dfac + fac
821+
fac *= k_minus_one_minus_n
772822
delta = dfac / xpow
823+
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
773824

774-
for n in range(1, 10):
775-
k_minus_one_minus_n -= 1
776-
S += delta
777-
xpow *= x
778-
dfac = k_minus_one_minus_n * dfac + fac
779-
fac *= k_minus_one_minus_n
780-
delta = dfac / xpow
781-
if np.isinf(delta):
782-
warnings.warn(
783-
"gammaincc_der did not converge",
784-
RuntimeWarning,
785-
)
786-
return np.nan
825+
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
826+
constant = [x]
827+
sum_a = _make_scalar_loop(
828+
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
829+
)
830+
grad_approx_a = (
831+
gammaincc(k, x) * (log_x - digamma_k)
832+
+ exp(-x + (k - 1) * log_x) * sum_a / gamma_k
833+
)
834+
return grad_approx_a
787835

836+
def approx_b(skip_loop):
837+
max_iters = switch(
838+
skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32")
839+
)
840+
log_precision = np.array(np.log(1e-6), dtype=config.floatX)
841+
842+
sum_b0 = np.array(0.0, dtype=dtype)
843+
log_s = np.array(0.0, dtype=dtype)
844+
s_sign = np.array(1, dtype="int8")
845+
n = np.array(1, dtype="int32")
846+
log_delta = log_s - 2 * log(k)
847+
848+
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
849+
delta = exp(log_delta)
850+
sum_b += switch(s_sign > 0, delta, -delta)
851+
s_sign = -s_sign
852+
853+
# log will cast >int16 to float64
854+
log_s_inc = log_x - log(n)
855+
if log_s_inc.type.dtype != log_s.type.dtype:
856+
log_s_inc = log_s_inc.astype(log_s.type.dtype)
857+
log_s += log_s_inc
858+
859+
new_log_delta = log_s - 2 * log(n + k)
860+
if new_log_delta.type.dtype != log_delta.type.dtype:
861+
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
862+
log_delta = new_log_delta
863+
864+
n += 1
788865
return (
789-
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
790-
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
791-
)
792-
793-
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
794-
else:
795-
log_precision = np.log(1e-6)
796-
max_iters = int(1e5)
797-
S = 0
798-
log_s = 0.0
799-
s_sign = 1
800-
log_delta = log_s - 2 * np.log(k)
801-
for n in range(1, max_iters + 1):
802-
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
803-
s_sign = -s_sign
804-
log_s += log_x - np.log(n)
805-
log_delta = log_s - 2 * np.log(n + k)
806-
807-
if np.isinf(log_delta):
808-
warnings.warn(
809-
"gammaincc_der did not converge",
810-
RuntimeWarning,
811-
)
812-
return np.nan
813-
814-
if log_delta <= log_precision:
815-
return (
816-
scipy.special.gammainc(k, x) * (digamma_k - log_x)
817-
+ np.exp(k * log_x) * S / gamma_k
818-
)
819-
820-
warnings.warn(
821-
f"gammaincc_der did not converge after {n} iterations",
822-
RuntimeWarning,
866+
(sum_b, log_s, s_sign, log_delta, n),
867+
log_delta <= log_precision,
823868
)
824-
return np.nan
825-
826-
def impl(self, k, x):
827-
return self.st_impl(k, x)
828-
829-
def c_code(self, *args, **kwargs):
830-
raise NotImplementedError()
831869

832-
833-
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
870+
init = [sum_b0, log_s, s_sign, log_delta, n]
871+
constant = [k, log_x]
872+
sum_b = _make_scalar_loop(
873+
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
874+
)
875+
grad_approx_b = (
876+
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
877+
)
878+
return grad_approx_b
879+
880+
branch_a = (x >= k) & (x >= 8)
881+
return switch(
882+
branch_a,
883+
approx_a(skip_loop=~branch_a | skip_loops),
884+
approx_b(skip_loop=branch_a | skip_loops),
885+
)
834886

835887

836888
class GammaU(BinaryScalarOp):

0 commit comments

Comments
 (0)