Skip to content

Commit 0e20b4c

Browse files
Replace use of np.MAXDIMS
`np.MAXDIMS` was removed from the public API and no replacement is given in the migration docs. In numpy <= 1.26, the value of `np.MAXDIMS` was 32. This was often used as a flag to mean `axis=None`. In numpy >= 2.0, the maximum number of dims of an array has been increased to 64; simultaneously, a constant `NPY_RAVEL_AXIS` was added to the C-API to indicate that `axis=None`. In most cases, the use of `np.MAXDIMS` to check for `axis=None` can be replaced by the new constant `NPY_RAVEL_AXIS`. To make this constant accessible when using numpy <= 1.26, I added a function to insert `npy_2_compat.h` into the support code for the affected ops.
1 parent 71a5071 commit 0e20b4c

File tree

6 files changed

+106
-50
lines changed

6 files changed

+106
-50
lines changed

pytensor/npy_2_compat.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,21 @@
4646
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
4747

4848

49+
# used in tests: the type of error thrown if a value is too large for the specified
50+
# numpy data type is different in numpy 2.x
51+
UintOverflowError = OverflowError if using_numpy_2 else TypeError
52+
53+
54+
# to patch up some of the C code, we need to use these special values...
4955
if using_numpy_2:
50-
UintOverflowError = OverflowError
56+
numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS"
5157
else:
52-
UintOverflowError = TypeError
58+
# 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0
59+
numpy_axis_is_none_flag = 32
60+
61+
62+
# max number of dims is 64 in numpy 2.x; 32 in older versions
63+
numpy_maxdims = 64 if using_numpy_2 else 32
5364

5465

5566
def npy_2_compat_header() -> str:

pytensor/tensor/extra_ops.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import Collection, Iterable
33

44
import numpy as np
5-
from numpy.exceptions import AxisError
65

76
import pytensor
87
import pytensor.scalar.basic as ps
@@ -19,10 +18,11 @@
1918
from pytensor.link.c.type import EnumList, Generic
2019
from pytensor.npy_2_compat import (
2120
normalize_axis_index,
22-
normalize_axis_tuple,
21+
npy_2_compat_header,
22+
numpy_axis_is_none_flag,
2323
)
2424
from pytensor.raise_op import Assert
25-
from pytensor.scalar import int32 as int_t
25+
from pytensor.scalar import int64 as int_t
2626
from pytensor.scalar import upcast
2727
from pytensor.tensor import TensorLike, as_tensor_variable
2828
from pytensor.tensor import basic as ptb
@@ -47,6 +47,7 @@
4747
from pytensor.tensor.shape import Shape_i
4848
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4949
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
50+
from pytensor.tensor.utils import normalize_reduce_axis
5051
from pytensor.tensor.variable import TensorVariable
5152
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
5253

@@ -302,7 +303,11 @@ def __init__(self, axis: int | None = None, mode="add"):
302303
self.axis = axis
303304
self.mode = mode
304305

305-
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
306+
@property
307+
def c_axis(self) -> int:
308+
if self.axis is None:
309+
return numpy_axis_is_none_flag
310+
return self.axis
306311

307312
def make_node(self, x):
308313
x = ptb.as_tensor_variable(x)
@@ -359,24 +364,37 @@ def infer_shape(self, fgraph, node, shapes):
359364

360365
return shapes
361366

367+
def c_support_code_apply(self, node: Apply, name: str) -> str:
368+
"""Needed to define NPY_RAVEL_AXIS"""
369+
return npy_2_compat_header()
370+
362371
def c_code(self, node, name, inames, onames, sub):
363372
(x,) = inames
364373
(z,) = onames
365374
fail = sub["fail"]
366375
params = sub["params"]
367376

