Skip to content

Commit a3a8751

Browse files
committed
Use ScalarLoop for gammainc(c) gradients
1 parent 0091e38 commit a3a8751

File tree

2 files changed

+207
-136
lines changed

2 files changed

+207
-136
lines changed

pytensor/scalar/math.py

Lines changed: 180 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,27 @@
1818
BinaryScalarOp,
1919
ScalarOp,
2020
UnaryScalarOp,
21+
as_scalar,
2122
complex_types,
23+
constant,
2224
discrete_types,
25+
eq,
2326
exp,
2427
expm1,
2528
float64,
2629
float_types,
2730
isinf,
2831
log,
2932
log1p,
33+
sqrt,
3034
switch,
3135
true_div,
3236
upcast,
3337
upgrade_to_float,
3438
upgrade_to_float64,
3539
upgrade_to_float_no_complex,
3640
)
41+
from pytensor.scalar.loop import ScalarLoop
3742

3843

3944
class Erf(UnaryScalarOp):
@@ -595,7 +600,7 @@ def grad(self, inputs, grads):
595600
(k, x) = inputs
596601
(gz,) = grads
597602
return [
598-
gz * gammainc_der(k, x),
603+
gz * gammainc_grad(k, x),
599604
gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
600605
]
601606

@@ -644,7 +649,7 @@ def grad(self, inputs, grads):
644649
(k, x) = inputs
645650
(gz,) = grads
646651
return [
647-
gz * gammaincc_der(k, x),
652+
gz * gammaincc_grad(k, x),
648653
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
649654
]
650655

@@ -675,162 +680,205 @@ def __hash__(self):
675680
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
676681

677682

