Skip to content

Commit ba10a3a

Browse files
Armavicabrendan-m-murphy
authored andcommitted
Changes for numpy 2.0 deprecations
- Replace np.cast with np.asarray: in numpy 2.0, `np.cast[new_dtype](arr)` is deprecated. The literal replacement is `np.asarray(arr, dtype=new_dtype)`. - Replace np.sctype2char and np.obj2sctype. Added try/except to handle change in behavior of `np.dtype` - Replace np.find_common_type with np.result_type Further changes to `TensorType`: TensorType.dtype must be a string, so the code has been changed from `self.dtype = np.dtype(dtype).type`, where the right-hand side is of type `np.generic`, to `self.dtype = str(np.dtype(dtype))`, where the right-hand side is a string that satisfies: `self.dtype == str(np.dtype(self.dtype))` This doesn't change the behavior of `np.array(..., dtype=self.dtype)` etc.
1 parent 47d0124 commit ba10a3a

File tree

8 files changed

+51
-46
lines changed

8 files changed

+51
-46
lines changed

pytensor/scalar/basic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2966,7 +2966,7 @@ def L_op(self, inputs, outputs, gout):
29662966
else:
29672967
return [x.zeros_like()]
29682968

2969-
return (gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)),)
2969+
return (gz / (x * np.array(math.log(2.0), dtype=x.dtype)),)
29702970

29712971
def c_code(self, node, name, inputs, outputs, sub):
29722972
(x,) = inputs
@@ -3009,7 +3009,7 @@ def L_op(self, inputs, outputs, gout):
30093009
else:
30103010
return [x.zeros_like()]
30113011

3012-
return (gz / (x * np.asarray(math.log(10.0)).astype(x.dtype)),)
3012+
return (gz / (x * np.array(math.log(10.0), dtype=x.dtype)),)
30133013

30143014
def c_code(self, node, name, inputs, outputs, sub):
30153015
(x,) = inputs
@@ -3124,7 +3124,7 @@ def L_op(self, inputs, outputs, gout):
31243124
else:
31253125
return [x.zeros_like()]
31263126

3127-
return (gz * exp2(x) * log(np.cast[x.type](2)),)
3127+
return (gz * exp2(x) * log(np.array(2, dtype=x.type)),)
31283128

31293129
def c_code(self, node, name, inputs, outputs, sub):
31303130
(x,) = inputs
@@ -3263,7 +3263,7 @@ def L_op(self, inputs, outputs, gout):
32633263
else:
32643264
return [x.zeros_like()]
32653265

3266-
return (gz * np.asarray(np.pi / 180, gz.type),)
3266+
return (gz * np.array(np.pi / 180, dtype=gz.type),)
32673267

32683268
def c_code(self, node, name, inputs, outputs, sub):
32693269
(x,) = inputs
@@ -3298,7 +3298,7 @@ def L_op(self, inputs, outputs, gout):
32983298
else:
32993299
return [x.zeros_like()]
33003300

3301-
return (gz * np.asarray(180.0 / np.pi, gz.type),)
3301+
return (gz * np.array(180.0 / np.pi, dtype=gz.type),)
33023302

33033303
def c_code(self, node, name, inputs, outputs, sub):
33043304
(x,) = inputs
@@ -3371,7 +3371,7 @@ def L_op(self, inputs, outputs, gout):
33713371
else:
33723372
return [x.zeros_like()]
33733373

3374-
return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),)
3374+
return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
33753375

33763376
def c_code(self, node, name, inputs, outputs, sub):
33773377
(x,) = inputs
@@ -3445,7 +3445,7 @@ def L_op(self, inputs, outputs, gout):
34453445
else:
34463446
return [x.zeros_like()]
34473447

3448-
return (gz / sqrt(np.cast[x.type](1) - sqr(x)),)
3448+
return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
34493449

34503450
def c_code(self, node, name, inputs, outputs, sub):
34513451
(x,) = inputs
@@ -3517,7 +3517,7 @@ def L_op(self, inputs, outputs, gout):
35173517
else:
35183518
return [x.zeros_like()]
35193519

3520-
return (gz / (np.cast[x.type](1) + sqr(x)),)
3520+
return (gz / (np.array(1, dtype=x.type) + sqr(x)),)
35213521

35223522
def c_code(self, node, name, inputs, outputs, sub):
35233523
(x,) = inputs
@@ -3640,7 +3640,7 @@ def L_op(self, inputs, outputs, gout):
36403640
else:
36413641
return [x.zeros_like()]
36423642

3643-
return (gz / sqrt(sqr(x) - np.cast[x.type](1)),)
3643+
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),)
36443644

36453645
def c_code(self, node, name, inputs, outputs, sub):
36463646
(x,) = inputs
@@ -3717,7 +3717,7 @@ def L_op(self, inputs, outputs, gout):
37173717
else:
37183718
return [x.zeros_like()]
37193719

