Skip to content

Commit 5e4d2da

Browse files
committed
Several scipy special functions now upcast integers to float64
Changed in scipy==1.14.0
1 parent f9dfe70 commit 5e4d2da

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- python>=3.10
1111
- compilers
1212
- numpy>=1.17.0
13-
- scipy>=0.14
13+
- scipy>=1.14.0
1414
- filelock
1515
- etuples
1616
- logical-unification

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ keywords = [
4747
]
4848
dependencies = [
4949
"setuptools>=59.0.0",
50-
"scipy>=0.14",
50+
"scipy>=1.14.0",
5151
"numpy>=1.17.0,<2",
5252
"filelock",
5353
"etuples",

pytensor/scalar/basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,17 @@ def upgrade_to_float_no_complex(*types):
10271027
return upgrade_to_float(*types)
10281028

10291029

1030+
def upgrade_to_float64_no_complex(*types):
1031+
"""
1032+
Don't accept complex, otherwise call upgrade_to_float64().
1033+
1034+
"""
1035+
for type in types:
1036+
if type in complex_types:
1037+
raise TypeError("complex argument not supported")
1038+
return upgrade_to_float64(*types)
1039+
1040+
10301041
def same_out_nocomplex(type):
10311042
if type in complex_types:
10321043
raise TypeError("complex argument not supported")

pytensor/scalar/math.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
upcast,
4242
upgrade_to_float,
4343
upgrade_to_float64,
44+
upgrade_to_float64_no_complex,
4445
upgrade_to_float_no_complex,
4546
)
4647
from pytensor.scalar.basic import abs as scalar_abs
@@ -323,7 +324,7 @@ def c_code(self, node, name, inputs, outputs, sub):
323324
raise NotImplementedError("only floating point is implemented")
324325

325326

326-
gamma = Gamma(upgrade_to_float, name="gamma")
327+
gamma = Gamma(upgrade_to_float64, name="gamma")
327328

328329

329330
class GammaLn(UnaryScalarOp):
@@ -460,7 +461,7 @@ def c_code(self, node, name, inp, out, sub):
460461
raise NotImplementedError("only floating point is implemented")
461462

462463

463-
psi = Psi(upgrade_to_float, name="psi")
464+
psi = Psi(upgrade_to_float64, name="psi")
464465

465466

466467
class TriGamma(UnaryScalarOp):
@@ -549,7 +550,7 @@ def c_code(self, node, name, inp, out, sub):
549550

550551

551552
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
552-
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
553+
tri_gamma = TriGamma(upgrade_to_float64_no_complex, name="tri_gamma")
553554

554555

555556
class PolyGamma(BinaryScalarOp):
@@ -1051,7 +1052,7 @@ def __hash__(self):
10511052
return hash(type(self))
10521053

10531054

1054-
gammau = GammaU(upgrade_to_float, name="gammau")
1055+
gammau = GammaU(upgrade_to_float64, name="gammau")
10551056

10561057

10571058
class GammaL(BinaryScalarOp):
@@ -1089,7 +1090,7 @@ def __hash__(self):
10891090
return hash(type(self))
10901091

10911092

1092-
gammal = GammaL(upgrade_to_float, name="gammal")
1093+
gammal = GammaL(upgrade_to_float64, name="gammal")
10931094

10941095

10951096
class Jv(BinaryScalarOp):
@@ -1335,7 +1336,7 @@ def c_code_cache_version(self):
13351336
return v
13361337

13371338

1338-
sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
1339+
sigmoid = Sigmoid(upgrade_to_float64, name="sigmoid")
13391340

13401341

13411342
class Softplus(UnaryScalarOp):

0 commit comments

Comments
 (0)