678-
class GammaIncDer(BinaryScalarOp):
679-
"""
680-
Gradient of the the regularized lower gamma function (P) wrt to the first
681-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
683+
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
684+
init = [as_scalar(x) for x in init]
685+
constant = [as_scalar(x) for x in constant]
686+
# Create dummy types, in case some variables have the same initial form
687+
init_ = [x.type() for x in init]
688+
constant_ = [x.type() for x in constant]
689+
update_, until_ = inner_loop_fn(*init_, *constant_)
690+
op = ScalarLoop(
691+
init=init_,
692+
constant=constant_,
693+
update=update_,
694+
until=until_,
695+
until_condition_failed="warn",
696+
name=name,
697+
)
698+
S, *_ = op(n_steps, *init, *constant)
699+
return S
700+
701+
702+
def gammainc_grad(k, x):
703+
"""Gradient of the regularized lower gamma function (P) wrt to the first
704+
argument (k, a.k.a. alpha).
705+
706+
Adapted from STAN `grad_reg_lower_inc_gamma.hpp`
682707
683708
Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
684709
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
685710
"""
711+
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
686712

687-
def impl(self, k, x):
688-
if x == 0:
689-
return 0
690-
691-
sqrt_exp = -756 - x**2 + 60 * x
692-
if (
693-
(k < 0.8 and x > 15)
694-
or (k < 12 and x > 30)
695-
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
696-
):
697-
return -GammaIncCDer.st_impl(k, x)
698-
699-
precision = 1e-10
700-
max_iters = int(1e5)
713+
def grad_approx(skip_loop):
714+
precision = np.array(1e-10, dtype=config.floatX)
715+
max_iters = np.array(1e5, dtype="int32")
701716

702-
log_x = np.log(x)
703-
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1)
717+
log_x = log(x)
718+
log_gamma_k_plus_1 = gammaln(k + 1)
704719

705-
k_plus_n = k
720+
# First loop
721+
k_plus_n = k # Should not overflow unless k > 2,147,383,647
706722
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
707-
sum_a = 0.0
708-
for n in range(0, max_iters + 1):
709-
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
710-
sum_a += term
723+
sum_a0 = np.array(0.0, dtype=dtype)
711724

712-
if term <= precision:
713-
break
725+
def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x, skip_loop):
726+
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
727+
sum_a += term
714728

715-
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
729+
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
716730
k_plus_n += 1
717-
718-
if n >= max_iters:
719-
warnings.warn(
720-
f"gammainc_der did not converge after {n} iterations",
721-
RuntimeWarning,
731+
return (
732+
(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n),
733+
skip_loop | (term <= precision),
722734
)
723-
return np.nan
724735

725-
k_plus_n = k
736+
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
737+
constant = [log_x, skip_loop]
738+
sum_a = _make_scalar_loop(
739+
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
740+
)
741+
742+
# Second loop
743+
n = np.array(0, dtype="int32")
726744
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
727-
sum_b = 0.0
728-
for n in range(0, max_iters + 1):
729-
term = np.exp(
730-
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
731-
) * scipy.special.digamma(k_plus_n + 1)
732-
sum_b += term
745+
k_plus_n = k
746+
sum_b0 = np.array(0.0, dtype=dtype)
733747

734-
if term <= precision and n >= 1: # Require at least two iterations
735-
return np.exp(-x) * (log_x * sum_a - sum_b)
748+
def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
749+
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
750+
sum_b += term
736751

737-
log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
752+
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
753+
n += 1
738754
k_plus_n += 1
755+
return (
756+
(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n),
757+
# Require at least two iterations
758+
((term <= precision) & (n > 1)),
759+
)
739760

740-
warnings.warn(
741-
f"gammainc_der did not converge after {n} iterations",
742-
RuntimeWarning,
761+
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
762+
constant = [log_x]
763+
sum_b = _make_scalar_loop(
764+
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
743765
)
744-
return np.nan
745766

746-
def c_code(self, *args, **kwargs):
747-
raise NotImplementedError()
748-
749-
750-
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
751-
752-
753-
class GammaIncCDer(BinaryScalarOp):
767+
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
768+
return grad_approx
769+
770+
zero_branch = eq(x, 0)
771+
sqrt_exp = -756 - x**2 + 60 * x
772+
gammaincc_branch = (
773+
((k < 0.8) & (x > 15))
774+
| ((k < 12) & (x > 30))
775+
| ((sqrt_exp > 0) & (k < sqrt(sqrt_exp)))
776+
)
777+
grad = switch(
778+
zero_branch,
779+
0,
780+
switch(
781+
gammaincc_branch,
782+
-gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)),
783+
grad_approx(skip_loop=zero_branch | gammaincc_branch),
784+
),
785+
)
786+
return grad
787+
788+
789+
def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
790+
"""Gradient of the regularized upper gamma function (Q) wrt to the first
791+
argument (k, a.k.a. alpha).
792+
793+
Adapted from STAN `grad_reg_inc_gamma.hpp`
794+
795+
skip_loops is used for faster branching when this function is called by `gammainc_der`
754796
"""
755-
Gradient of the the regularized upper gamma function (Q) wrt to the first
756-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
757-
"""
758-
759-
@staticmethod
760-
def st_impl(k, x):
761-
gamma_k = scipy.special.gamma(k)
762-
digamma_k = scipy.special.digamma(k)
763-
log_x = np.log(x)
764-
765-
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
766-
if (x >= k) and (x >= 8):
767-
S = 0
768-
k_minus_one_minus_n = k - 1
769-
fac = k_minus_one_minus_n
770-
dfac = 1
771-
xpow = x
797+
dtype = upcast(k.type.dtype, x.type.dtype, "float32")
798+
799+
gamma_k = gamma(k)
800+
digamma_k = psi(k)
801+
log_x = log(x)
802+
803+
def approx_a(skip_loop):
804+
n_steps = np.array(9, dtype="int32")
805+
sum_a0 = np.array(0.0, dtype=dtype)
806+
dfac = np.array(1.0, dtype=dtype)
807+
xpow = x
808+
k_minus_one_minus_n = k - 1
809+
fac = k_minus_one_minus_n
810+
delta = true_div(dfac, xpow)
811+
812+
def inner_loop_a(
813+
sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x, skip_loop
814+
):
815+
sum_a += delta
816+
xpow *= x
817+
k_minus_one_minus_n -= 1
818+
dfac = k_minus_one_minus_n * dfac + fac
819+
fac *= k_minus_one_minus_n
772820
delta = dfac / xpow
821+
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), skip_loop
773822

