Skip to content

Commit 1f2c42e

Browse files
authored
Merge pull request #96 from ngoldbaum/pyint-promoter
add a promoter for multiplying with a python int
2 parents 01b2245 + 35948ca commit 1f2c42e

File tree

2 files changed

+96
-106
lines changed

2 files changed

+96
-106
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 84 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors(
10811081
* Copied from NumPy, because NumPy doesn't always use it :)
10821082
*/
10831083
static int
1084-
ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1085-
PyArray_DTypeMeta *signature[],
1086-
PyArray_DTypeMeta *new_op_dtypes[],
1087-
PyArray_DTypeMeta *final_dtype)
1084+
string_inputs_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1085+
PyArray_DTypeMeta *signature[],
1086+
PyArray_DTypeMeta *new_op_dtypes[],
1087+
PyArray_DTypeMeta *final_dtype)
10881088
{
1089-
/* If nin < 2 promotion is a no-op, so it should not be registered */
1090-
assert(ufunc->nin > 1);
1091-
if (op_dtypes[0] == NULL) {
1092-
assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */
1093-
Py_INCREF(op_dtypes[1]);
1094-
new_op_dtypes[0] = op_dtypes[1];
1095-
Py_INCREF(op_dtypes[1]);
1096-
new_op_dtypes[1] = op_dtypes[1];
1097-
Py_INCREF(op_dtypes[1]);
1098-
new_op_dtypes[2] = op_dtypes[1];
1099-
return 0;
1100-
}
1101-
PyArray_DTypeMeta *common = NULL;
1102-
/*
1103-
* If a signature is used and homogeneous in its outputs use that
1104-
* (Could/should likely be rather applied to inputs also, although outs
1105-
* only could have some advantage and input dtypes are rarely enforced.)
1106-
*/
1107-
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
1108-
if (signature[i] != NULL) {
1109-
if (common == NULL) {
1110-
Py_INCREF(signature[i]);
1111-
common = signature[i];
1112-
}
1113-
else if (common != signature[i]) {
1114-
Py_CLEAR(common); /* Not homogeneous, unset common */
1115-
break;
1116-
}
1117-
}
1118-
}
1119-
Py_XDECREF(common);
1120-
1121-
/* Otherwise, set all input operands to final_dtype */
1089+
/* set all input operands to final_dtype */
11221090
for (int i = 0; i < ufunc->nargs; i++) {
11231091
PyArray_DTypeMeta *tmp = final_dtype;
11241092
if (signature[i]) {
@@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11271095
Py_INCREF(tmp);
11281096
new_op_dtypes[i] = tmp;
11291097
}
1098+
/* don't touch output dtypes */
11301099
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
11311100
Py_XINCREF(op_dtypes[i]);
11321101
new_op_dtypes[i] = op_dtypes[i];
@@ -1140,19 +1109,50 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11401109
PyArray_DTypeMeta *signature[],
11411110
PyArray_DTypeMeta *new_op_dtypes[])
11421111
{
1143-
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
1144-
signature, new_op_dtypes,
1145-
(PyArray_DTypeMeta *)&PyArray_ObjectDType);
1112+
return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature,
1113+
new_op_dtypes,
1114+
(PyArray_DTypeMeta *)&PyArray_ObjectDType);
11461115
}
11471116

11481117
static int
11491118
string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11501119
PyArray_DTypeMeta *signature[],
11511120
PyArray_DTypeMeta *new_op_dtypes[])
11521121
{
1153-
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
1154-
signature, new_op_dtypes,
1155-
(PyArray_DTypeMeta *)&StringDType);
1122+
return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature,
1123+
new_op_dtypes,
1124+
(PyArray_DTypeMeta *)&StringDType);
1125+
}
1126+
1127+
static int
1128+
string_multiply_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *op_dtypes[],
1129+
PyArray_DTypeMeta *signature[],
1130+
PyArray_DTypeMeta *new_op_dtypes[])
1131+
{
1132+
PyUFuncObject *ufunc = (PyUFuncObject *)ufunc_obj;
1133+
for (int i = 0; i < ufunc->nargs; i++) {
1134+
PyArray_DTypeMeta *tmp = NULL;
1135+
if (signature[i]) {
1136+
tmp = signature[i];
1137+
}
1138+
else if (op_dtypes[i] == &PyArray_PyIntAbstractDType) {
1139+
tmp = &PyArray_Int64DType;
1140+
}
1141+
else if (op_dtypes[i]) {
1142+
tmp = op_dtypes[i];
1143+
}
1144+
else {
1145+
tmp = (PyArray_DTypeMeta *)&StringDType;
1146+
}
1147+
Py_INCREF(tmp);
1148+
new_op_dtypes[i] = tmp;
1149+
}
1150+
/* don't touch output dtypes */
1151+
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
1152+
Py_XINCREF(op_dtypes[i]);
1153+
new_op_dtypes[i] = op_dtypes[i];
1154+
}
1155+
return 0;
11561156
}
11571157

