9
9
10
10
extern " C" {
11
11
#include < Python.h>
12
+ #include < cstdio>
12
13
13
14
#include " numpy/arrayobject.h"
14
15
#include " numpy/ndarraytypes.h"
15
16
#include " numpy/ufuncobject.h"
16
17
17
18
#include " numpy/dtype_api.h"
18
19
}
19
-
20
20
#include " dtype.h"
21
21
#include " umath.h"
22
22
#include " ops.hpp"
@@ -33,18 +33,22 @@ quad_generic_unary_op_strided_loop(PyArrayMethod_Context *context, char *const d
33
33
npy_intp in_stride = strides[0 ];
34
34
npy_intp out_stride = strides[1 ];
35
35
36
+ Sleef_quad in, out;
36
37
while (N--) {
37
- unary_op ((Sleef_quad *)in_ptr, (Sleef_quad *)out_ptr);
38
+ memcpy (&in, in_ptr, sizeof (Sleef_quad));
39
+ unary_op (&in, &out);
40
+ memcpy (out_ptr, &out, sizeof (Sleef_quad));
41
+
38
42
in_ptr += in_stride;
39
43
out_ptr += out_stride;
40
44
}
41
45
return 0 ;
42
46
}
43
47
44
48
static NPY_CASTING
45
- quad_unary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *dtypes[],
46
- QuadPrecDTypeObject * given_descrs[],
47
- QuadPrecDTypeObject *loop_descrs[], npy_intp *unused )
49
+ quad_unary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
50
+ PyArray_Descr * const given_descrs[], PyArray_Descr *loop_descrs [],
51
+ npy_intp *NPY_UNUSED (view_offset) )
48
52
{
49
53
Py_INCREF (given_descrs[0 ]);
50
54
loop_descrs[0 ] = given_descrs[0 ];
@@ -57,7 +61,7 @@ quad_unary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
57
61
Py_INCREF (given_descrs[1 ]);
58
62
loop_descrs[1 ] = given_descrs[1 ];
59
63
60
- return NPY_NO_CASTING; // Quad precision is always the same precision
64
+ return NPY_NO_CASTING;
61
65
}
62
66
63
67
template <unary_op_def unary_op>
@@ -156,8 +160,12 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
156
160
npy_intp in2_stride = strides[1 ];
157
161
npy_intp out_stride = strides[2 ];
158
162
163
+ Sleef_quad in1, in2, out;
159
164
while (N--) {
160
- binop ((Sleef_quad *)out_ptr, (Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
165
+ memcpy (&in1, in1_ptr, sizeof (Sleef_quad));
166
+ memcpy (&in2, in2_ptr, sizeof (Sleef_quad));
167
+ binop (&out, &in1, &in2);
168
+ memcpy (out_ptr, &out, sizeof (Sleef_quad));
161
169
162
170
in1_ptr += in1_stride;
163
171
in2_ptr += in2_stride;
@@ -167,35 +175,186 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
167
175
}
168
176
169
177
static NPY_CASTING
170
- quad_binary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *dtypes[],
171
- QuadPrecDTypeObject * given_descrs[],
172
- QuadPrecDTypeObject *loop_descrs[], npy_intp *unused )
178
+ quad_binary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
179
+ PyArray_Descr * const given_descrs[],
180
+ PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED (view_offset) )
173
181
{
174
182
Py_INCREF (given_descrs[0 ]);
175
183
loop_descrs[0 ] = given_descrs[0 ];
176
184
Py_INCREF (given_descrs[1 ]);
177
185
loop_descrs[1 ] = given_descrs[1 ];
178
186
179
187
if (given_descrs[2 ] == NULL ) {
188
+ PyArray_Descr *out_descr = (PyArray_Descr *)new_quaddtype_instance ();
189
+ if (!out_descr) {
190
+ return (NPY_CASTING)-1 ;
191
+ }
180
192
Py_INCREF (given_descrs[0 ]);
181
- loop_descrs[2 ] = given_descrs[ 0 ] ;
193
+ loop_descrs[2 ] = out_descr ;
182
194
}
183
195
else {
184
196
Py_INCREF (given_descrs[2 ]);
185
197
loop_descrs[2 ] = given_descrs[2 ];
186
198
}
187
199
188
- return NPY_NO_CASTING; // Quad precision is always the same precision
200
+ return NPY_NO_CASTING;
189
201
}
190
202
191
- // todo: skipping the promoter for now, since same type operation will be requried
203
+ // helper debugging function
204
+ static const char *
205
+ get_dtype_name (PyArray_DTypeMeta *dtype)
206
+ {
207
+ if (dtype == &QuadPrecDType) {
208
+ return " QuadPrecDType" ;
209
+ }
210
+ else if (dtype == &PyArray_BoolDType) {
211
+ return " BoolDType" ;
212
+ }
213
+ else if (dtype == &PyArray_ByteDType) {
214
+ return " ByteDType" ;
215
+ }
216
+ else if (dtype == &PyArray_UByteDType) {
217
+ return " UByteDType" ;
218
+ }
219
+ else if (dtype == &PyArray_ShortDType) {
220
+ return " ShortDType" ;
221
+ }
222
+ else if (dtype == &PyArray_UShortDType) {
223
+ return " UShortDType" ;
224
+ }
225
+ else if (dtype == &PyArray_IntDType) {
226
+ return " IntDType" ;
227
+ }
228
+ else if (dtype == &PyArray_UIntDType) {
229
+ return " UIntDType" ;
230
+ }
231
+ else if (dtype == &PyArray_LongDType) {
232
+ return " LongDType" ;
233
+ }
234
+ else if (dtype == &PyArray_ULongDType) {
235
+ return " ULongDType" ;
236
+ }
237
+ else if (dtype == &PyArray_LongLongDType) {
238
+ return " LongLongDType" ;
239
+ }
240
+ else if (dtype == &PyArray_ULongLongDType) {
241
+ return " ULongLongDType" ;
242
+ }
243
+ else if (dtype == &PyArray_FloatDType) {
244
+ return " FloatDType" ;
245
+ }
246
+ else if (dtype == &PyArray_DoubleDType) {
247
+ return " DoubleDType" ;
248
+ }
249
+ else if (dtype == &PyArray_LongDoubleDType) {
250
+ return " LongDoubleDType" ;
251
+ }
252
+ else {
253
+ return " UnknownDType" ;
254
+ }
255
+ }
256
+
257
+ static int
258
+ quad_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
259
+ PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
260
+ {
261
+ // printf("quad_ufunc_promoter called for ufunc: %s\n", ufunc->name);
262
+ // printf("Entering quad_ufunc_promoter\n");
263
+ // printf("Ufunc name: %s\n", ufunc->name);
264
+ // printf("nin: %d, nargs: %d\n", ufunc->nin, ufunc->nargs);
265
+
266
+ int nin = ufunc->nin ;
267
+ int nargs = ufunc->nargs ;
268
+ PyArray_DTypeMeta *common = NULL ;
269
+ bool has_quad = false ;
270
+
271
+ // Handle the special case for reductions
272
+ if (op_dtypes[0 ] == NULL ) {
273
+ assert (nin == 2 && ufunc->nout == 1 ); /* must be reduction */
274
+ for (int i = 0 ; i < 3 ; i++) {
275
+ Py_INCREF (op_dtypes[1 ]);
276
+ new_op_dtypes[i] = op_dtypes[1 ];
277
+ // printf("new_op_dtypes[%d] set to %s\n", i, get_dtype_name(new_op_dtypes[i]));
278
+ }
279
+ return 0 ;
280
+ }
281
+
282
+ // Check if any input or signature is QuadPrecision
283
+ for (int i = 0 ; i < nargs; i++) {
284
+ if ((i < nin && op_dtypes[i] == &QuadPrecDType) || (signature[i] == &QuadPrecDType)) {
285
+ has_quad = true ;
286
+ // printf("QuadPrecision detected in input %d or signature\n", i);
287
+ break ;
288
+ }
289
+ }
290
+
291
+ if (has_quad) {
292
+ // If QuadPrecision is involved, use it for all arguments
293
+ common = &QuadPrecDType;
294
+ // printf("Using QuadPrecDType as common type\n");
295
+ }
296
+ else {
297
+ // Check if output signature is homogeneous
298
+ for (int i = nin; i < nargs; i++) {
299
+ if (signature[i] != NULL ) {
300
+ if (common == NULL ) {
301
+ Py_INCREF (signature[i]);
302
+ common = signature[i];
303
+ // printf("Common type set to %s from signature\n", get_dtype_name(common));
304
+ }
305
+ else if (common != signature[i]) {
306
+ Py_CLEAR (common); // Not homogeneous, unset common
307
+ // printf("Output signature not homogeneous, cleared common type\n");
308
+ break ;
309
+ }
310
+ }
311
+ }
312
+
313
+ // If no common output dtype, use standard promotion for inputs
314
+ if (common == NULL ) {
315
+ // printf("Using standard promotion for inputs\n");
316
+ common = PyArray_PromoteDTypeSequence (nin, op_dtypes);
317
+ if (common == NULL ) {
318
+ if (PyErr_ExceptionMatches (PyExc_TypeError)) {
319
+ PyErr_Clear (); // Do not propagate normal promotion errors
320
+ }
321
+ // printf("Exiting quad_ufunc_promoter (promotion failed)\n");
322
+ return -1 ;
323
+ }
324
+ // printf("Common type after promotion: %s\n", get_dtype_name(common));
325
+ }
326
+ }
327
+
328
+ // Set all new_op_dtypes to the common dtype
329
+ for (int i = 0 ; i < nargs; i++) {
330
+ if (signature[i]) {
331
+ // If signature is specified for this argument, use it
332
+ Py_INCREF (signature[i]);
333
+ new_op_dtypes[i] = signature[i];
334
+ // printf("new_op_dtypes[%d] set to %s (from signature)\n", i,
335
+ // get_dtype_name(new_op_dtypes[i]));
336
+ }
337
+ else {
338
+ // Otherwise, use the common dtype
339
+ Py_INCREF (common);
340
+ new_op_dtypes[i] = common;
341
+ // printf("new_op_dtypes[%d] set to %s (from common)\n", i,
342
+ // get_dtype_name(new_op_dtypes[i]));
343
+ }
344
+ }
345
+
346
+ Py_XDECREF (common);
347
+ // printf("Exiting quad_ufunc_promoter\n");
348
+ return 0 ;
349
+ }
192
350
193
351
template <binop_def binop>
194
352
int
195
353
create_quad_binary_ufunc (PyObject *numpy, const char *ufunc_name)
196
354
{
197
355
PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
198
356
if (ufunc == NULL ) {
357
+ Py_DecRef (ufunc);
199
358
return -1 ;
200
359
}
201
360
@@ -220,6 +379,25 @@ create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
220
379
return -1 ;
221
380
}
222
381
382
+ PyObject *promoter_capsule =
383
+ PyCapsule_New ((void *)&quad_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
384
+ if (promoter_capsule == NULL ) {
385
+ return -1 ;
386
+ }
387
+
388
+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
389
+ if (DTypes == 0 ) {
390
+ Py_DECREF (promoter_capsule);
391
+ return -1 ;
392
+ }
393
+
394
+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
395
+ Py_DECREF (promoter_capsule);
396
+ Py_DECREF (DTypes);
397
+ return -1 ;
398
+ }
399
+ Py_DECREF (promoter_capsule);
400
+ Py_DECREF (DTypes);
223
401
return 0 ;
224
402
}
225
403
@@ -272,6 +450,22 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
272
450
return 0 ;
273
451
}
274
452
453
+ NPY_NO_EXPORT int
454
+ comparison_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
455
+ PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
456
+ {
457
+ PyArray_DTypeMeta *new_signature[NPY_MAXARGS];
458
+
459
+ memcpy (new_signature, signature, 3 * sizeof (PyArray_DTypeMeta *));
460
+ new_signature[2 ] = NULL ;
461
+ int res = quad_ufunc_promoter (ufunc, op_dtypes, new_signature, new_op_dtypes);
462
+ if (res < 0 ) {
463
+ return -1 ;
464
+ }
465
+ Py_XSETREF (new_op_dtypes[2 ], &PyArray_BoolDType);
466
+ return 0 ;
467
+ }
468
+
275
469
template <cmp_def comp>
276
470
int
277
471
create_quad_comparison_ufunc (PyObject *numpy, const char *ufunc_name)
@@ -300,6 +494,26 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
300
494
return -1 ;
301
495
}
302
496
497
+ PyObject *promoter_capsule =
498
+ PyCapsule_New ((void *)&comparison_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
499
+ if (promoter_capsule == NULL ) {
500
+ return -1 ;
501
+ }
502
+
503
+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArray_BoolDType);
504
+ if (DTypes == 0 ) {
505
+ Py_DECREF (promoter_capsule);
506
+ return -1 ;
507
+ }
508
+
509
+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
510
+ Py_DECREF (promoter_capsule);
511
+ Py_DECREF (DTypes);
512
+ return -1 ;
513
+ }
514
+ Py_DECREF (promoter_capsule);
515
+ Py_DECREF (DTypes);
516
+
303
517
return 0 ;
304
518
}
305
519
0 commit comments