Skip to content

Commit b15e80a

Browse files
authored
Merge pull request #43 from ngoldbaum/is-numeric
Apply NPY_DT_NUMERIC flag where appropriate
2 parents 05de73b + 16fb7e9 commit b15e80a

File tree

11 files changed

+83
-70
lines changed

11 files changed

+83
-70
lines changed

.pre-commit-config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,20 @@ repos:
102102
rev: 22.12.0
103103
hooks:
104104
- id: black
105+
name: 'black for asciidtype'
106+
files: '^asciidtype/.*\.py'
107+
- id: black
108+
name: 'black for metadatadtype'
109+
files: '^metadatadtype/.*\.py'
110+
- id: black
111+
name: 'black for mpfdtype'
112+
files: '^mpfdtype/.*\.py'
113+
- id: black
114+
name: 'black for quaddtype'
115+
files: '^quaddtype/.*\.py'
116+
- id: black
117+
name: 'black for stringdtype'
118+
files: '^stringdtype/.*\.py'
119+
- id: black
120+
name: 'black for unytdtype'
121+
files: '^unytdtype/.*\.py'

asciidtype/tests/test_asciidtype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,7 @@ def test_pickle():
248248
assert res[1] == dtype
249249

250250
os.remove(f.name)
251+
252+
253+
def test_is_numeric():
254+
assert not ASCIIDType._is_numeric

metadatadtype/metadatadtype/src/dtype.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ init_metadata_dtype(void)
257257
PyArrayMethod_Spec **casts = get_casts();
258258

259259
PyArrayDTypeMeta_Spec MetadataDType_DTypeSpec = {
260-
.flags = NPY_DT_PARAMETRIC,
260+
.flags = NPY_DT_PARAMETRIC | NPY_DT_NUMERIC,
261261
.casts = casts,
262262
.typeobj = MetadataScalar_Type,
263263
.slots = MetadataDType_Slots,

metadatadtype/tests/test_metadatadtype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,9 @@ def test_cast_to_float64():
5050
dtype = MetadataDType("test")
5151
scalar = MetadataScalar(1, dtype)
5252
arr = np.array([scalar, scalar, scalar])
53-
conv = arr.astype('float64')
53+
conv = arr.astype("float64")
5454
assert str(conv) == "[1. 1. 1.]"
55+
56+
57+
def test_is_numeric():
58+
assert MetadataDType._is_numeric

mpfdtype/mpfdtype/src/dtype.c

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
#include "casts.h"
1414
#include "dtype.h"
1515

16-
17-
1816
/*
1917
* Internal helper to create new instances.
2018
*/
@@ -27,9 +25,8 @@ new_MPFDType_instance(mpfr_prec_t precision)
2725
* set in that case.
2826
*/
2927
if (precision < MPFR_PREC_MIN || precision > MPFR_PREC_MAX) {
30-
PyErr_Format(PyExc_ValueError,
31-
"precision must be between %d and %d.",
32-
MPFR_PREC_MIN, MPFR_PREC_MAX);
28+
PyErr_Format(PyExc_ValueError, "precision must be between %d and %d.", MPFR_PREC_MIN,
29+
MPFR_PREC_MAX);
3330
return NULL;
3431
}
3532

@@ -43,7 +40,7 @@ new_MPFDType_instance(mpfr_prec_t precision)
4340
size_t size = mpfr_custom_get_size(precision);
4441
if (size > NPY_MAX_INT - sizeof(mpf_field)) {
4542
PyErr_SetString(PyExc_TypeError,
46-
"storage of single float would be too large for precision.");
43+
"storage of single float would be too large for precision.");
4744
}
4845
new->base.elsize = sizeof(mpf_storage) + size;
4946
new->base.alignment = _Alignof(mpf_field);
@@ -52,15 +49,13 @@ new_MPFDType_instance(mpfr_prec_t precision)
5249
return new;
5350
}
5451

55-
5652
static MPFDTypeObject *
5753
ensure_canonical(MPFDTypeObject *self)
5854
{
5955
Py_INCREF(self);
6056
return self;
6157
}
6258

63-
6459
static MPFDTypeObject *
6560
common_instance(MPFDTypeObject *dtype1, MPFDTypeObject *dtype2)
6661
{
@@ -74,17 +69,15 @@ common_instance(MPFDTypeObject *dtype1, MPFDTypeObject *dtype2)
7469
}
7570
}
7671