11581158
// Register a ufunc.
@@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11611161
int
11621162
init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
11631163
resolve_descriptors_function *resolve_func,
1164-
PyArrayMethod_StridedLoop *loop_func, const char *loop_name,
1165-
int nin, int nout, NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
1164+
PyArrayMethod_StridedLoop *loop_func, int nin, int nout,
1165+
NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
11661166
{
11671167
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
11681168
if (ufunc == NULL) {
11691169
return -1;
11701170
}
11711171

1172+
char loop_name[256] = {0};
1173+
1174+
snprintf(loop_name, sizeof(loop_name), "string_%s", ufunc_name);
1175+
11721176
PyArrayMethod_Spec spec = {
11731177
.name = loop_name,
11741178
.nin = nin,
@@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12081212
PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype,
12091213
PyArray_DTypeMeta *edtype, promoter_function *promoter_impl)
12101214
{
1211-
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
1215+
PyObject *ufunc = PyObject_GetAttrString((PyObject *)numpy, ufunc_name);
12121216

12131217
if (ufunc == NULL) {
12141218
return -1;
@@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12511255
\
12521256
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
12531257
&multiply_resolve_descriptors, \
1254-
&multiply_right_##shortname##_strided_loop, \
1255-
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1258+
&multiply_right_##shortname##_strided_loop, 2, 1, \
1259+
NPY_NO_CASTING, 0) < 0) { \
12561260
goto error; \
12571261
} \
12581262
\
@@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12621266
\
12631267
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
12641268
&multiply_resolve_descriptors, \
1265-
&multiply_left_##shortname##_strided_loop, \
1266-
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1269+
&multiply_left_##shortname##_strided_loop, 2, 1, \
1270+
NPY_NO_CASTING, 0) < 0) { \
12671271
goto error; \
12681272
}
12691273

@@ -1279,53 +1283,23 @@ init_ufuncs(void)
12791283
"greater", "greater_equal",
12801284
"less", "less_equal"};
12811285

