@@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors(
1081
1081
* Copied from NumPy, because NumPy doesn't always use it :)
1082
1082
*/
1083
1083
static int
1084
- ufunc_promoter_internal (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
1085
- PyArray_DTypeMeta * signature [],
1086
- PyArray_DTypeMeta * new_op_dtypes [],
1087
- PyArray_DTypeMeta * final_dtype )
1084
+ string_inputs_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
1085
+ PyArray_DTypeMeta * signature [],
1086
+ PyArray_DTypeMeta * new_op_dtypes [],
1087
+ PyArray_DTypeMeta * final_dtype )
1088
1088
{
1089
- /* If nin < 2 promotion is a no-op, so it should not be registered */
1090
- assert (ufunc -> nin > 1 );
1091
- if (op_dtypes [0 ] == NULL ) {
1092
- assert (ufunc -> nin == 2 && ufunc -> nout == 1 ); /* must be reduction */
1093
- Py_INCREF (op_dtypes [1 ]);
1094
- new_op_dtypes [0 ] = op_dtypes [1 ];
1095
- Py_INCREF (op_dtypes [1 ]);
1096
- new_op_dtypes [1 ] = op_dtypes [1 ];
1097
- Py_INCREF (op_dtypes [1 ]);
1098
- new_op_dtypes [2 ] = op_dtypes [1 ];
1099
- return 0 ;
1100
- }
1101
- PyArray_DTypeMeta * common = NULL ;
1102
- /*
1103
- * If a signature is used and homogeneous in its outputs use that
1104
- * (Could/should likely be rather applied to inputs also, although outs
1105
- * only could have some advantage and input dtypes are rarely enforced.)
1106
- */
1107
- for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
1108
- if (signature [i ] != NULL ) {
1109
- if (common == NULL ) {
1110
- Py_INCREF (signature [i ]);
1111
- common = signature [i ];
1112
- }
1113
- else if (common != signature [i ]) {
1114
- Py_CLEAR (common ); /* Not homogeneous, unset common */
1115
- break ;
1116
- }
1117
- }
1118
- }
1119
- Py_XDECREF (common );
1120
-
1121
- /* Otherwise, set all input operands to final_dtype */
1089
+ /* set all input operands to final_dtype */
1122
1090
for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
1123
1091
PyArray_DTypeMeta * tmp = final_dtype ;
1124
1092
if (signature [i ]) {
@@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1127
1095
Py_INCREF (tmp );
1128
1096
new_op_dtypes [i ] = tmp ;
1129
1097
}
1098
+ /* don't touch output dtypes */
1130
1099
for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
1131
1100
Py_XINCREF (op_dtypes [i ]);
1132
1101
new_op_dtypes [i ] = op_dtypes [i ];
@@ -1140,19 +1109,50 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1140
1109
PyArray_DTypeMeta * signature [],
1141
1110
PyArray_DTypeMeta * new_op_dtypes [])
1142
1111
{
1143
- return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
1144
- signature , new_op_dtypes ,
1145
- (PyArray_DTypeMeta * )& PyArray_ObjectDType );
1112
+ return string_inputs_promoter ((PyUFuncObject * )ufunc , op_dtypes , signature ,
1113
+ new_op_dtypes ,
1114
+ (PyArray_DTypeMeta * )& PyArray_ObjectDType );
1146
1115
}
1147
1116
1148
1117
static int
1149
1118
string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
1150
1119
PyArray_DTypeMeta * signature [],
1151
1120
PyArray_DTypeMeta * new_op_dtypes [])
1152
1121
{
1153
- return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
1154
- signature , new_op_dtypes ,
1155
- (PyArray_DTypeMeta * )& StringDType );
1122
+ return string_inputs_promoter ((PyUFuncObject * )ufunc , op_dtypes , signature ,
1123
+ new_op_dtypes ,
1124
+ (PyArray_DTypeMeta * )& StringDType );
1125
+ }
1126
+
1127
+ static int
1128
+ string_multiply_promoter (PyObject * ufunc_obj , PyArray_DTypeMeta * op_dtypes [],
1129
+ PyArray_DTypeMeta * signature [],
1130
+ PyArray_DTypeMeta * new_op_dtypes [])
1131
+ {
1132
+ PyUFuncObject * ufunc = (PyUFuncObject * )ufunc_obj ;
1133
+ for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
1134
+ PyArray_DTypeMeta * tmp = NULL ;
1135
+ if (signature [i ]) {
1136
+ tmp = signature [i ];
1137
+ }
1138
+ else if (op_dtypes [i ] == & PyArray_PyIntAbstractDType ) {
1139
+ tmp = & PyArray_Int64DType ;
1140
+ }
1141
+ else if (op_dtypes [i ]) {
1142
+ tmp = op_dtypes [i ];
1143
+ }
1144
+ else {
1145
+ tmp = (PyArray_DTypeMeta * )& StringDType ;
1146
+ }
1147
+ Py_INCREF (tmp );
1148
+ new_op_dtypes [i ] = tmp ;
1149
+ }
1150
+ /* don't touch output dtypes */
1151
+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
1152
+ Py_XINCREF (op_dtypes [i ]);
1153
+ new_op_dtypes [i ] = op_dtypes [i ];
1154
+ }
1155
+ return 0 ;
1156
1156
}
1157
1157
1158
1158
// Register a ufunc.
@@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1161
1161
int
1162
1162
init_ufunc (PyObject * numpy , const char * ufunc_name , PyArray_DTypeMeta * * dtypes ,
1163
1163
resolve_descriptors_function * resolve_func ,
1164
- PyArrayMethod_StridedLoop * loop_func , const char * loop_name ,
1165
- int nin , int nout , NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
1164
+ PyArrayMethod_StridedLoop * loop_func , int nin , int nout ,
1165
+ NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
1166
1166
{
1167
1167
PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
1168
1168
if (ufunc == NULL ) {
1169
1169
return -1 ;
1170
1170
}
1171
1171
1172
+ char loop_name [256 ] = {0 };
1173
+
1174
+ snprintf (loop_name , sizeof (loop_name ), "string_%s" , ufunc_name );
1175
+
1172
1176
PyArrayMethod_Spec spec = {
1173
1177
.name = loop_name ,
1174
1178
.nin = nin ,
@@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1208
1212
PyArray_DTypeMeta * ldtype , PyArray_DTypeMeta * rdtype ,
1209
1213
PyArray_DTypeMeta * edtype , promoter_function * promoter_impl )
1210
1214
{
1211
- PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
1215
+ PyObject * ufunc = PyObject_GetAttrString (( PyObject * ) numpy , ufunc_name );
1212
1216
1213
1217
if (ufunc == NULL ) {
1214
1218
return -1 ;
@@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1251
1255
\
1252
1256
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
1253
1257
&multiply_resolve_descriptors, \
1254
- &multiply_right_##shortname##_strided_loop, \
1255
- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1258
+ &multiply_right_##shortname##_strided_loop, 2, 1, \
1259
+ NPY_NO_CASTING, 0) < 0) { \
1256
1260
goto error; \
1257
1261
} \
1258
1262
\
@@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1262
1266
\
1263
1267
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
1264
1268
&multiply_resolve_descriptors, \
1265
- &multiply_left_##shortname##_strided_loop, \
1266
- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1269
+ &multiply_left_##shortname##_strided_loop, 2, 1, \
1270
+ NPY_NO_CASTING, 0) < 0) { \
1267
1271
goto error; \
1268
1272
}
1269
1273
@@ -1279,53 +1283,23 @@ init_ufuncs(void)
1279
1283
"greater" , "greater_equal" ,
1280
1284
"less" , "less_equal" };
1281
1285
1286
+ static PyArrayMethod_StridedLoop * strided_loops [6 ] = {
1287
+ & string_equal_strided_loop , & string_not_equal_strided_loop ,
1288
+ & string_greater_strided_loop , & string_greater_equal_strided_loop ,
1289
+ & string_less_strided_loop , & string_less_equal_strided_loop ,
1290
+ };
1291
+
1282
1292
PyArray_DTypeMeta * comparison_dtypes [] = {
1283
1293
(PyArray_DTypeMeta * )& StringDType ,
1284
1294
(PyArray_DTypeMeta * )& StringDType , & PyArray_BoolDType };
1285
1295
1286
- if (init_ufunc (numpy , "equal" , comparison_dtypes ,
1287
- & string_comparison_resolve_descriptors ,
1288
- & string_equal_strided_loop , "string_equal" , 2 , 1 ,
1289
- NPY_NO_CASTING , 0 ) < 0 ) {
1290
- goto error ;
1291
- }
1292
-
1293
- if (init_ufunc (numpy , "not_equal" , comparison_dtypes ,
1294
- & string_comparison_resolve_descriptors ,
1295
- & string_not_equal_strided_loop , "string_not_equal" , 2 , 1 ,
1296
- NPY_NO_CASTING , 0 ) < 0 ) {
1297
- goto error ;
1298
- }
1299
-
1300
- if (init_ufunc (numpy , "greater" , comparison_dtypes ,
1301
- & string_comparison_resolve_descriptors ,
1302
- & string_greater_strided_loop , "string_greater" , 2 , 1 ,
1303
- NPY_NO_CASTING , 0 ) < 0 ) {
1304
- goto error ;
1305
- }
1306
-
1307
- if (init_ufunc (numpy , "greater_equal" , comparison_dtypes ,
1308
- & string_comparison_resolve_descriptors ,
1309
- & string_greater_equal_strided_loop , "string_greater_equal" ,
1310
- 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1311
- goto error ;
1312
- }
1313
-
1314
- if (init_ufunc (numpy , "less" , comparison_dtypes ,
1315
- & string_comparison_resolve_descriptors ,
1316
- & string_less_strided_loop , "string_less" , 2 , 1 ,
1317
- NPY_NO_CASTING , 0 ) < 0 ) {
1318
- goto error ;
1319
- }
1320
-
1321
- if (init_ufunc (numpy , "less_equal" , comparison_dtypes ,
1322
- & string_comparison_resolve_descriptors ,
1323
- & string_less_equal_strided_loop , "string_less_equal" , 2 , 1 ,
1324
- NPY_NO_CASTING , 0 ) < 0 ) {
1325
- goto error ;
1326
- }
1327
-
1328
1296
for (int i = 0 ; i < 6 ; i ++ ) {
1297
+ if (init_ufunc (numpy , comparison_ufunc_names [i ], comparison_dtypes ,
1298
+ & string_comparison_resolve_descriptors ,
1299
+ strided_loops [i ], 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1300
+ goto error ;
1301
+ }
1302
+
1329
1303
if (add_promoter (numpy , comparison_ufunc_names [i ],
1330
1304
(PyArray_DTypeMeta * )& StringDType ,
1331
1305
& PyArray_UnicodeDType , & PyArray_BoolDType ,
@@ -1360,8 +1334,7 @@ init_ufuncs(void)
1360
1334
1361
1335
if (init_ufunc (numpy , "isnan" , isnan_dtypes ,
1362
1336
& string_isnan_resolve_descriptors ,
1363
- & string_isnan_strided_loop , "string_isnan" , 1 , 1 ,
1364
- NPY_NO_CASTING , 0 ) < 0 ) {
1337
+ & string_isnan_strided_loop , 1 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1365
1338
goto error ;
1366
1339
}
1367
1340
@@ -1372,20 +1345,17 @@ init_ufuncs(void)
1372
1345
};
1373
1346
1374
1347
if (init_ufunc (numpy , "maximum" , binary_dtypes , binary_resolve_descriptors ,
1375
- & maximum_strided_loop , "string_maximum" , 2 , 1 ,
1376
- NPY_NO_CASTING , 0 ) < 0 ) {
1348
+ & maximum_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1377
1349
goto error ;
1378
1350
}
1379
1351
1380
1352
if (init_ufunc (numpy , "minimum" , binary_dtypes , binary_resolve_descriptors ,
1381
- & minimum_strided_loop , "string_minimum" , 2 , 1 ,
1382
- NPY_NO_CASTING , 0 ) < 0 ) {
1353
+ & minimum_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1383
1354
goto error ;
1384
1355
}
1385
1356
1386
1357
if (init_ufunc (numpy , "add" , binary_dtypes , binary_resolve_descriptors ,
1387
- & add_strided_loop , "string_add" , 2 , 1 , NPY_NO_CASTING ,
1388
- 0 ) < 0 ) {
1358
+ & add_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1389
1359
goto error ;
1390
1360
}
1391
1361
@@ -1414,6 +1384,20 @@ init_ufuncs(void)
1414
1384
INIT_MULTIPLY (ULongLong , ulonglong );
1415
1385
#endif
1416
1386
1387
+ if (add_promoter (numpy , "multiply" , (PyArray_DTypeMeta * )& StringDType ,
1388
+ & PyArray_PyIntAbstractDType ,
1389
+ (PyArray_DTypeMeta * )& StringDType ,
1390
+ string_multiply_promoter ) < 0 ) {
1391
+ goto error ;
1392
+ }
1393
+
1394
+ if (add_promoter (numpy , "multiply" , & PyArray_PyIntAbstractDType ,
1395
+ (PyArray_DTypeMeta * )& StringDType ,
1396
+ (PyArray_DTypeMeta * )& StringDType ,
1397
+ string_multiply_promoter ) < 0 ) {
1398
+ goto error ;
1399
+ }
1400
+
1417
1401
Py_DECREF (numpy );
1418
1402
return 0 ;
1419
1403
0 commit comments