Skip to content

Several scipy special functions now upcast integers to float64 #859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- python>=3.10
- compilers
- numpy>=1.17.0
- scipy>=0.14
- scipy>=1.14.0
- filelock
- etuples
- logical-unification
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ keywords = [
]
dependencies = [
"setuptools>=59.0.0",
"scipy>=0.14",
"scipy>=1.14.0",
"numpy>=1.17.0,<2",
"filelock",
"etuples",
Expand Down
11 changes: 11 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,17 @@ def upgrade_to_float_no_complex(*types):
return upgrade_to_float(*types)


def upgrade_to_float64_no_complex(*types):
"""
Don't accept complex, otherwise call upgrade_to_float64().

"""
for type in types:
if type in complex_types:
raise TypeError("complex argument not supported")
return upgrade_to_float64(*types)


def same_out_nocomplex(type):
if type in complex_types:
raise TypeError("complex argument not supported")
Expand Down
16 changes: 9 additions & 7 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
upcast,
upgrade_to_float,
upgrade_to_float64,
upgrade_to_float64_no_complex,
upgrade_to_float_no_complex,
)
from pytensor.scalar.basic import abs as scalar_abs
Expand Down Expand Up @@ -323,7 +324,7 @@ def c_code(self, node, name, inputs, outputs, sub):
raise NotImplementedError("only floating point is implemented")


gamma = Gamma(upgrade_to_float, name="gamma")
gamma = Gamma(upgrade_to_float64, name="gamma")


class GammaLn(UnaryScalarOp):
Expand Down Expand Up @@ -460,7 +461,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floating point is implemented")


psi = Psi(upgrade_to_float, name="psi")
psi = Psi(upgrade_to_float64, name="psi")


class TriGamma(UnaryScalarOp):
Expand Down Expand Up @@ -549,7 +550,7 @@ def c_code(self, node, name, inp, out, sub):


# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
tri_gamma = TriGamma(upgrade_to_float64_no_complex, name="tri_gamma")


class PolyGamma(BinaryScalarOp):
Expand Down Expand Up @@ -880,7 +881,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):

def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
sum_b += term
sum_b += term.astype(dtype)

log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
n += 1
Expand Down Expand Up @@ -1051,7 +1052,7 @@ def __hash__(self):
return hash(type(self))


gammau = GammaU(upgrade_to_float, name="gammau")
gammau = GammaU(upgrade_to_float64, name="gammau")


class GammaL(BinaryScalarOp):
Expand Down Expand Up @@ -1089,7 +1090,7 @@ def __hash__(self):
return hash(type(self))


gammal = GammaL(upgrade_to_float, name="gammal")
gammal = GammaL(upgrade_to_float64, name="gammal")


class Jv(BinaryScalarOp):
Expand Down Expand Up @@ -1335,7 +1336,7 @@ def c_code_cache_version(self):
return v


sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
sigmoid = Sigmoid(upgrade_to_float64, name="sigmoid")


class Softplus(UnaryScalarOp):
Expand Down Expand Up @@ -1631,6 +1632,7 @@ def _betainc_db_n_dq(f, p, q, n):
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
else:
dK = log1p(-x) + psi(p + q) - psi(q)
dK = dK.astype(dtype)

derivative = np.array(0, dtype=dtype)
n = np.array(1, dtype="int16") # Enough for 200 max iters
Expand Down
Loading