3720-
return (gz / sqrt(sqr(x) + np.cast[x.type](1)),)
3720+
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),)
37213721

37223722
def c_code(self, node, name, inputs, outputs, sub):
37233723
(x,) = inputs
@@ -3795,7 +3795,7 @@ def L_op(self, inputs, outputs, gout):
37953795
else:
37963796
return [x.zeros_like()]
37973797

3798-
return (gz / (np.cast[x.type](1) - sqr(x)),)
3798+
return (gz / (np.array(1, dtype=x.type) - sqr(x)),)
37993799

38003800
def c_code(self, node, name, inputs, outputs, sub):
38013801
(x,) = inputs

pytensor/tensor/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
664664
and isinstance(self.nfunc, np.ufunc)
665665
and node.inputs[0].dtype in discrete_dtypes
666666
):
667-
char = np.sctype2char(out_dtype)
667+
char = np.dtype(out_dtype).char
668668
sig = char * node.nin + "->" + char * node.nout
669669
node.tag.sig = sig
670670
node.tag.fake_node = Apply(

pytensor/tensor/type.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Literal, Optional
55

66
import numpy as np
7+
import numpy.typing as npt
78

89
import pytensor
910
from pytensor import scalar as ps
@@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
6970

7071
def __init__(
7172
self,
72-
dtype: str | np.dtype,
73+
dtype: str | npt.DTypeLike,
7374
shape: Iterable[bool | int | None] | None = None,
7475
name: str | None = None,
7576
broadcastable: Iterable[bool] | None = None,
@@ -101,11 +102,11 @@ def __init__(
101102
if str(dtype) == "floatX":
102103
self.dtype = config.floatX
103104
else:
104-
if np.obj2sctype(dtype) is None:
105+
try:
106+
self.dtype = str(np.dtype(dtype))
107+
except TypeError:
105108
raise TypeError(f"Invalid dtype: {dtype}")
106109

107-
self.dtype = np.dtype(dtype).name
108-
109110
def parse_bcast_and_shape(s):
110111
if isinstance(s, bool | np.bool_):
111112
return 1 if s else None
@@ -789,14 +790,16 @@ def tensor(
789790
**kwargs,
790791
) -> "TensorVariable":
791792
if name is not None:
792-
# Help catching errors with the new tensor API
793-
# Many single letter strings are valid sctypes
794-
if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)):
795-
np.obj2sctype(name)
796-
raise ValueError(
797-
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
798-
"This name looks like a dtype, which you should pass as a keyword argument only."
799-
)
793+
try:
794+
# Help catching errors with the new tensor API
795+
# Many single letter strings are valid sctypes
796+
if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type):
797+
raise ValueError(
798+
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
799+
"This name looks like a dtype, which you should pass as a keyword argument only."
800+
)
801+
except TypeError:
802+
pass
800803

801804
if dtype is None:
802805
dtype = config.floatX

tests/scan/test_rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def test_machine_translation(self):
673673
zi = tensor3("zi")
674674
zi_value = x_value
675675

676-
init = pt.alloc(np.cast[config.floatX](0), batch_size, dim)
676+
init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim)
677677

