From 92b257ad27f01e0490d6c27a2249d1a841c38a91 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 10:53:07 -0400 Subject: [PATCH 01/12] Replace np.cast with np.asarray --- pytensor/scalar/basic.py | 14 +++++++------- tests/scan/test_rewriting.py | 2 +- tests/tensor/test_extra_ops.py | 6 +++--- tests/test_gradient.py | 26 +++++++++++++------------- tests/typed_list/test_basic.py | 8 ++++---- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 30919e16d9..f3388900cb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -3144,7 +3144,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * exp2(x) * log(np.cast[x.type](2)),) + return (gz * exp2(x) * log(np.asarray(2, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3391,7 +3391,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (-gz / sqrt(np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3465,7 +3465,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (gz / sqrt(np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3537,7 +3537,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) + sqr(x)),) + return (gz / (np.asarray(1, dtype=x.type) + sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3660,7 +3660,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) - np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) - np.asarray(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3737,7 +3737,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) + np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) + np.asarray(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3815,7 +3815,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) - sqr(x)),) + return (gz / (np.asarray(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index c9f11e891d..5d0d6e70fc 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -673,7 +673,7 @@ def test_machine_translation(self): zi = tensor3("zi") zi_value = x_value - init = pt.alloc(np.cast[config.floatX](0), batch_size, dim) + init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim) def rnn_step1( # sequences diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index cda745d023..1e4ec9f7ad 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -685,7 +685,7 @@ def test_perform(self, shp): y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) @@ -697,7 +697,7 @@ def test_perform_3d(self): x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) - val = np.cast[config.floatX](rng.random() + 10) + val = rng.random(dtype=config.floatX) + 10 out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val @@ -759,7 +759,7 @@ def test_perform(self, test_offset, shp): f = function([x, y, z], fill_diagonal_offset(x, y, z)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val, test_offset) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out, test_offset), val) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index f9f1e8fe4b..12e7f7b5a2 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -480,12 +480,12 @@ def make_grad_func(X): int_type = imatrix().dtype float_type = "float64" - X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0) - W = np.cast[W.dtype](rng.standard_normal((d, n))) - b = np.cast[b.dtype](rng.standard_normal(n)) + X = rng.standard_normal((m, d), dtype=int_type) * 127.0 + W = rng.standard_normal((d, n), dtype=W.dtype) + b = rng.standard_normal(n, dtype=b.dtype) int_result = int_func(X, W, b) - float_result = float_func(np.cast[float_type](X), W, b) + float_result = float_func(np.asarray(X, dtype=float_type), W, b) assert np.allclose(int_result, float_result), (int_result, float_result) @@ -507,7 +507,7 @@ def test_grad_disconnected(self): # the output f = pytensor.function([x], g) rng = np.random.default_rng([2012, 9, 5]) - x = np.cast[x.dtype](rng.standard_normal(3)) + x = rng.standard_normal(3, dtype=x.dtype) g = f(x) assert np.allclose(g, np.ones(x.shape, dtype=x.dtype)) @@ -629,7 +629,7 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [np.asarray(value, dtype=ipt.dtype) for ipt, value in zip(inputs, values)] true_grads = grad(cost, inputs, disconnected_inputs="ignore") true_grads = pytensor.function(inputs, true_grads) @@ -676,7 +676,7 @@ def test_known_grads_integers(): f = pytensor.function([g_expected], g_grad) x = -3 - gv = np.cast[config.floatX](0.6) + gv = np.asarray(0.6, dtype=config.floatX) g_actual = f(gv) @@ -742,7 +742,7 @@ def test_subgraph_grad(): inputs = [t, x] rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [np.asarray(value, dtype=ipt.dtype) for ipt, value in zip(inputs, values)] wrt = [w2, w1] cost = cost2 + cost1 @@ -1026,21 +1026,21 @@ def test_jacobian_scalar(): # test when the jacobian is called with a tensor as wrt Jx = jacobian(y, x) f = pytensor.function([x], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = rng.uniform(dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = rng.uniform(dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list as wrt Jx = jacobian(y, [x]) assert isinstance(Jx, list) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = rng.uniform(dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list of two elements @@ -1048,8 +1048,8 @@ def test_jacobian_scalar(): y = x * z Jx = jacobian(y, [x, z]) f = pytensor.function([x, z], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) - vz = np.cast[pytensor.config.floatX](rng.uniform()) + vx = rng.uniform(dtype=pytensor.config.floatX) + vz = rng.uniform(dtype=pytensor.config.floatX) vJx = f(vx, vz) assert np.allclose(vJx[0], vz) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 4b309c2324..c790cc8f84 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -577,10 +577,10 @@ def test_correct_answer(self): x = tensor3() y = tensor3() - A = np.cast[pytensor.config.floatX](np.random.random((5, 3))) - B = np.cast[pytensor.config.floatX](np.random.random((7, 2))) - X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1))) - Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3))) + A = np.random.random((5, 3), dtype=pytensor.config.floatX) + B = np.random.random((7, 2), dtype=pytensor.config.floatX) + X = np.random.random((5, 6, 1), dtype=pytensor.config.floatX) + Y = np.random.random((1, 9, 3), dtype=pytensor.config.floatX) make_list((3.0, 4.0)) c = make_list((a, b)) From 4b83f9433b716afbf01435533f2e27f02fd55394 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 10:58:42 -0400 Subject: [PATCH 02/12] Replace np.sctype2char --- pytensor/tensor/elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 632fa0976d..ef5b2b9aa9 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -685,7 +685,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): and isinstance(self.nfunc, np.ufunc) and node.inputs[0].dtype in discrete_dtypes ): - char = np.sctype2char(out_dtype) + char = np.dtype(out_dtype).char sig = char * node.nin + "->" + char * node.nout node.tag.sig = sig node.tag.fake_node = Apply( From 5cef8fb40e3a1a83a64858f5bb4a5f661c9ddd6b Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 11:01:11 -0400 Subject: [PATCH 03/12] Remove np.obj2sctype --- pytensor/tensor/type.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index b55d226471..01911ff47a 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -102,7 +102,7 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + if np.dtype(dtype).type is None: raise TypeError(f"Invalid dtype: {dtype}") self.dtype = np.dtype(dtype).name @@ -785,8 +785,7 @@ def tensor( if name is not None: # Help catching errors with the new tensor API # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)): - np.obj2sctype(name) + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): raise ValueError( f"The first and only positional argument of tensor is now `name`. Got {name}.\n" "This name looks like a dtype, which you should pass as a keyword argument only." From bf55d44b08903abb20769d1b8c4499e7ebe87d16 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 11:09:39 -0400 Subject: [PATCH 04/12] Replace np.find_common_type with np.result_type --- tests/tensor/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 51c3a28ce3..a3ee820d01 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn): """ def ret(*args, **kwargs): - out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) + out_dtype = np.result_type(np.float16, *args) if out_dtype == "float16": # Force everything to float32 sig = "f" * fn.nin + "->" + "f" * fn.nout From b5c15dc7e2bd929fc7aab4409ba1520d5ab4d195 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 11:19:10 -0400 Subject: [PATCH 05/12] Add ruff numpy2 transition rule --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b7f4c86535..a52e5e9d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,7 @@ line-length = 88 exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"] [tool.ruff.lint] -select = ["C", "E", "F", "I", "UP", "W", "RUF"] +select = ["C", "E", "F", "I", "UP", "W", "RUF", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012"] From c0de3fca7b8ddab4dc884af81eeab7c8023228d6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:16:08 +0200 Subject: [PATCH 06/12] Update numpy deprecated imports --- pytensor/tensor/__init__.py | 2 +- pytensor/tensor/basic.py | 2 +- pytensor/tensor/conv/abstract_conv.py | 3 ++- pytensor/tensor/extra_ops.py | 9 +++++---- pytensor/tensor/slinalg.py | 3 ++- tests/tensor/test_io.py | 2 +- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 3dfa1b4b7a..49ef4738bb 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -123,7 +123,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: on # Allow accessing numpy constants from pytensor.tensor -from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi +from numpy import e, euler_gamma, inf, nan, newaxis, pi from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f36f8888ba..b04b4a3c46 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4198,7 +4198,7 @@ def expand_dims( axis = (axis,) out_ndim = len(axis) + a.ndim - axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) + axis = normalize_axis_tuple(axis, out_ndim) dim_it = iter(range(a.ndim)) pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)] diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index d504e89386..febe44a40b 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -8,6 +8,7 @@ from math import gcd import numpy as np +from numpy.exceptions import ComplexWarning try: @@ -2341,7 +2342,7 @@ def conv( bval = _bvalfromboundary("fill") with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) for b in range(img.shape[0]): for g in range(self.num_groups): for n in range(output_channel_offset): diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 94498eaed0..d0b49ec866 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,7 +1,8 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.core.multiarray import normalize_axis_index +from numpy.exceptions import AxisError +from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple import pytensor import pytensor.scalar.basic as ps @@ -584,9 +585,9 @@ def squeeze(x, axis=None): # scalar inputs are treated as 1D regarding axis in this `Op` try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=_x.ndim) + axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) + except AxisError: + raise AxisError(axis, ndim=_x.ndim) if not axis: # Nothing to do diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4f72c0263a..d379c29b0c 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,6 +6,7 @@ import numpy as np import scipy.linalg +from numpy.exceptions import ComplexWarning import pytensor import pytensor.tensor as pt @@ -633,7 +634,7 @@ def perform(self, node, inputs, outputs): Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) out[0] = Y.astype(A.dtype) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index f8d0495824..75cb08a835 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -51,7 +51,7 @@ def test_memmap(self): path = Variable(Generic(), None) x = load(path, "int32", (None,), mmap_mode="c") fn = function([path], x) - assert type(fn(self.filename)) == np.core.memmap + assert type(fn(self.filename)) == np.memmap def teardown_method(self): os.remove(os.path.join(pytensor.config.compiledir, "_test.npy")) From d3f138ce4bc513789cb5cdb59d1cb5357bdd328c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 13:06:07 +0200 Subject: [PATCH 07/12] Handle change in behavior np.dtype --- pytensor/tensor/type.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 01911ff47a..a13f5ab9d2 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -783,13 +783,16 @@ def tensor( **kwargs, ) -> "TensorVariable": if name is not None: - # Help catching errors with the new tensor API - # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): - raise ValueError( - f"The first and only positional argument of tensor is now `name`. Got {name}.\n" - "This name looks like a dtype, which you should pass as a keyword argument only." - ) + try: + # Help catching errors with the new tensor API + # Many single letter strings are valid sctypes + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + raise ValueError( + f"The first and only positional argument of tensor is now `name`. Got {name}.\n" + "This name looks like a dtype, which you should pass as a keyword argument only." + ) + except TypeError: + pass if dtype is None: dtype = config.floatX From 75efbdc868fb79a169da74270d32460aeab8f86b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 12:47:02 +0200 Subject: [PATCH 08/12] Update access to `_get_ndarray_c_version` Also removes special case for old unsupported numpy 1.16 --- pytensor/link/c/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index e11247c9b3..2eb5856183 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1371,8 +1371,8 @@ def cmodule_key_( # We must always add the numpy ABI version here as # DynamicModule always add the include - if np.lib.NumpyVersion(np.__version__) < "1.16.0a": - ndarray_c_version = np.core.multiarray._get_ndarray_c_version() + if np.lib.NumpyVersion(np.__version__) >= "2.0.0rc": + ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() else: ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") From 8a83fb0e90b9b42ef3444a30a4c6a3706d2ea229 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:42:14 +0200 Subject: [PATCH 09/12] Replace `->elsize` by `PyArray_ITEMSIZE` --- pytensor/sparse/basic.py | 20 +++---- pytensor/sparse/rewriting.py | 94 ++++++++++++++++----------------- pytensor/tensor/blas.py | 8 +-- pytensor/tensor/blas_headers.py | 4 +- tests/compile/test_debugmode.py | 6 +-- 5 files changed, 66 insertions(+), 66 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 7c89d81cb5..d824678564 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -3607,7 +3607,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3643,11 +3643,11 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; @@ -3740,7 +3740,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3777,11 +3777,11 @@ def c_code(self, node, name, inputs, outputs, sub): // extract number of rows npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 436209bd6d..9768b1c5e8 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -158,8 +158,8 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{y}* ydata = (dtype_{y}*)PyArray_DATA({y}); dtype_{z}* zdata = (dtype_{z}*)PyArray_DATA({z}); - npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_DESCR({y})->elsize; - npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; + npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_ITEMSIZE({y}); + npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); npy_intp pos; if ({format} == 0){{ @@ -186,7 +186,7 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[3]] def c_code_cache_version(self): - return (2,) + return (3,) @node_rewriter([sparse.AddSD]) @@ -360,13 +360,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -435,7 +435,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3,) + return (4,) sd_csc = StructuredDotCSC() @@ -553,13 +553,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -612,7 +612,7 @@ def c_code(self, node, name, inputs, outputs, sub): """.format(**dict(locals(), **sub)) def c_code_cache_version(self): - return (2,) + return (3,) sd_csr = StructuredDotCSR() @@ -842,12 +842,12 @@ def c_code(self, node, name, inputs, outputs, sub): const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA({x_ptr}); const dtype_{alpha} alpha = ((dtype_{alpha}*)PyArray_DATA({alpha}))[0]; - npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_DESCR({zn})->elsize; - npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_DESCR({x_val})->elsize; - npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_DESCR({x_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_DESCR({x_ptr})->elsize; - npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_DESCR({y})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_ITEMSIZE({zn}); + npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_ITEMSIZE({x_val}); + npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_ITEMSIZE({x_ind}); + npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_ITEMSIZE({x_ptr}); + npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_ITEMSIZE({y}); // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) @@ -893,7 +893,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3, blas.blas_header_version()) + return (4, blas.blas_header_version()) usmm_csc_dense = UsmmCscDense(inplace=False) @@ -1031,13 +1031,13 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; - npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_DESCR({b_val})->elsize; - npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_DESCR({b_ind})->elsize; - npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_DESCR({b_ptr})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); + npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_ITEMSIZE({b_val}); + npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_ITEMSIZE({b_ind}); + npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_ITEMSIZE({b_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -1082,7 +1082,7 @@ def c_code(self, node, name, inputs, outputs, sub): """.format(**dict(locals(), **sub)) def c_code_cache_version(self): - return (3,) + return (4,) csm_grad_c = CSMGradC() @@ -1476,7 +1476,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (2,) + return (3,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1537,7 +1537,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over rows for (npy_intp j = 0; j < N; ++j) @@ -1648,7 +1648,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1715,7 +1715,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over columns for (npy_intp j = 0; j < N; ++j) @@ -1860,7 +1860,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): ) def c_code_cache_version(self): - return (4, blas.blas_header_version()) + return (5, blas.blas_header_version()) def c_support_code(self, **kwargs): return blas.blas_header_text() @@ -1986,14 +1986,14 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{z_ind}* __restrict__ Dzi = (dtype_{z_ind}*)PyArray_DATA({z_ind}); dtype_{z_ptr}* __restrict__ Dzp = (dtype_{z_ptr}*)PyArray_DATA({z_ptr}); - const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_DESCR({x})->elsize; - const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; - const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_DESCR({p_data})->elsize; - const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_DESCR({p_ind})->elsize; - const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_DESCR({p_ptr})->elsize; - const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_DESCR({z_data})->elsize; - const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_DESCR({z_ind})->elsize; - const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_DESCR({z_ptr})->elsize; + const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_ITEMSIZE({x}); + const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); + const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_ITEMSIZE({p_data}); + const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_ITEMSIZE({p_ind}); + const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_ITEMSIZE({p_ptr}); + const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_ITEMSIZE({z_data}); + const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_ITEMSIZE({z_ind}); + const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_ITEMSIZE({z_ptr}); memcpy(Dzi, Dpi, PyArray_DIMS({p_ind})[0]*sizeof(dtype_{p_ind})); memcpy(Dzp, Dpp, PyArray_DIMS({p_ptr})[0]*sizeof(dtype_{p_ptr})); diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index a162cfc960..f6e7f23cc7 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -504,7 +504,7 @@ def c_header_dirs(self, **kwargs): int unit = 0; int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Ny = PyArray_DIMS(%(_y)s); @@ -795,7 +795,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (13, blas_header_version()) + return (14, blas_header_version()) class Gemm(GemmRelated): @@ -1870,7 +1870,7 @@ def contiguous(var, ndim): return """ int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_DESCR({_x})->elsize; // in bytes + int type_size = PyArray_ITEMSIZE({_x}); // in bytes if (PyArray_NDIM({_x}) != 3) {{ PyErr_Format(PyExc_NotImplementedError, @@ -1930,7 +1930,7 @@ def contiguous(var, ndim): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (5, blas_header_version()) + return (6, blas_header_version()) def grad(self, inp, grads): x, y = inp diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index af23cd76f1..81aaf941ce 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -1061,7 +1061,7 @@ def openblas_threads_text(): def blas_header_version(): # Version for the base header - version = (9,) + version = (10,) if detect_macos_sdot_bug(): if detect_macos_sdot_bug.fix_works: # Version with fix @@ -1079,7 +1079,7 @@ def ____gemm_code(check_ab, a_init, b_init): const char * error_string = NULL; int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_DESCR(_x)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(_x); // in bytes npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Ny = PyArray_DIMS(_y); diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index f8caed2c33..0dc6080659 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -145,7 +145,7 @@ def dontuse_perform(self, node, inp, out_): raise ValueError(self.behaviour) def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inp, out, sub): (a,) = inp @@ -164,8 +164,8 @@ def c_code(self, node, name, inp, out, sub): prep_vars = """ //the output array has size M x N npy_intp M = PyArray_DIMS(%(a)s)[0]; - npy_intp Sa = PyArray_STRIDES(%(a)s)[0] / PyArray_DESCR(%(a)s)->elsize; - npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Sa = PyArray_STRIDES(%(a)s)[0] / PyArray_ITEMSIZE(%(a)s); + npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / PyArray_ITEMSIZE(%(z)s); npy_double * Da = (npy_double*)PyArray_BYTES(%(a)s); npy_double * Dz = (npy_double*)PyArray_BYTES(%(z)s); From 5e6993738f53617c626efd76f650297cf45ad1f6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 13:05:36 +0200 Subject: [PATCH 10/12] Don't use deprecated PyArray_MoveInto --- pytensor/tensor/blas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index f6e7f23cc7..04d64a16d5 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1036,7 +1036,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(x_new, %(_x)s) == -1) + if(PyArray_CopyInto(x_new, %(_x)s) == -1) { %(fail)s } @@ -1062,7 +1062,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(y_new, %(_y)s) == -1) + if(PyArray_CopyInto(y_new, %(_y)s) == -1) { %(fail)s } @@ -1108,7 +1108,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): gv = self.build_gemm_version() if gv: - return (7, *gv) + return (8, *gv) else: return gv From 2b58be2ecc0187cc1c56006198b76bf7e1129c85 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 12:48:03 +0200 Subject: [PATCH 11/12] Remove custom Complex type --- pytensor/scalar/basic.py | 181 ++----------------------------- pytensor/sparse/type.py | 4 +- pytensor/tensor/elemwise_cgen.py | 2 - pytensor/tensor/type.py | 4 +- 4 files changed, 15 insertions(+), 176 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f3388900cb..d2b8f8d39f 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -279,8 +279,6 @@ class ScalarType(CType, HasDataType, HasShape): Analogous to TensorType, but for zero-dimensional objects. Maps directly to C primitives. - TODO: refactor to be named ScalarType for consistency with TensorType. - """ __props__ = ("dtype",) @@ -350,11 +348,14 @@ def c_element_type(self): return self.dtype_specs()[1] def c_headers(self, c_compiler=None, **kwargs): - l = [""] - # These includes are needed by ScalarType and TensorType, - # we declare them here and they will be re-used by TensorType - l.append("") - l.append("") + l = [ + "", + # These includes are needed by ScalarType and TensorType, + # we declare them here and they will be re-used by TensorType + "", + "", + "", + ] if config.lib__amblibm and c_compiler.supports_amdlibm: l += [""] return l @@ -396,8 +397,8 @@ def dtype_specs(self): "float16": (np.float16, "npy_float16", "Float16"), "float32": (np.float32, "npy_float32", "Float32"), "float64": (np.float64, "npy_float64", "Float64"), - "complex128": (np.complex128, "pytensor_complex128", "Complex128"), - "complex64": (np.complex64, "pytensor_complex64", "Complex64"), + "complex128": (np.complex128, "npy_complex128", "Complex128"), + "complex64": (np.complex64, "npy_complex64", "Complex64"), "bool": (np.bool_, "npy_bool", "Bool"), "uint8": (np.uint8, "npy_uint8", "UInt8"), "int8": (np.int8, "npy_int8", "Int8"), @@ -506,171 +507,11 @@ def c_sync(self, name, sub): def c_cleanup(self, name, sub): return "" - def c_support_code(self, **kwargs): - if self.dtype.startswith("complex"): - cplx_types = ["pytensor_complex64", "pytensor_complex128"] - real_types = [ - "npy_int8", - "npy_int16", - "npy_int32", - "npy_int64", - "npy_float32", - "npy_float64", - ] - # If the 'int' C type is not exactly the same as an existing - # 'npy_intX', some C code may not compile, e.g. when assigning - # the value 0 (cast to 'int' in C) to an PyTensor_complex64. - if np.dtype("intc").num not in [np.dtype(d[4:]).num for d in real_types]: - # In that case we add the 'int' type to the real types. - real_types.append("int") - - template = """ - struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s - { - typedef pytensor_complex%(nbits)s complex_type; - typedef npy_float%(half_nbits)s scalar_type; - - complex_type operator +(const complex_type &y) const { - complex_type ret; - ret.real = this->real + y.real; - ret.imag = this->imag + y.imag; - return ret; - } - - complex_type operator -() const { - complex_type ret; - ret.real = -this->real; - ret.imag = -this->imag; - return ret; - } - bool operator ==(const complex_type &y) const { - return (this->real == y.real) && (this->imag == y.imag); - } - bool operator ==(const scalar_type &y) const { - return (this->real == y) && (this->imag == 0); - } - complex_type operator -(const complex_type &y) const { - complex_type ret; - ret.real = this->real - y.real; - ret.imag = this->imag - y.imag; - return ret; - } - complex_type operator *(const complex_type &y) const { - complex_type ret; - ret.real = this->real * y.real - this->imag * y.imag; - ret.imag = this->real * y.imag + this->imag * y.real; - return ret; - } - complex_type operator /(const complex_type &y) const { - complex_type ret; - scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; - ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; - ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; - return ret; - } - template - complex_type& operator =(const T& y); - - pytensor_complex%(nbits)s() {} - - template - pytensor_complex%(nbits)s(const T& y) { *this = y; } - - template - pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } - }; - """ - - def operator_eq_real(mytype, othertype): - return f""" - template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y; this->imag=0; return *this; }} - """ - - def operator_eq_cplx(mytype, othertype): - return f""" - template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y.real; this->imag=y.imag; return *this; }} - """ - - operator_eq = "".join( - operator_eq_real(ctype, rtype) - for ctype in cplx_types - for rtype in real_types - ) + "".join( - operator_eq_cplx(ctype1, ctype2) - for ctype1 in cplx_types - for ctype2 in cplx_types - ) - - # We are not using C++ generic templating here, because this would - # generate two different functions for adding a complex64 and a - # complex128, one returning a complex64, the other a complex128, - # and the compiler complains it is ambiguous. - # Instead, we generate code for known and safe types only. - - def operator_plus_real(mytype, othertype): - return f""" - const {mytype} operator+(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real+y, x.imag); }} - - const {mytype} operator+(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real+y, x.imag); }} - """ - - operator_plus = "".join( - operator_plus_real(ctype, rtype) - for ctype in cplx_types - for rtype in real_types - ) - - def operator_minus_real(mytype, othertype): - return f""" - const {mytype} operator-(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real-y, x.imag); }} - - const {mytype} operator-(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(y-x.real, -x.imag); }} - """ - - operator_minus = "".join( - operator_minus_real(ctype, rtype) - for ctype in cplx_types - for rtype in real_types - ) - - def operator_mul_real(mytype, othertype): - return f""" - const {mytype} operator*(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real*y, x.imag*y); }} - - const {mytype} operator*(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real*y, x.imag*y); }} - """ - - operator_mul = "".join( - operator_mul_real(ctype, rtype) - for ctype in cplx_types - for rtype in real_types - ) - - return ( - template % dict(nbits=64, half_nbits=32) - + template % dict(nbits=128, half_nbits=64) - + operator_eq - + operator_plus - + operator_minus - + operator_mul - ) - - else: - return "" - def c_init_code(self, **kwargs): return ["import_array();"] def c_code_cache_version(self): - return (13, np.__version__) + return (14, np.__version__) def get_shape_info(self, obj): return obj.itemsize diff --git a/pytensor/sparse/type.py b/pytensor/sparse/type.py index 421f3d26a3..f364575270 100644 --- a/pytensor/sparse/type.py +++ b/pytensor/sparse/type.py @@ -59,8 +59,8 @@ class SparseTensorType(TensorType, HasDataType): "int32": (int, "npy_int32", "NPY_INT32"), "uint64": (int, "npy_uint64", "NPY_UINT64"), "int64": (int, "npy_int64", "NPY_INT64"), - "complex128": (complex, "pytensor_complex128", "NPY_COMPLEX128"), - "complex64": (complex, "pytensor_complex64", "NPY_COMPLEX64"), + "complex128": (complex, "npy_complex128", "NPY_COMPLEX128"), + "complex64": (complex, "npy_complex64", "NPY_COMPLEX64"), } ndim = 2 diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index 57155db9f9..c58c4fec6e 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -166,8 +166,6 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): """ type = dtype.upper() - if type.startswith("PYTENSOR_COMPLEX"): - type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") nd = len(loop_orders[0]) init_dims = compute_output_dims_lengths("dims", loop_orders, sub) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index a13f5ab9d2..67a5fc0b16 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -50,8 +50,8 @@ "int32": (int, "npy_int32", "NPY_INT32"), "uint64": (int, "npy_uint64", "NPY_UINT64"), "int64": (int, "npy_int64", "NPY_INT64"), - "complex128": (complex, "pytensor_complex128", "NPY_COMPLEX128"), - "complex64": (complex, "pytensor_complex64", "NPY_COMPLEX64"), + "complex128": (complex, "npy_complex128", "NPY_COMPLEX128"), + "complex64": (complex, "npy_complex64", "NPY_COMPLEX64"), } From 3d228ebf2130522063458ce91d86bf48dce5da09 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 11:12:23 -0400 Subject: [PATCH 12/12] Try numpy 2.0.0rc1 --- environment.yml | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/environment.yml b/environment.yml index e84f1c5207..9c6046b4a4 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0 + - conda-forge/label/numpy_dev::numpy=2.0.0rc1 - scipy>=0.14 - filelock - etuples diff --git a/pyproject.toml b/pyproject.toml index a52e5e9d80..e014312d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "setuptools>=48.0.0", "cython", - "numpy>=1.17.0", + "numpy>=2.0.0rc1", "versioneer[toml]>=0.28", ] build-backend = "setuptools.build_meta" @@ -48,7 +48,7 @@ keywords = [ dependencies = [ "setuptools>=48.0.0", "scipy>=0.14", - "numpy>=1.17.0", + "numpy>=2.0.0rc1", "filelock", "etuples", "logical-unification",