diff --git a/stringdtype/stringdtype/src/umath.c b/stringdtype/stringdtype/src/umath.c index 737250ef..022db85c 100644 --- a/stringdtype/stringdtype/src/umath.c +++ b/stringdtype/stringdtype/src/umath.c @@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors( * Copied from NumPy, because NumPy doesn't always use it :) */ static int -ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[], - PyArray_DTypeMeta *signature[], - PyArray_DTypeMeta *new_op_dtypes[], - PyArray_DTypeMeta *final_dtype) +string_inputs_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[], + PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[], + PyArray_DTypeMeta *final_dtype) { - /* If nin < 2 promotion is a no-op, so it should not be registered */ - assert(ufunc->nin > 1); - if (op_dtypes[0] == NULL) { - assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */ - Py_INCREF(op_dtypes[1]); - new_op_dtypes[0] = op_dtypes[1]; - Py_INCREF(op_dtypes[1]); - new_op_dtypes[1] = op_dtypes[1]; - Py_INCREF(op_dtypes[1]); - new_op_dtypes[2] = op_dtypes[1]; - return 0; - } - PyArray_DTypeMeta *common = NULL; - /* - * If a signature is used and homogeneous in its outputs use that - * (Could/should likely be rather applied to inputs also, although outs - * only could have some advantage and input dtypes are rarely enforced.) - */ - for (int i = ufunc->nin; i < ufunc->nargs; i++) { - if (signature[i] != NULL) { - if (common == NULL) { - Py_INCREF(signature[i]); - common = signature[i]; - } - else if (common != signature[i]) { - Py_CLEAR(common); /* Not homogeneous, unset common */ - break; - } - } - } - Py_XDECREF(common); - - /* Otherwise, set all input operands to final_dtype */ + /* set all input operands to final_dtype */ for (int i = 0; i < ufunc->nargs; i++) { PyArray_DTypeMeta *tmp = final_dtype; if (signature[i]) { @@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[], Py_INCREF(tmp); new_op_dtypes[i] = tmp; } + /* don't touch output dtypes */ for (int i = ufunc->nin; i < ufunc->nargs; i++) { Py_XINCREF(op_dtypes[i]); new_op_dtypes[i] = op_dtypes[i]; @@ -1140,9 +1109,9 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[]) { - return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes, - signature, new_op_dtypes, - (PyArray_DTypeMeta *)&PyArray_ObjectDType); + return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature, + new_op_dtypes, + (PyArray_DTypeMeta *)&PyArray_ObjectDType); } static int @@ -1150,9 +1119,40 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[]) { - return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes, - signature, new_op_dtypes, - (PyArray_DTypeMeta *)&StringDType); + return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature, + new_op_dtypes, + (PyArray_DTypeMeta *)&StringDType); +} + +static int +string_multiply_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *op_dtypes[], + PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[]) +{ + PyUFuncObject *ufunc = (PyUFuncObject *)ufunc_obj; + for (int i = 0; i < ufunc->nargs; i++) { + PyArray_DTypeMeta *tmp = NULL; + if (signature[i]) { + tmp = signature[i]; + } + else if (op_dtypes[i] == &PyArray_PyIntAbstractDType) { + tmp = &PyArray_Int64DType; + } + else if (op_dtypes[i]) { + tmp = op_dtypes[i]; + } + else { + tmp = (PyArray_DTypeMeta *)&StringDType; + } + Py_INCREF(tmp); + new_op_dtypes[i] = tmp; + } + /* don't touch output dtypes */ + for (int i = ufunc->nin; i < ufunc->nargs; i++) { + Py_XINCREF(op_dtypes[i]); + new_op_dtypes[i] = op_dtypes[i]; + } + return 0; } // Register a ufunc. @@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[], int init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes, resolve_descriptors_function *resolve_func, - PyArrayMethod_StridedLoop *loop_func, const char *loop_name, - int nin, int nout, NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags) + PyArrayMethod_StridedLoop *loop_func, int nin, int nout, + NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags) { PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name); if (ufunc == NULL) { return -1; } + char loop_name[256] = {0}; + + snprintf(loop_name, sizeof(loop_name), "string_%s", ufunc_name); + PyArrayMethod_Spec spec = { .name = loop_name, .nin = nin, @@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype, PyArray_DTypeMeta *edtype, promoter_function *promoter_impl) { - PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name); + PyObject *ufunc = PyObject_GetAttrString((PyObject *)numpy, ufunc_name); if (ufunc == NULL) { return -1; @@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name, \ if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \ &multiply_resolve_descriptors, \ - &multiply_right_##shortname##_strided_loop, \ - "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \ + &multiply_right_##shortname##_strided_loop, 2, 1, \ + NPY_NO_CASTING, 0) < 0) { \ goto error; \ } \ \ @@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name, \ if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \ &multiply_resolve_descriptors, \ - &multiply_left_##shortname##_strided_loop, \ - "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \ + &multiply_left_##shortname##_strided_loop, 2, 1, \ + NPY_NO_CASTING, 0) < 0) { \ goto error; \ } @@ -1279,53 +1283,23 @@ init_ufuncs(void) "greater", "greater_equal", "less", "less_equal"}; + static PyArrayMethod_StridedLoop *strided_loops[6] = { + &string_equal_strided_loop, &string_not_equal_strided_loop, + &string_greater_strided_loop, &string_greater_equal_strided_loop, + &string_less_strided_loop, &string_less_equal_strided_loop, + }; + PyArray_DTypeMeta *comparison_dtypes[] = { (PyArray_DTypeMeta *)&StringDType, (PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType}; - if (init_ufunc(numpy, "equal", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_equal_strided_loop, "string_equal", 2, 1, - NPY_NO_CASTING, 0) < 0) { - goto error; - } - - if (init_ufunc(numpy, "not_equal", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_not_equal_strided_loop, "string_not_equal", 2, 1, - NPY_NO_CASTING, 0) < 0) { - goto error; - } - - if (init_ufunc(numpy, "greater", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_greater_strided_loop, "string_greater", 2, 1, - NPY_NO_CASTING, 0) < 0) { - goto error; - } - - if (init_ufunc(numpy, "greater_equal", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_greater_equal_strided_loop, "string_greater_equal", - 2, 1, NPY_NO_CASTING, 0) < 0) { - goto error; - } - - if (init_ufunc(numpy, "less", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_less_strided_loop, "string_less", 2, 1, - NPY_NO_CASTING, 0) < 0) { - goto error; - } - - if (init_ufunc(numpy, "less_equal", comparison_dtypes, - &string_comparison_resolve_descriptors, - &string_less_equal_strided_loop, "string_less_equal", 2, 1, - NPY_NO_CASTING, 0) < 0) { - goto error; - } - for (int i = 0; i < 6; i++) { + if (init_ufunc(numpy, comparison_ufunc_names[i], comparison_dtypes, + &string_comparison_resolve_descriptors, + strided_loops[i], 2, 1, NPY_NO_CASTING, 0) < 0) { + goto error; + } + if (add_promoter(numpy, comparison_ufunc_names[i], (PyArray_DTypeMeta *)&StringDType, &PyArray_UnicodeDType, &PyArray_BoolDType, @@ -1360,8 +1334,7 @@ init_ufuncs(void) if (init_ufunc(numpy, "isnan", isnan_dtypes, &string_isnan_resolve_descriptors, - &string_isnan_strided_loop, "string_isnan", 1, 1, - NPY_NO_CASTING, 0) < 0) { + &string_isnan_strided_loop, 1, 1, NPY_NO_CASTING, 0) < 0) { goto error; } @@ -1372,20 +1345,17 @@ init_ufuncs(void) }; if (init_ufunc(numpy, "maximum", binary_dtypes, binary_resolve_descriptors, - &maximum_strided_loop, "string_maximum", 2, 1, - NPY_NO_CASTING, 0) < 0) { + &maximum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) { goto error; } if (init_ufunc(numpy, "minimum", binary_dtypes, binary_resolve_descriptors, - &minimum_strided_loop, "string_minimum", 2, 1, - NPY_NO_CASTING, 0) < 0) { + &minimum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) { goto error; } if (init_ufunc(numpy, "add", binary_dtypes, binary_resolve_descriptors, - &add_strided_loop, "string_add", 2, 1, NPY_NO_CASTING, - 0) < 0) { + &add_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) { goto error; } @@ -1414,6 +1384,20 @@ init_ufuncs(void) INIT_MULTIPLY(ULongLong, ulonglong); #endif + if (add_promoter(numpy, "multiply", (PyArray_DTypeMeta *)&StringDType, + &PyArray_PyIntAbstractDType, + (PyArray_DTypeMeta *)&StringDType, + string_multiply_promoter) < 0) { + goto error; + } + + if (add_promoter(numpy, "multiply", &PyArray_PyIntAbstractDType, + (PyArray_DTypeMeta *)&StringDType, + (PyArray_DTypeMeta *)&StringDType, + string_multiply_promoter) < 0) { + goto error; + } + Py_DECREF(numpy); return 0; diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index e1cc27a0..7ca928c8 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -643,6 +643,7 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out): @pytest.mark.parametrize( "other_dtype", [ + None, "int8", "int16", "int32", @@ -666,13 +667,17 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out): def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out): """Test the two-argument ufuncs match python builtin behavior.""" arr = np.array(string_list, dtype=dtype) - other_dtype = np.dtype(other_dtype) + if other_dtype is not None: + other_dtype = np.dtype(other_dtype) try: len(other) result = [s * o for s, o in zip(string_list, other)] - other = np.array(other, dtype=other_dtype) + other = np.array(other) + if other_dtype is not None: + other = other.astype(other_dtype) except TypeError: - other = other_dtype.type(other) + if other_dtype is not None: + other = other_dtype.type(other) result = [s * other for s in string_list] if use_out: @@ -702,7 +707,9 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out): try: len(other) - other = np.append(other, 3).astype(other_dtype) + other = np.append(other, 3) + if other_dtype is not None: + other = other.astype(other_dtype) except TypeError: pass @@ -714,7 +721,7 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out): else: try: assert res[-1] == dtype.na_object * other[-1] - except IndexError: + except (IndexError, TypeError): assert res[-1] == dtype.na_object * other else: with pytest.raises(TypeError): @@ -776,7 +783,6 @@ def test_null_roundtripping(dtype): assert data[1] == arr[1] -@pytest.mark.xfail(strict=True) def test_string_too_large_error(): arr = np.array(["a", "b", "c"], dtype=StringDType()) with pytest.raises(MemoryError):