678678
def rnn_step1(
679679
# sequences

tests/tensor/test_extra_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def test_perform(self, shp):
708708
y = scalar()
709709
f = function([x, y], fill_diagonal(x, y))
710710
a = rng.random(shp).astype(config.floatX)
711-
val = np.cast[config.floatX](rng.random())
711+
val = rng.random(dtype=config.floatX)
712712
out = f(a, val)
713713
# We can't use np.fill_diagonal as it is bugged.
714714
assert np.allclose(np.diag(out), val)
@@ -720,7 +720,7 @@ def test_perform_3d(self):
720720
x = tensor3()
721721
y = scalar()
722722
f = function([x, y], fill_diagonal(x, y))
723-
val = np.cast[config.floatX](rng.random() + 10)
723+
val = rng.random(dtype=config.floatX) + 10
724724
out = f(a, val)
725725
# We can't use np.fill_diagonal as it is bugged.
726726
assert out[0, 0, 0] == val
@@ -782,7 +782,7 @@ def test_perform(self, test_offset, shp):
782782

783783
f = function([x, y, z], fill_diagonal_offset(x, y, z))
784784
a = rng.random(shp).astype(config.floatX)
785-
val = np.cast[config.floatX](rng.random())
785+
val = rng.random(dtype=config.floatX)
786786
out = f(a, val, test_offset)
787787
# We can't use np.fill_diagonal as it is bugged.
788788
assert np.allclose(np.diag(out, test_offset), val)

tests/tensor/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn):
152152
"""
153153

154154
def ret(*args, **kwargs):
155-
out_dtype = np.find_common_type([a.dtype for a in args], [np.float16])
155+
out_dtype = np.result_type(np.float16, *args)
156156
if out_dtype == "float16":
157157
# Force everything to float32
158158
sig = "f" * fn.nin + "->" + "f" * fn.nout

tests/test_gradient.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,12 @@ def make_grad_func(X):
481481
int_type = imatrix().dtype
482482
float_type = "float64"
483483

484-
X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0)
485-
W = np.cast[W.dtype](rng.standard_normal((d, n)))
486-
b = np.cast[b.dtype](rng.standard_normal(n))
484+
X = np.asarray(rng.standard_normal((m, d)) * 127.0, dtype=int_type)
485+
W = rng.standard_normal((d, n), dtype=W.dtype)
486+
b = rng.standard_normal(n, dtype=b.dtype)
487487

488488
int_result = int_func(X, W, b)
489-
float_result = float_func(np.cast[float_type](X), W, b)
489+
float_result = float_func(np.asarray(X, dtype=float_type), W, b)
490490

491491
assert np.allclose(int_result, float_result), (int_result, float_result)
492492

@@ -508,7 +508,7 @@ def test_grad_disconnected(self):
508508
# the output
509509
f = pytensor.function([x], g)
510510
rng = np.random.default_rng([2012, 9, 5])
511-
x = np.cast[x.dtype](rng.standard_normal(3))
511+
x = rng.standard_normal(3, dtype=x.dtype)
512512
g = f(x)
513513
assert np.allclose(g, np.ones(x.shape, dtype=x.dtype))
514514

@@ -631,7 +631,8 @@ def test_known_grads():
631631
rng = np.random.default_rng([2012, 11, 15])
632632
values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()]
633633
values = [
634-
np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True)
634+
np.asarray(value, dtype=ipt.dtype)
635+
for ipt, value in zip(inputs, values, strict=True)
635636
]
636637

637638
true_grads = grad(cost, inputs, disconnected_inputs="ignore")
@@ -679,7 +680,7 @@ def test_known_grads_integers():
679680
f = pytensor.function([g_expected], g_grad)
680681

681682
x = -3
682-
gv = np.cast[config.floatX](0.6)
683+
gv = np.asarray(0.6, dtype=config.floatX)
683684

684685
g_actual = f(gv)
685686

@@ -746,7 +747,8 @@ def test_subgraph_grad():
746747
rng = np.random.default_rng([2012, 11, 15])
747748
values = [rng.standard_normal(2), rng.standard_normal(3)]
748749
values = [
749-
np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True)
750+
np.asarray(value, dtype=ipt.dtype)
751+
for ipt, value in zip(inputs, values, strict=True)
750752
]
751753

752754
wrt = [w2, w1]
@@ -1031,30 +1033,30 @@ def test_jacobian_scalar():
10311033
# test when the jacobian is called with a tensor as wrt
10321034
Jx = jacobian(y, x)
10331035
f = pytensor.function([x], Jx)
1034-
vx = np.cast[pytensor.config.floatX](rng.uniform())
1036+
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
10351037
assert np.allclose(f(vx), 2)
10361038

10371039
# test when the jacobian is called with a tuple as wrt
10381040
Jx = jacobian(y, (x,))
10391041
assert isinstance(Jx, tuple)
10401042
f = pytensor.function([x], Jx[0])
1041-
vx = np.cast[pytensor.config.floatX](rng.uniform())
1043+
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
10421044
assert np.allclose(f(vx), 2)
10431045

10441046
# test when the jacobian is called with a list as wrt
10451047
Jx = jacobian(y, [x])
10461048
assert isinstance(Jx, list)
10471049
f = pytensor.function([x], Jx[0])
1048-
vx = np.cast[pytensor.config.floatX](rng.uniform())
1050+
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
10491051
assert np.allclose(f(vx), 2)
10501052

10511053
# test when the jacobian is called with a list of two elements
10521054
z = scalar()
10531055
y = x * z
10541056
Jx = jacobian(y, [x, z])
10551057
f = pytensor.function([x, z], Jx)
1056-
vx = np.cast[pytensor.config.floatX](rng.uniform())
1057-
vz = np.cast[pytensor.config.floatX](rng.uniform())
1058+
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
1059+
vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
10581060
vJx = f(vx, vz)
10591061

10601062
assert np.allclose(vJx[0], vz)

tests/typed_list/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,10 @@ def test_correct_answer(self):
577577
x = tensor3()
578578
y = tensor3()
579579

580-
A = np.cast[pytensor.config.floatX](np.random.random((5, 3)))
581-
B = np.cast[pytensor.config.floatX](np.random.random((7, 2)))
582-
X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1)))
583-
Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3)))
580+
A = np.random.random((5, 3)).astype(pytensor.config.floatX)
581+
B = np.random.random((7, 2)).astype(pytensor.config.floatX)
582+
X = np.random.random((5, 6, 1)).astype(pytensor.config.floatX)
583+
Y = np.random.random((1, 9, 3)).astype(pytensor.config.floatX)
584584

585585
make_list((3.0, 4.0))
586586
c = make_list((a, b))

0 commit comments

Comments
 (0)