368-
code = f"""
369-
int axis = {params}->c_axis;
377+
if self.axis is None:
378+
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
379+
else:
380+
axis_code = f"int axis = {params}->c_axis;\n"
381+
382+
code = (
383+
axis_code
384+
+ f"""
385+
#undef NPY_UF_DBG_TRACING
386+
#define NPY_UF_DBG_TRACING 1
387+
370388
if (axis == 0 && PyArray_NDIM({x}) == 1)
371-
axis = NPY_MAXDIMS;
389+
axis = NPY_RAVEL_AXIS;
372390
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
373-
if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
391+
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
374392
{{
375393
Py_XDECREF({z});
376-
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}));
394+
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
377395
}}
378396
379-
else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
397+
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
380398
{{
381399
Py_XDECREF({z});
382400
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
@@ -403,11 +421,12 @@ def c_code(self, node, name, inames, onames, sub):
403421
Py_XDECREF(t);
404422
}}
405423
"""
424+
)
406425

407426
return code
408427

409428
def c_code_cache_version(self):
410-
return (8,)
429+
return (9,)
411430

412431
def __str__(self):
413432
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
@@ -598,11 +617,7 @@ def squeeze(x, axis=None):
598617
elif not isinstance(axis, Collection):
599618
axis = (axis,)
600619

601-
# scalar inputs are treated as 1D regarding axis in this `Op`
602-
try:
603-
axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim))
604-
except AxisError:
605-
raise AxisError(axis, ndim=_x.ndim)
620+
axis = normalize_reduce_axis(axis, ndim=_x.ndim)
606621

607622
if not axis:
608623
# Nothing to do

pytensor/tensor/math.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from pytensor.graph.replace import _vectorize_node
1414
from pytensor.link.c.op import COp
1515
from pytensor.link.c.params_type import ParamsType
16-
from pytensor.npy_2_compat import normalize_axis_tuple
16+
from pytensor.npy_2_compat import (
17+
normalize_axis_tuple,
18+
npy_2_compat_header,
19+
numpy_axis_is_none_flag,
20+
)
1721
from pytensor.printing import pprint
1822
from pytensor.raise_op import Assert
1923
from pytensor.scalar.basic import BinaryScalarOp
@@ -160,7 +164,7 @@ def get_params(self, node):
160164
c_axis = np.int64(self.axis[0])
161165
else:
162166
# The value here doesn't matter, it won't be used
163-
c_axis = np.int64(-1)
167+
c_axis = numpy_axis_is_none_flag
164168
return self.params_type.get_params(c_axis=c_axis)
165169

166170
def make_node(self, x):
@@ -203,13 +207,17 @@ def perform(self, node, inp, outs):
203207

204208
max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
205209

210+
def c_support_code_apply(self, node: Apply, name: str) -> str:
211+
"""Needed to define NPY_RAVEL_AXIS"""
212+
return npy_2_compat_header()
213+
206214
def c_code(self, node, name, inp, out, sub):
207215
(x,) = inp
208216
(argmax,) = out
209217
fail = sub["fail"]
210218
params = sub["params"]
211219
if self.axis is None:
212-
axis_code = "axis = NPY_MAXDIMS;"
220+
axis_code = "axis = NPY_RAVEL_AXIS;"
213221
else:
214222
if len(self.axis) != 1:
215223
raise NotImplementedError()

pytensor/tensor/special.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.graph.basic import Apply
77
from pytensor.graph.replace import _vectorize_node
88
from pytensor.link.c.op import COp
9+
from pytensor.npy_2_compat import npy_2_compat_header
910
from pytensor.tensor.basic import as_tensor_variable
1011
from pytensor.tensor.elemwise import get_normalized_batch_axes
1112
from pytensor.tensor.math import gamma, gammaln, log, neg, sum
@@ -60,12 +61,16 @@ def infer_shape(self, fgraph, node, shape):
6061
return [shape[1]]
6162

6263
def c_code_cache_version(self):
63-
return (4,)
64+
return (5,)
65+
66+
def c_support_code_apply(self, node: Apply, name: str) -> str:
67+
# return super().c_support_code_apply(node, name)
68+
return npy_2_compat_header()
6469

6570
def c_code(self, node, name, inp, out, sub):
6671
dy, sm = inp
6772
(dx,) = out
68-
axis = self.axis if self.axis is not None else np.MAXDIMS
73+
axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS"
6974
fail = sub["fail"]
7075

7176
return dedent(
@@ -79,7 +84,7 @@ def c_code(self, node, name, inp, out, sub):
7984
8085
int sm_ndim = PyArray_NDIM({sm});
8186
int axis = {axis};
82-
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
87+
int iterate_axis = !(axis == NPY_RAVEL_AXIS || sm_ndim == 1);
8388
8489
// Validate inputs
8590
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
@@ -95,13 +100,15 @@ def c_code(self, node, name, inp, out, sub):
95100
{fail};
96101
}}
97102
98-
if (axis < 0) axis = sm_ndim + axis;
99-
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
103+
if (iterate_axis)
100104
{{
101-
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
102-
{fail};
105+
if (axis < 0) axis = sm_ndim + axis;
106+
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
107+
{{
108+
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
109+
{fail};
110+
}}
103111
}}
104-
105112
if (({dx} == NULL)
106113
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
107114
{{
@@ -289,10 +296,14 @@ def infer_shape(self, fgraph, node, shape):
289296
def c_headers(self, **kwargs):
290297
return ["<iostream>", "<cmath>"]
291298

299+
def c_support_code_apply(self, node: Apply, name: str) -> str:
300+
"""Needed to define NPY_RAVEL_AXIS"""
301+
return npy_2_compat_header()
302+
292303
def c_code(self, node, name, inp, out, sub):
293304
(x,) = inp
294305
(sm,) = out
295-
axis = self.axis if self.axis is not None else np.MAXDIMS
306+
axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS"
296307
fail = sub["fail"]
297308
# dtype = node.inputs[0].type.dtype_specs()[1]
298309
# TODO: put this into a templated function, in the support code
@@ -309,7 +320,7 @@ def c_code(self, node, name, inp, out, sub):
309320
310321
int x_ndim = PyArray_NDIM({x});
311322
int axis = {axis};
312-
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
323+
int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1);
313324
314325
// Validate inputs
315326
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
@@ -319,11 +330,14 @@ def c_code(self, node, name, inp, out, sub):
319330
{fail}
320331
}}
321332
322-
if (axis < 0) axis = x_ndim + axis;
323-
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
333+
if (iterate_axis)
324334
{{
325-
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
326-
{fail}
335+
if (axis < 0) axis = x_ndim + axis;
336+
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
337+
{{
338+
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
339+
{fail}
340+
}}
327341
}}
328342
329343
// Allocate Output Array
@@ -481,7 +495,7 @@ def c_code(self, node, name, inp, out, sub):
481495

482496
@staticmethod
483497
def c_code_cache_version():
484-
return (4,)
498+
return (5,)
485499

486500

487501
def softmax(c, axis=None):
@@ -541,10 +555,14 @@ def infer_shape(self, fgraph, node, shape):
541555
def c_headers(self, **kwargs):
542556
return ["<cmath>"]
543557

558+
def c_support_code_apply(self, node: Apply, name: str) -> str:
559+
"""Needed to define NPY_RAVEL_AXIS"""
560+
return npy_2_compat_header()
561+
544562
def c_code(self, node, name, inp, out, sub):
545563
(x,) = inp
546564
(sm,) = out
547-
axis = self.axis if self.axis is not None else np.MAXDIMS
565+
axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS"
548566
fail = sub["fail"]
549567

550568
return dedent(
@@ -558,7 +576,7 @@ def c_code(self, node, name, inp, out, sub):
558576
559577
int x_ndim = PyArray_NDIM({x});
560578
int axis = {axis};
561-
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
579+
int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1);
562580
563581
// Validate inputs
564582
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
@@ -568,13 +586,15 @@ def c_code(self, node, name, inp, out, sub):
568586
{fail}
569587
}}
570588
571-
if (axis < 0) axis = x_ndim + axis;
572-
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
589+
if (iterate_axis)
573590
{{
574-
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
575-
{fail}
591+
if (axis < 0) axis = x_ndim + axis;
592+
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
593+
{{
594+
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
595+
{fail}
596+
}}
576597
}}
577-
578598
// Allocate Output Array
579599
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
580600
{{
@@ -730,7 +750,7 @@ def c_code(self, node, name, inp, out, sub):
730750

731751
@staticmethod
732752
def c_code_cache_version():
733-
return (1,)
753+
return (2,)
734754

735755

736756
def log_softmax(c, axis=None):

pytensor/tensor/subtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.graph.utils import MethodNotDefined
1919
from pytensor.link.c.op import COp
2020
from pytensor.link.c.params_type import ParamsType
21-
from pytensor.npy_2_compat import numpy_version, using_numpy_2
21+
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2
2222
from pytensor.printing import Printer, pprint, set_precedence
2323
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
2424
from pytensor.tensor import (
@@ -2149,7 +2149,7 @@ def infer_shape(self, fgraph, node, ishapes):
21492149
def c_support_code(self, **kwargs):
21502150
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
21512151
# which is not defined. It should be NPY_MIN_LONG instead in that case.
2152-
return dedent(
2152+
return npy_2_compat_header() + dedent(
21532153
"""\
21542154
#ifndef MIN_LONG
21552155
#define MIN_LONG NPY_MIN_LONG
@@ -2174,7 +2174,7 @@ def c_code(self, node, name, input_names, output_names, sub):
21742174
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
21752175
PyArray_SIZE({i_name}) > 0) {{
21762176
npy_int64 min_val, max_val;
2177-
PyObject* py_min_val = PyArray_Min({i_name}, NPY_MAXDIMS,
2177+
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
21782178
NULL);
21792179
if (py_min_val == NULL) {{
21802180
{fail};
@@ -2184,7 +2184,7 @@ def c_code(self, node, name, input_names, output_names, sub):
21842184
if (min_val == -1 && PyErr_Occurred()) {{
21852185
{fail};
21862186
}}
2187-
PyObject* py_max_val = PyArray_Max({i_name}, NPY_MAXDIMS,
2187+
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
21882188
NULL);
21892189
if (py_max_val == NULL) {{
21902190
{fail};
@@ -2243,7 +2243,7 @@ def c_code(self, node, name, input_names, output_names, sub):
22432243
"""
22442244

22452245
def c_code_cache_version(self):
2246-
return (0, 1, 2)
2246+
return (0, 1, 2, 3)
22472247

22482248

22492249
advanced_subtensor1 = AdvancedSubtensor1()

0 commit comments

Comments
 (0)