diff --git a/environment.yml b/environment.yml index 6be8e376da..d066c653d8 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - python>=3.10 - compilers - numpy>=1.17.0 - - scipy>=0.14 + - scipy>=1.14.0 - filelock - etuples - logical-unification diff --git a/pyproject.toml b/pyproject.toml index 31f0e1bb24..6ce096d7b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ keywords = [ ] dependencies = [ "setuptools>=59.0.0", - "scipy>=0.14", + "scipy>=1.14.0", "numpy>=1.17.0,<2", "filelock", "etuples", diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 2a3db168ba..49b6ebff07 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1027,6 +1027,17 @@ def upgrade_to_float_no_complex(*types): return upgrade_to_float(*types) +def upgrade_to_float64_no_complex(*types): + """ + Don't accept complex, otherwise call upgrade_to_float64(). + + """ + for type in types: + if type in complex_types: + raise TypeError("complex argument not supported") + return upgrade_to_float64(*types) + + def same_out_nocomplex(type): if type in complex_types: raise TypeError("complex argument not supported") diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index ac66fbd698..33c941679c 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -41,6 +41,7 @@ upcast, upgrade_to_float, upgrade_to_float64, + upgrade_to_float64_no_complex, upgrade_to_float_no_complex, ) from pytensor.scalar.basic import abs as scalar_abs @@ -323,7 +324,7 @@ def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("only floating point is implemented") -gamma = Gamma(upgrade_to_float, name="gamma") +gamma = Gamma(upgrade_to_float64, name="gamma") class GammaLn(UnaryScalarOp): @@ -460,7 +461,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floating point is implemented") -psi = Psi(upgrade_to_float, name="psi") +psi = Psi(upgrade_to_float64, name="psi") class TriGamma(UnaryScalarOp): @@ -549,7 +550,7 @@ def c_code(self, node, name, inp, out, sub): # Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410 -tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma") +tri_gamma = TriGamma(upgrade_to_float64_no_complex, name="tri_gamma") class PolyGamma(BinaryScalarOp): @@ -880,7 +881,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x): def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x): term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1) - sum_b += term + sum_b += term.astype(dtype) log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) n += 1 @@ -1051,7 +1052,7 @@ def __hash__(self): return hash(type(self)) -gammau = GammaU(upgrade_to_float, name="gammau") +gammau = GammaU(upgrade_to_float64, name="gammau") class GammaL(BinaryScalarOp): @@ -1089,7 +1090,7 @@ def __hash__(self): return hash(type(self)) -gammal = GammaL(upgrade_to_float, name="gammal") +gammal = GammaL(upgrade_to_float64, name="gammal") class Jv(BinaryScalarOp): @@ -1335,7 +1336,7 @@ def c_code_cache_version(self): return v -sigmoid = Sigmoid(upgrade_to_float, name="sigmoid") +sigmoid = Sigmoid(upgrade_to_float64, name="sigmoid") class Softplus(UnaryScalarOp): @@ -1631,6 +1632,7 @@ def _betainc_db_n_dq(f, p, q, n): dK = log(x) - reciprocal(p) + psi(p + q) - psi(p) else: dK = log1p(-x) + psi(p + q) - psi(q) + dK = dK.astype(dtype) derivative = np.array(0, dtype=dtype) n = np.array(1, dtype="int16") # Enough for 200 max iters