77-
7872
static PyArray_DTypeMeta *
7973
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
8074
{
8175
/*
8276
* Typenum is useful for NumPy, but there it can still be convenient.
8377
* (New-style user dtypes will probably get -1 as type number...)
8478
*/
85-
if (other->type_num >= 0
86-
&& PyTypeNum_ISNUMBER(other->type_num)
87-
&& !PyTypeNum_ISCOMPLEX(other->type_num)) {
79+
if (other->type_num >= 0 && PyTypeNum_ISNUMBER(other->type_num) &&
80+
!PyTypeNum_ISCOMPLEX(other->type_num)) {
8881
/*
8982
* A (simple) builtin numeric type (not complex) promotes to fixed
9083
* precision.
@@ -96,18 +89,15 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
9689
return (PyArray_DTypeMeta *)Py_NotImplemented;
9790
}
9891

99-
10092
/*
10193
* Functions dealing with scalar logic
10294
*/
10395

10496
static PyArray_Descr *
105-
mpf_discover_descriptor_from_pyobject(
106-
PyArray_DTypeMeta *NPY_UNUSED(cls), PyObject *obj)
97+
mpf_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), PyObject *obj)
10798
{
10899
if (Py_TYPE(obj) != &MPFloat_Type) {
109-
PyErr_SetString(PyExc_TypeError,
110-
"Can only store MPFloat in a MPFDType array.");
100+
PyErr_SetString(PyExc_TypeError, "Can only store MPFloat in a MPFDType array.");
111101
return NULL;
112102
}
113103
mpfr_prec_t prec = get_prec_from_object(obj);
@@ -117,7 +107,6 @@ mpf_discover_descriptor_from_pyobject(
117107
return (PyArray_Descr *)new_MPFDType_instance(prec);
118108
}
119109

120-
121110
static int
122111
mpf_setitem(MPFDTypeObject *descr, PyObject *obj, char *dataptr)
123112
{
@@ -167,18 +156,14 @@ mpf_getitem(MPFDTypeObject *descr, char *dataptr)
167156
return (PyObject *)new;
168157
}
169158

170-
171159
static PyType_Slot MPFDType_Slots[] = {
172-
{NPY_DT_ensure_canonical, &ensure_canonical},
173-
{NPY_DT_common_instance, &common_instance},
174-
{NPY_DT_common_dtype, &common_dtype},
175-
{NPY_DT_discover_descr_from_pyobject,
176-
&mpf_discover_descriptor_from_pyobject},
177-
{NPY_DT_setitem, &mpf_setitem},
178-
{NPY_DT_getitem, &mpf_getitem},
179-
{0, NULL}
180-
};
181-
160+
{NPY_DT_ensure_canonical, &ensure_canonical},
161+
{NPY_DT_common_instance, &common_instance},
162+
{NPY_DT_common_dtype, &common_dtype},
163+
{NPY_DT_discover_descr_from_pyobject, &mpf_discover_descriptor_from_pyobject},
164+
{NPY_DT_setitem, &mpf_setitem},
165+
{NPY_DT_getitem, &mpf_getitem},
166+
{0, NULL}};
182167

183168
/*
184169
* The following defines everything type object related (i.e. not NumPy
@@ -195,59 +180,49 @@ MPFDType_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
195180

196181
Py_ssize_t precision;
197182

198-
if (!PyArg_ParseTupleAndKeywords(
199-
args, kwds, "n:MPFDType", kwargs_strs, &precision)) {
183+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "n:MPFDType", kwargs_strs, &precision)) {
200184
return NULL;
201185
}
202186

203187
return (PyObject *)new_MPFDType_instance(precision);
204188
}
205189

206-
207190
static PyObject *
208191
MPFDType_repr(MPFDTypeObject *self)
209192
{
210-
PyObject *res = PyUnicode_FromFormat(
211-
"MPFDType(%ld)", (long)self->precision);
193+
PyObject *res = PyUnicode_FromFormat("MPFDType(%ld)", (long)self->precision);
212194
return res;
213195
}
214196

215-
216197
PyObject *
217198
MPFDType_get_prec(MPFDTypeObject *self)
218199
{
219200
return PyLong_FromLong(self->precision);
220201
}
221202

222-
223203
NPY_NO_EXPORT PyGetSetDef mpfdtype_getsetlist[] = {
224-
{"prec",
225-
(getter)MPFDType_get_prec,
226-
NULL,
227-
NULL, NULL},
228-
{NULL, NULL, NULL, NULL, NULL}, /* Sentinel */
204+
{"prec", (getter)MPFDType_get_prec, NULL, NULL, NULL},
205+
{NULL, NULL, NULL, NULL, NULL}, /* Sentinel */
229206
};
230207