1286+
static PyArrayMethod_StridedLoop *strided_loops[6] = {
1287+
&string_equal_strided_loop, &string_not_equal_strided_loop,
1288+
&string_greater_strided_loop, &string_greater_equal_strided_loop,
1289+
&string_less_strided_loop, &string_less_equal_strided_loop,
1290+
};
1291+
12821292
PyArray_DTypeMeta *comparison_dtypes[] = {
12831293
(PyArray_DTypeMeta *)&StringDType,
12841294
(PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType};
12851295

1286-
if (init_ufunc(numpy, "equal", comparison_dtypes,
1287-
&string_comparison_resolve_descriptors,
1288-
&string_equal_strided_loop, "string_equal", 2, 1,
1289-
NPY_NO_CASTING, 0) < 0) {
1290-
goto error;
1291-
}
1292-
1293-
if (init_ufunc(numpy, "not_equal", comparison_dtypes,
1294-
&string_comparison_resolve_descriptors,
1295-
&string_not_equal_strided_loop, "string_not_equal", 2, 1,
1296-
NPY_NO_CASTING, 0) < 0) {
1297-
goto error;
1298-
}
1299-
1300-
if (init_ufunc(numpy, "greater", comparison_dtypes,
1301-
&string_comparison_resolve_descriptors,
1302-
&string_greater_strided_loop, "string_greater", 2, 1,
1303-
NPY_NO_CASTING, 0) < 0) {
1304-
goto error;
1305-
}
1306-
1307-
if (init_ufunc(numpy, "greater_equal", comparison_dtypes,
1308-
&string_comparison_resolve_descriptors,
1309-
&string_greater_equal_strided_loop, "string_greater_equal",
1310-
2, 1, NPY_NO_CASTING, 0) < 0) {
1311-
goto error;
1312-
}
1313-
1314-
if (init_ufunc(numpy, "less", comparison_dtypes,
1315-
&string_comparison_resolve_descriptors,
1316-
&string_less_strided_loop, "string_less", 2, 1,
1317-
NPY_NO_CASTING, 0) < 0) {
1318-
goto error;
1319-
}
1320-
1321-
if (init_ufunc(numpy, "less_equal", comparison_dtypes,
1322-
&string_comparison_resolve_descriptors,
1323-
&string_less_equal_strided_loop, "string_less_equal", 2, 1,
1324-
NPY_NO_CASTING, 0) < 0) {
1325-
goto error;
1326-
}
1327-
13281296
for (int i = 0; i < 6; i++) {
1297+
if (init_ufunc(numpy, comparison_ufunc_names[i], comparison_dtypes,
1298+
&string_comparison_resolve_descriptors,
1299+
strided_loops[i], 2, 1, NPY_NO_CASTING, 0) < 0) {
1300+
goto error;
1301+
}
1302+
13291303
if (add_promoter(numpy, comparison_ufunc_names[i],
13301304
(PyArray_DTypeMeta *)&StringDType,
13311305
&PyArray_UnicodeDType, &PyArray_BoolDType,
@@ -1360,8 +1334,7 @@ init_ufuncs(void)
13601334

13611335
if (init_ufunc(numpy, "isnan", isnan_dtypes,
13621336
&string_isnan_resolve_descriptors,
1363-
&string_isnan_strided_loop, "string_isnan", 1, 1,
1364-
NPY_NO_CASTING, 0) < 0) {
1337+
&string_isnan_strided_loop, 1, 1, NPY_NO_CASTING, 0) < 0) {
13651338
goto error;
13661339
}
13671340

@@ -1372,20 +1345,17 @@ init_ufuncs(void)
13721345
};
13731346

13741347
if (init_ufunc(numpy, "maximum", binary_dtypes, binary_resolve_descriptors,
1375-
&maximum_strided_loop, "string_maximum", 2, 1,
1376-
NPY_NO_CASTING, 0) < 0) {
1348+
&maximum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
13771349
goto error;
13781350
}
13791351

13801352
if (init_ufunc(numpy, "minimum", binary_dtypes, binary_resolve_descriptors,
1381-
&minimum_strided_loop, "string_minimum", 2, 1,
1382-
NPY_NO_CASTING, 0) < 0) {
1353+
&minimum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
13831354
goto error;
13841355
}
13851356

13861357
if (init_ufunc(numpy, "add", binary_dtypes, binary_resolve_descriptors,
1387-
&add_strided_loop, "string_add", 2, 1, NPY_NO_CASTING,
1388-
0) < 0) {
1358+
&add_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
13891359
goto error;
13901360
}
13911361

@@ -1414,6 +1384,20 @@ init_ufuncs(void)
14141384
INIT_MULTIPLY(ULongLong, ulonglong);
14151385
#endif
14161386

1387+
if (add_promoter(numpy, "multiply", (PyArray_DTypeMeta *)&StringDType,
1388+
&PyArray_PyIntAbstractDType,
1389+
(PyArray_DTypeMeta *)&StringDType,
1390+
string_multiply_promoter) < 0) {
1391+
goto error;
1392+
}
1393+
1394+
if (add_promoter(numpy, "multiply", &PyArray_PyIntAbstractDType,
1395+
(PyArray_DTypeMeta *)&StringDType,
1396+
(PyArray_DTypeMeta *)&StringDType,
1397+
string_multiply_promoter) < 0) {
1398+
goto error;
1399+
}
1400+
14171401
Py_DECREF(numpy);
14181402
return 0;
14191403

stringdtype/tests/test_stringdtype.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out):
643643
@pytest.mark.parametrize(
644644
"other_dtype",
645645
[
646+
None,
646647
"int8",
647648
"int16",
648649
"int32",
@@ -666,13 +667,17 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out):
666667
def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
667668
"""Test the two-argument ufuncs match python builtin behavior."""
668669
arr = np.array(string_list, dtype=dtype)
669-
other_dtype = np.dtype(other_dtype)
670+
if other_dtype is not None:
671+
other_dtype = np.dtype(other_dtype)
670672
try:
671673
len(other)
672674
result = [s * o for s, o in zip(string_list, other)]
673-
other = np.array(other, dtype=other_dtype)
675+
other = np.array(other)
676+
if other_dtype is not None:
677+
other = other.astype(other_dtype)
674678
except TypeError:
675-
other = other_dtype.type(other)
679+
if other_dtype is not None:
680+
other = other_dtype.type(other)
676681
result = [s * other for s in string_list]
677682

678683
if use_out:
@@ -702,7 +707,9 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
702707

703708
try:
704709
len(other)
705-
other = np.append(other, 3).astype(other_dtype)
710+
other = np.append(other, 3)
711+
if other_dtype is not None:
712+
other = other.astype(other_dtype)
706713
except TypeError:
707714
pass
708715

@@ -714,7 +721,7 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
714721
else:
715722
try:
716723
assert res[-1] == dtype.na_object * other[-1]
717-
except IndexError:
724+
except (IndexError, TypeError):
718725
assert res[-1] == dtype.na_object * other
719726
else:
720727
with pytest.raises(TypeError):
@@ -776,7 +783,6 @@ def test_null_roundtripping(dtype):
776783
assert data[1] == arr[1]
777784

778785

779-
@pytest.mark.xfail(strict=True)
780786
def test_string_too_large_error():
781787
arr = np.array(["a", "b", "c"], dtype=StringDType())
782788
with pytest.raises(MemoryError):

0 commit comments

Comments
 (0)