Skip to content

Commit cc48800

Browse files
committed
Several scipy special functions now upcast integers to float64
Changed in scipy==1.14.0
1 parent f9dfe70 commit cc48800

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- python>=3.10
1111
- compilers
1212
- numpy>=1.17.0
13-
- scipy>=0.14
13+
- scipy>=1.14.0
1414
- filelock
1515
- etuples
1616
- logical-unification

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ keywords = [
4747
]
4848
dependencies = [
4949
"setuptools>=59.0.0",
50-
"scipy>=0.14",
50+
"scipy>=1.14.0",
5151
"numpy>=1.17.0,<2",
5252
"filelock",
5353
"etuples",

pytensor/scalar/basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,17 @@ def upgrade_to_float_no_complex(*types):
10271027
return upgrade_to_float(*types)
10281028

10291029

1030+
def upgrade_to_float64_no_complex(*types):
1031+
"""
1032+
Don't accept complex, otherwise call upgrade_to_float64().
1033+
1034+
"""
1035+
for type in types:
1036+
if type in complex_types:
1037+
raise TypeError("complex argument not supported")
1038+
return upgrade_to_float64(*types)
1039+
1040+
10301041
def same_out_nocomplex(type):
10311042
if type in complex_types:
10321043
raise TypeError("complex argument not supported")

pytensor/scalar/math.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
upcast,
4242
upgrade_to_float,
4343
upgrade_to_float64,
44+
upgrade_to_float64_no_complex,
4445
upgrade_to_float_no_complex,
4546
)
4647
from pytensor.scalar.basic import abs as scalar_abs
@@ -323,7 +324,7 @@ def c_code(self, node, name, inputs, outputs, sub):
323324
raise NotImplementedError("only floating point is implemented")
324325

325326

326-
gamma = Gamma(upgrade_to_float, name="gamma")
327+
gamma = Gamma(upgrade_to_float64, name="gamma")
327328

328329

329330
class GammaLn(UnaryScalarOp):
@@ -460,7 +461,7 @@ def c_code(self, node, name, inp, out, sub):
460461
raise NotImplementedError("only floating point is implemented")
461462

462463

463-
psi = Psi(upgrade_to_float, name="psi")
464+
psi = Psi(upgrade_to_float64, name="psi")
464465

465466

466467
class TriGamma(UnaryScalarOp):
@@ -549,7 +550,7 @@ def c_code(self, node, name, inp, out, sub):
549550

550551

551552
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
552-
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
553+
tri_gamma = TriGamma(upgrade_to_float64_no_complex, name="tri_gamma")
553554

554555

555556
class PolyGamma(BinaryScalarOp):
@@ -880,7 +881,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):
880881

881882
def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
882883
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
883-
sum_b += term
884+
sum_b += term.astype(dtype)
884885

885886
log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
886887
n += 1
@@ -1051,7 +1052,7 @@ def __hash__(self):
10511052
return hash(type(self))
10521053

10531054

1054-
gammau = GammaU(upgrade_to_float, name="gammau")
1055+
gammau = GammaU(upgrade_to_float64, name="gammau")
10551056

10561057

10571058
class GammaL(BinaryScalarOp):
@@ -1089,7 +1090,7 @@ def __hash__(self):
10891090
return hash(type(self))
10901091

10911092

1092-
gammal = GammaL(upgrade_to_float, name="gammal")
1093+
gammal = GammaL(upgrade_to_float64, name="gammal")
10931094

10941095

10951096
class Jv(BinaryScalarOp):
@@ -1335,7 +1336,7 @@ def c_code_cache_version(self):
13351336
return v
13361337

13371338

1338-
sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
1339+
sigmoid = Sigmoid(upgrade_to_float64, name="sigmoid")
13391340

13401341

13411342
class Softplus(UnaryScalarOp):
@@ -1631,6 +1632,7 @@ def _betainc_db_n_dq(f, p, q, n):
16311632
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
16321633
else:
16331634
dK = log1p(-x) + psi(p + q) - psi(q)
1635+
dK = dK.astype(dtype)
16341636

16351637
derivative = np.array(0, dtype=dtype)
16361638
n = np.array(1, dtype="int16") # Enough for 200 max iters

0 commit comments

Comments
 (0)