231-
232208
/*
233209
* This is the basic things that you need to create a Python Type/Class in C.
234210
* However, there is a slight difference here because we create a
235211
* PyArray_DTypeMeta, which is a larger struct than a typical type.
236212
* (This should get a bit nicer eventually with Python >3.11.)
237213
*/
238-
PyArray_DTypeMeta MPFDType = {{{
239-
PyVarObject_HEAD_INIT(NULL, 0)
240-
.tp_name = "MPFDType.MPFDType",
241-
.tp_basicsize = sizeof(MPFDTypeObject),
242-
.tp_new = MPFDType_new,
243-
.tp_repr = (reprfunc)MPFDType_repr,
244-
.tp_str = (reprfunc)MPFDType_repr,
245-
.tp_getset = mpfdtype_getsetlist,
246-
}},
247-
/* rest, filled in during DTypeMeta initialization */
214+
PyArray_DTypeMeta MPFDType = {
215+
{{
216+
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "MPFDType.MPFDType",
217+
.tp_basicsize = sizeof(MPFDTypeObject),
218+
.tp_new = MPFDType_new,
219+
.tp_repr = (reprfunc)MPFDType_repr,
220+
.tp_str = (reprfunc)MPFDType_repr,
221+
.tp_getset = mpfdtype_getsetlist,
222+
}},
223+
/* rest, filled in during DTypeMeta initialization */
248224
};
249225

250-
251226
int
252227
init_mpf_dtype(void)
253228
{
@@ -258,7 +233,7 @@ init_mpf_dtype(void)
258233
PyArrayMethod_Spec **casts = init_casts();
259234

260235
PyArrayDTypeMeta_Spec MPFDType_DTypeSpec = {
261-
.flags = NPY_DT_PARAMETRIC,
236+
.flags = NPY_DT_PARAMETRIC | NPY_DT_NUMERIC,
262237
.casts = casts,
263238
.typeobj = &MPFloat_Type,
264239
.slots = MPFDType_Slots,
@@ -271,8 +246,7 @@ init_mpf_dtype(void)
271246
return -1;
272247
}
273248

274-
if (PyArrayInitDTypeMeta_FromSpec(
275-
&MPFDType, &MPFDType_DTypeSpec) < 0) {
249+
if (PyArrayInitDTypeMeta_FromSpec(&MPFDType, &MPFDType_DTypeSpec) < 0) {
276250
free_casts();
277251
return -1;
278252
}

mpfdtype/mpfdtype/tests/test_array.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import pytest
2-
3-
import sys
41
import numpy as np
5-
import operator
62
from numpy.testing import assert_array_equal
73

8-
from mpfdtype import MPFDType, MPFloat
4+
from mpfdtype import MPFDType
95

106

117
def test_advanced_indexing():
@@ -16,3 +12,7 @@ def test_advanced_indexing():
1612
b = arr[[1, 2, 3, 4]]
1713
b[...] = 5 # does not mutate arr (internal references not broken)
1814
assert_array_equal(arr, orig)
15+
16+
17+
def test_is_numeric():
18+
assert MPFDType._is_numeric

quaddtype/quaddtype/src/dtype.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ init_quad_dtype(void)
164164
};
165165

166166
PyArrayDTypeMeta_Spec QuadDType_DTypeSpec = {
167+
.flags = NPY_DT_NUMERIC,
167168
.casts = casts,
168169
.typeobj = QuadScalar_Type,
169170
.slots = QuadDType_Slots,

quaddtype/tests/test_quaddtype.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@ def test_scalar_creation():
1212

1313

1414
def test_create_with_explicit_dtype():
15-
assert repr(
16-
np.array([3.0, 3.1, 3.2], dtype=QuadDType())
17-
) == "array([3.0, 3.1, 3.2], dtype=This is a quad (128-bit float) dtype.)"
15+
assert (
16+
repr(np.array([3.0, 3.1, 3.2], dtype=QuadDType()))
17+
== "array([3.0, 3.1, 3.2], dtype=This is a quad (128-bit float) dtype.)"
18+
)
1819

1920

2021
def test_multiply():
2122
x = np.array([3, 8.0], dtype=QuadDType())
22-
assert str(x * x) == '[9.0 64.0]'
23+
assert str(x * x) == "[9.0 64.0]"
2324

2425

2526
def test_bytes():
2627
"""Check that each quad is 16 bytes."""
2728
x = np.array([3, 8.0, 1.4], dtype=QuadDType())
2829
assert len(x.tobytes()) == x.size * 16
30+
31+
32+
def test_is_numeric():
33+
assert QuadDType._is_numeric

stringdtype/tests/test_stringdtype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,7 @@ def test_creation_functions():
194194

195195
# make sure getitem works too
196196
assert np.empty(3, dtype=StringDType())[0] == ""
197+
198+
199+
def test_is_numeric():
200+
assert not StringDType._is_numeric

unytdtype/tests/test_unytdtype.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_dtype_creation():
88
dtype = UnytDType("m")
99
assert str(dtype) == "UnytDType('m')"
1010

11-
dtype2 = UnytDType(unyt.Unit('m'))
11+
dtype2 = UnytDType(unyt.Unit("m"))
1212
assert str(dtype2) == "UnytDType('m')"
1313
assert dtype == dtype2
1414

@@ -69,5 +69,9 @@ def test_insert_with_different_unit():
6969
def test_cast_to_float64():
7070
meter = UnytScalar(1, unyt.m)
7171
arr = np.array([meter, meter, meter])
72-
conv = arr.astype('float64')
72+
conv = arr.astype("float64")
7373
assert str(conv) == "[1. 1. 1.]"
74+
75+
76+
def test_is_numeric():
77+
assert UnytDType._is_numeric

unytdtype/unytdtype/src/dtype.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ init_unyt_dtype(void)
280280
PyArrayMethod_Spec **casts = get_casts();
281281

282282
PyArrayDTypeMeta_Spec UnytDType_DTypeSpec = {
283-
.flags = NPY_DT_PARAMETRIC,
283+
.flags = (NPY_DT_PARAMETRIC | NPY_DT_NUMERIC),
284284
.casts = casts,
285285
.typeobj = UnytScalar_Type,
286286
.slots = UnytDType_Slots,

0 commit comments

Comments
 (0)