774-
for n in range(1, 10):
775-
k_minus_one_minus_n -= 1
776-
S += delta
777-
xpow *= x
778-
dfac = k_minus_one_minus_n * dfac + fac
779-
fac *= k_minus_one_minus_n
780-
delta = dfac / xpow
781-
if np.isinf(delta):
782-
warnings.warn(
783-
"gammaincc_der did not converge",
784-
RuntimeWarning,
785-
)
786-
return np.nan
787-
823+
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
824+
constant = [x, skip_loop]
825+
sum_a = _make_scalar_loop(
826+
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
827+
)
828+
grad_approx_a = (
829+
gammaincc(k, x) * (log_x - digamma_k)
830+
+ exp(-x + (k - 1) * log_x) * sum_a / gamma_k
831+
)
832+
return grad_approx_a
833+
834+
def approx_b(skip_loop):
835+
max_iters = np.array(1e5, dtype="int32")
836+
log_precision = np.array(np.log(1e-6), dtype=config.floatX)
837+
838+
sum_b0 = np.array(0.0, dtype=dtype)
839+
log_s = np.array(0.0, dtype=dtype)
840+
s_sign = np.array(1, dtype="int8")
841+
n = np.array(1, dtype="int32")
842+
log_delta = log_s - 2 * log(k)
843+
844+
def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x, skip_loop):
845+
delta = exp(log_delta)
846+
sum_b += switch(s_sign > 0, delta, -delta)
847+
s_sign = -s_sign
848+
849+
# log will cast >int16 to float64
850+
log_s_inc = log_x - log(n)
851+
if log_s_inc.type.dtype != log_s.type.dtype:
852+
log_s_inc = log_s_inc.astype(log_s.type.dtype)
853+
log_s += log_s_inc
854+
855+
new_log_delta = log_s - 2 * log(n + k)
856+
if new_log_delta.type.dtype != log_delta.type.dtype:
857+
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
858+
log_delta = new_log_delta
859+
860+
n += 1
788861
return (
789-
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
790-
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
862+
(sum_b, log_s, s_sign, log_delta, n),
863+
(skip_loop | (log_delta <= log_precision)),
791864
)
792865

793-
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
794-
else:
795-
log_precision = np.log(1e-6)
796-
max_iters = int(1e5)
797-
S = 0
798-
log_s = 0.0
799-
s_sign = 1
800-
log_delta = log_s - 2 * np.log(k)
801-
for n in range(1, max_iters + 1):
802-
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
803-
s_sign = -s_sign
804-
log_s += log_x - np.log(n)
805-
log_delta = log_s - 2 * np.log(n + k)
806-
807-
if np.isinf(log_delta):
808-
warnings.warn(
809-
"gammaincc_der did not converge",
810-
RuntimeWarning,
811-
)
812-
return np.nan
813-
814-
if log_delta <= log_precision:
815-
return (
816-
scipy.special.gammainc(k, x) * (digamma_k - log_x)
817-
+ np.exp(k * log_x) * S / gamma_k
818-
)
819-
820-
warnings.warn(
821-
f"gammaincc_der did not converge after {n} iterations",
822-
RuntimeWarning,
823-
)
824-
return np.nan
825-
826-
def impl(self, k, x):
827-
return self.st_impl(k, x)
828-
829-
def c_code(self, *args, **kwargs):
830-
raise NotImplementedError()
831-
832-
833-
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
866+
init = [sum_b0, log_s, s_sign, log_delta, n]
867+
constant = [k, log_x, skip_loop]
868+
sum_b = _make_scalar_loop(
869+
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
870+
)
871+
grad_approx_b = (
872+
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
873+
)
874+
return grad_approx_b
875+
876+
branch_a = (x >= k) & (x >= 8)
877+
return switch(
878+
branch_a,
879+
approx_a(skip_loop=~branch_a | skip_loops),
880+
approx_b(skip_loop=branch_a | skip_loops),
881+
)
834882

835883

836884
class GammaU(BinaryScalarOp):

0 commit comments

Comments
 (0)