Skip to content

Commit b3df02c

Browse files
committed
ENH: add _is_numeric attribute for DType classes
1 parent 891ab8e commit b3df02c

File tree

6 files changed

+31
-2
lines changed

6 files changed

+31
-2
lines changed

numpy/core/include/numpy/_dtype_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_
66
#define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_
77

8-
#define __EXPERIMENTAL_DTYPE_API_VERSION 7
8+
#define __EXPERIMENTAL_DTYPE_API_VERSION 8
99

1010
struct PyArrayMethodObject_tag;
1111

@@ -263,6 +263,7 @@ typedef int translate_loop_descrs_func(int nin, int nout,
263263

264264
#define NPY_DT_ABSTRACT 1 << 1
265265
#define NPY_DT_PARAMETRIC 1 << 2
266+
#define NPT_DT_NUMERIC 1 << 3
266267

267268
#define NPY_DT_discover_descr_from_pyobject 1
268269
#define _NPY_DT_is_known_scalar_type 2

numpy/core/src/multiarray/dtypemeta.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,10 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
899899
}
900900
}
901901

902+
if (PyTypeNum_ISNUMBER(descr->type_num)) {
903+
dtype_class->flags |= NPY_DT_NUMERIC;
904+
}
905+
902906
if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj,
903907
PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) {
904908
Py_DECREF(dtype_class);
@@ -927,13 +931,19 @@ dtypemeta_get_parametric(PyArray_DTypeMeta *self) {
927931
return PyBool_FromLong(NPY_DT_is_parametric(self));
928932
}
929933

934+
static PyObject *
935+
dtypemeta_get_is_numeric(PyArray_DTypeMeta *self) {
936+
return PyBool_FromLong(NPY_DT_is_numeric(self));
937+
}
938+
930939
/*
931940
* Simple exposed information, defined for each DType (class).
932941
*/
933942
static PyGetSetDef dtypemeta_getset[] = {
934943
{"_abstract", (getter)dtypemeta_get_abstract, NULL, NULL, NULL},
935944
{"_legacy", (getter)dtypemeta_get_legacy, NULL, NULL, NULL},
936945
{"_parametric", (getter)dtypemeta_get_parametric, NULL, NULL, NULL},
946+
{"_is_numeric", (getter)dtypemeta_get_is_numeric, NULL, NULL, NULL},
937947
{NULL, NULL, NULL, NULL, NULL}
938948
};
939949

numpy/core/src/multiarray/dtypemeta.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ extern "C" {
1111
#define NPY_DT_LEGACY 1 << 0
1212
#define NPY_DT_ABSTRACT 1 << 1
1313
#define NPY_DT_PARAMETRIC 1 << 2
14+
#define NPY_DT_NUMERIC 1 << 3
1415

1516

1617
typedef struct {
@@ -53,6 +54,7 @@ typedef struct {
5354
#define NPY_DT_is_legacy(dtype) (((dtype)->flags & NPY_DT_LEGACY) != 0)
5455
#define NPY_DT_is_abstract(dtype) (((dtype)->flags & NPY_DT_ABSTRACT) != 0)
5556
#define NPY_DT_is_parametric(dtype) (((dtype)->flags & NPY_DT_PARAMETRIC) != 0)
57+
#define NPY_DT_is_numeric(dtype) (((dtype)->flags & NPY_DT_NUMERIC) != 0)
5658
#define NPY_DT_is_user_defined(dtype) (((dtype)->type_num == -1))
5759

5860
/*

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ static PyArray_DTypeMeta PyArray_SFloatDType = {{{
248248
}},
249249
.type_num = -1,
250250
.scalar_type = NULL,
251-
.flags = NPY_DT_PARAMETRIC,
251+
.flags = NPY_DT_PARAMETRIC | NPY_DT_NUMERIC,
252252
.dt_slots = &sfloat_slots,
253253
};
254254

numpy/core/tests/test_custom_dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,7 @@ def test_type_pickle():
231231
assert res is SF
232232

233233
del np._ScaledFloatTestDType
234+
235+
236+
def test_is_numeric():
237+
assert SF._is_numeric

numpy/core/tests/test_dtype.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,18 @@ def test_dtype_superclass(self):
15791579
assert type(np.dtype).__module__ == "numpy"
15801580
assert np.dtype._abstract
15811581

1582+
def test_is_numeric(self):
1583+
all_codes = set(np.typecodes['All'])
1584+
numeric_codes = set(np.typecodes['AllInteger'] +
1585+
np.typecodes['AllFloat'] + '?')
1586+
non_numeric_codes = all_codes - numeric_codes
1587+
1588+
for code in numeric_codes:
1589+
assert type(np.dtype(code))._is_numeric
1590+
1591+
for code in non_numeric_codes:
1592+
assert not type(np.dtype(code))._is_numeric
1593+
15821594

15831595
class TestFromCTypes:
15841596

0 commit comments

Comments
 (0)