Skip to content

Commit f62401a

Browse files
committed
maintanance: unpin scipy
fix: cast to elemwise outputs to their respective dtypes fix: Relax scipy dependency, should work in both cases style: black wrap with asarray fix: make elemwise test check against dtype in the graph fix scalar issues Update pytensor/scalar/basic.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> fix test add a clarifying comment to checking nan fix: bool is deprecated in numpy deps: bound scipy version improve test
1 parent 4d0103b commit f62401a

File tree

7 files changed

+36
-34
lines changed

7 files changed

+36
-34
lines changed

environment-osx-arm64.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- python=>3.10
1111
- compilers
1212
- numpy>=1.17.0,<2
13-
- scipy>=0.14,<1.14.0
13+
- scipy>=1,<2
1414
- filelock>=3.15
1515
- etuples
1616
- logical-unification

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- python>=3.10
1111
- compilers
1212
- numpy>=1.17.0,<2
13-
- scipy>=0.14,<1.14.0
13+
- scipy>=1,<2
1414
- filelock>=3.15
1515
- etuples
1616
- logical-unification

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ keywords = [
4747
]
4848
dependencies = [
4949
"setuptools>=59.0.0",
50-
"scipy>=0.14,<1.14",
50+
"scipy>=1,<2",
5151
"numpy>=1.17.0,<2",
5252
"filelock>=3.15",
5353
"etuples",

pytensor/scalar/basic.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,14 +1140,25 @@ def output_types(self, types):
11401140
else:
11411141
raise NotImplementedError(f"Cannot calculate the output types for {self}")
11421142

1143+
@staticmethod
1144+
def _cast_scalar(x, dtype):
1145+
if hasattr(x, "astype"):
1146+
return x.astype(dtype)
1147+
elif dtype == "bool":
1148+
return np.bool_(x)
1149+
else:
1150+
return getattr(np, dtype)(x)
1151+
11431152
def perform(self, node, inputs, output_storage):
11441153
if self.nout == 1:
1145-
output_storage[0][0] = self.impl(*inputs)
1154+
dtype = node.outputs[0].dtype
1155+
output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype)
11461156
else:
11471157
variables = from_return_values(self.impl(*inputs))
11481158
assert len(variables) == len(output_storage)
1149-
for storage, variable in zip(output_storage, variables):
1150-
storage[0] = variable
1159+
for out, storage, variable in zip(node.outputs, output_storage, variables):
1160+
dtype = out.dtype
1161+
storage[0] = self._cast_scalar(variable, dtype)
11511162

11521163
def impl(self, *inputs):
11531164
raise MethodNotDefined("impl", type(self), self.__class__.__name__)

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -767,34 +767,16 @@ def perform(self, node, inputs, output_storage):
767767
for i, (variable, storage, nout) in enumerate(
768768
zip(variables, output_storage, node.outputs)
769769
):
770-
if getattr(variable, "dtype", "") == "object":
771-
# Since numpy 1.6, function created with numpy.frompyfunc
772-
# always return an ndarray with dtype object
773-
variable = np.asarray(variable, dtype=nout.dtype)
770+
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
774771

775772
if i in self.inplace_pattern:
776773
odat = inputs[self.inplace_pattern[i]]
777774
odat[...] = variable
778775
storage[0] = odat
779776

780-
# Sometimes NumPy return a Python type.
781-
# Some PyTensor op return a different dtype like floor, ceil,
782-
# trunc, eq, ...
783-
elif not isinstance(variable, np.ndarray) or variable.dtype != nout.dtype:
784-
variable = np.asarray(variable, nout.dtype)
785-
# The next line is needed for numpy 1.9. Otherwise
786-
# there are tests that fail in DebugMode.
787-
# Normally we would call pytensor.misc._asarray, but it
788-
# is faster to inline the code. We know that the dtype
789-
# are the same string, just different typenum.
790-
if np.dtype(nout.dtype).num != variable.dtype.num:
791-
variable = variable.view(dtype=nout.dtype)
792-
storage[0] = variable
793777
# numpy.real return a view!
794-
elif not variable.flags.owndata:
778+
if not variable.flags.owndata:
795779
storage[0] = variable.copy()
796-
else:
797-
storage[0] = variable
798780

799781
@staticmethod
800782
def _check_runtime_broadcast(node, inputs):

tests/scalar/test_loop.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,17 @@ def test_inner_composite(mode):
212212
y16 = op(n_steps, x16)
213213
assert y16.type.dtype == "float16"
214214

215-
fn32 = function([n_steps, x16], y16, mode=mode)
215+
fn16 = function([n_steps, x16], y16, mode=mode)
216+
out16 = fn16(n_steps=3, x16=np.array(4.73, dtype="float16"))
216217
np.testing.assert_allclose(
217-
fn32(n_steps=9, x16=np.array(4.73, dtype="float16")),
218-
4.73 + 9,
218+
out16,
219+
4.73 + 3,
219220
rtol=1e-3,
220221
)
222+
out16overflow = fn16(n_steps=9, x16=np.array(4.73, dtype="float16"))
223+
assert out16overflow.dtype == "float16"
224+
# with this dtype overflow happens
225+
assert np.isnan(out16overflow)
221226

222227

223228
@mode
@@ -243,8 +248,10 @@ def test_inner_loop(mode):
243248
y16 = outer_loop_op(n_steps, x16, n_steps)
244249
assert y16.type.dtype == "float16"
245250

246-
fn32 = function([n_steps, x16], y16, mode=mode)
251+
fn16 = function([n_steps, x16], y16, mode=mode)
252+
out16 = fn16(n_steps=3, x16=np.array(2.5, dtype="float16"))
253+
assert out16.dtype == "float16"
247254
np.testing.assert_allclose(
248-
fn32(n_steps=3, x16=np.array(2.5, dtype="float16")),
255+
out16,
249256
3**2 + 2.5,
250257
)

tests/tensor/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,15 +508,17 @@ def test_good(self):
508508
if not isinstance(expecteds, list | tuple):
509509
expecteds = (expecteds,)
510510

511-
for i, (variable, expected) in enumerate(zip(variables, expecteds)):
511+
for i, (variable, expected, out_symbol) in enumerate(
512+
zip(variables, expecteds, node.outputs)
513+
):
512514
condition = (
513-
variable.dtype != expected.dtype
515+
variable.dtype != out_symbol.type.dtype
514516
or variable.shape != expected.shape
515517
or not np.allclose(variable, expected, atol=eps, rtol=eps)
516518
)
517519
assert not condition, (
518520
f"Test {self.op}::{testname}: Output {i} gave the wrong"
519-
f" value. With inputs {inputs}, expected {expected} (dtype {expected.dtype}),"
521+
f" value. With inputs {inputs}, expected {expected} (dtype {out_symbol.type.dtype}),"
520522
f" got {variable} (dtype {variable.dtype}). eps={eps:f}"
521523
f" np.allclose returns {np.allclose(variable, expected, atol=eps, rtol=eps)} {np.allclose(variable, expected)}"
522524
)

0 commit comments

Comments
 (0)