Skip to content

Commit 96270bc

Browse files
committed
added dtype promoter functions
1 parent d2e1ded commit 96270bc

File tree

1 file changed

+227
-13
lines changed

1 file changed

+227
-13
lines changed

quaddtype/quaddtype/src/umath.cpp

Lines changed: 227 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
extern "C" {
1111
#include <Python.h>
12+
#include <cstdio>
1213

1314
#include "numpy/arrayobject.h"
1415
#include "numpy/ndarraytypes.h"
1516
#include "numpy/ufuncobject.h"
1617

1718
#include "numpy/dtype_api.h"
1819
}
19-
2020
#include "dtype.h"
2121
#include "umath.h"
2222
#include "ops.hpp"
@@ -33,18 +33,22 @@ quad_generic_unary_op_strided_loop(PyArrayMethod_Context *context, char *const d
3333
npy_intp in_stride = strides[0];
3434
npy_intp out_stride = strides[1];
3535

36+
Sleef_quad in, out;
3637
while (N--) {
37-
unary_op((Sleef_quad *)in_ptr, (Sleef_quad *)out_ptr);
38+
memcpy(&in, in_ptr, sizeof(Sleef_quad));
39+
unary_op(&in, &out);
40+
memcpy(out_ptr, &out, sizeof(Sleef_quad));
41+
3842
in_ptr += in_stride;
3943
out_ptr += out_stride;
4044
}
4145
return 0;
4246
}
4347

4448
static NPY_CASTING
45-
quad_unary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
46-
QuadPrecDTypeObject *given_descrs[],
47-
QuadPrecDTypeObject *loop_descrs[], npy_intp *unused)
49+
quad_unary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
50+
PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
51+
npy_intp *NPY_UNUSED(view_offset))
4852
{
4953
Py_INCREF(given_descrs[0]);
5054
loop_descrs[0] = given_descrs[0];
@@ -57,7 +61,7 @@ quad_unary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
5761
Py_INCREF(given_descrs[1]);
5862
loop_descrs[1] = given_descrs[1];
5963

60-
return NPY_NO_CASTING; // Quad precision is always the same precision
64+
return NPY_NO_CASTING;
6165
}
6266

6367
template <unary_op_def unary_op>
@@ -156,8 +160,12 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
156160
npy_intp in2_stride = strides[1];
157161
npy_intp out_stride = strides[2];
158162

163+
Sleef_quad in1, in2, out;
159164
while (N--) {
160-
binop((Sleef_quad *)out_ptr, (Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
165+
memcpy(&in1, in1_ptr, sizeof(Sleef_quad));
166+
memcpy(&in2, in2_ptr, sizeof(Sleef_quad));
167+
binop(&out, &in1, &in2);
168+
memcpy(out_ptr, &out, sizeof(Sleef_quad));
161169

162170
in1_ptr += in1_stride;
163171
in2_ptr += in2_stride;
@@ -167,35 +175,186 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
167175
}
168176

169177
static NPY_CASTING
170-
quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
171-
QuadPrecDTypeObject *given_descrs[],
172-
QuadPrecDTypeObject *loop_descrs[], npy_intp *unused)
178+
quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
179+
PyArray_Descr *const given_descrs[],
180+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
173181
{
174182
Py_INCREF(given_descrs[0]);
175183
loop_descrs[0] = given_descrs[0];
176184
Py_INCREF(given_descrs[1]);
177185
loop_descrs[1] = given_descrs[1];
178186

179187
if (given_descrs[2] == NULL) {
188+
PyArray_Descr *out_descr = (PyArray_Descr *)new_quaddtype_instance();
189+
if (!out_descr) {
190+
return (NPY_CASTING)-1;
191+
}
180192
Py_INCREF(given_descrs[0]);
181-
loop_descrs[2] = given_descrs[0];
193+
loop_descrs[2] = out_descr;
182194
}
183195
else {
184196
Py_INCREF(given_descrs[2]);
185197
loop_descrs[2] = given_descrs[2];
186198
}
187199

188-
return NPY_NO_CASTING; // Quad precision is always the same precision
200+
return NPY_NO_CASTING;
189201
}
190202

191-
// todo: skipping the promoter for now, since same type operation will be requried
203+
// helper debugging function
204+
static const char *
205+
get_dtype_name(PyArray_DTypeMeta *dtype)
206+
{
207+
if (dtype == &QuadPrecDType) {
208+
return "QuadPrecDType";
209+
}
210+
else if (dtype == &PyArray_BoolDType) {
211+
return "BoolDType";
212+
}
213+
else if (dtype == &PyArray_ByteDType) {
214+
return "ByteDType";
215+
}
216+
else if (dtype == &PyArray_UByteDType) {
217+
return "UByteDType";
218+
}
219+
else if (dtype == &PyArray_ShortDType) {
220+
return "ShortDType";
221+
}
222+
else if (dtype == &PyArray_UShortDType) {
223+
return "UShortDType";
224+
}
225+
else if (dtype == &PyArray_IntDType) {
226+
return "IntDType";
227+
}
228+
else if (dtype == &PyArray_UIntDType) {
229+
return "UIntDType";
230+
}
231+
else if (dtype == &PyArray_LongDType) {
232+
return "LongDType";
233+
}
234+
else if (dtype == &PyArray_ULongDType) {
235+
return "ULongDType";
236+
}
237+
else if (dtype == &PyArray_LongLongDType) {
238+
return "LongLongDType";
239+
}
240+
else if (dtype == &PyArray_ULongLongDType) {
241+
return "ULongLongDType";
242+
}
243+
else if (dtype == &PyArray_FloatDType) {
244+
return "FloatDType";
245+
}
246+
else if (dtype == &PyArray_DoubleDType) {
247+
return "DoubleDType";
248+
}
249+
else if (dtype == &PyArray_LongDoubleDType) {
250+
return "LongDoubleDType";
251+
}
252+
else {
253+
return "UnknownDType";
254+
}
255+
}
256+
257+
static int
258+
quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
259+
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
260+
{
261+
// printf("quad_ufunc_promoter called for ufunc: %s\n", ufunc->name);
262+
// printf("Entering quad_ufunc_promoter\n");
263+
// printf("Ufunc name: %s\n", ufunc->name);
264+
// printf("nin: %d, nargs: %d\n", ufunc->nin, ufunc->nargs);
265+
266+
int nin = ufunc->nin;
267+
int nargs = ufunc->nargs;
268+
PyArray_DTypeMeta *common = NULL;
269+
bool has_quad = false;
270+
271+
// Handle the special case for reductions
272+
if (op_dtypes[0] == NULL) {
273+
assert(nin == 2 && ufunc->nout == 1); /* must be reduction */
274+
for (int i = 0; i < 3; i++) {
275+
Py_INCREF(op_dtypes[1]);
276+
new_op_dtypes[i] = op_dtypes[1];
277+
// printf("new_op_dtypes[%d] set to %s\n", i, get_dtype_name(new_op_dtypes[i]));
278+
}
279+
return 0;
280+
}
281+
282+
// Check if any input or signature is QuadPrecision
283+
for (int i = 0; i < nargs; i++) {
284+
if ((i < nin && op_dtypes[i] == &QuadPrecDType) || (signature[i] == &QuadPrecDType)) {
285+
has_quad = true;
286+
// printf("QuadPrecision detected in input %d or signature\n", i);
287+
break;
288+
}
289+
}
290+
291+
if (has_quad) {
292+
// If QuadPrecision is involved, use it for all arguments
293+
common = &QuadPrecDType;
294+
// printf("Using QuadPrecDType as common type\n");
295+
}
296+
else {
297+
// Check if output signature is homogeneous
298+
for (int i = nin; i < nargs; i++) {
299+
if (signature[i] != NULL) {
300+
if (common == NULL) {
301+
Py_INCREF(signature[i]);
302+
common = signature[i];
303+
// printf("Common type set to %s from signature\n", get_dtype_name(common));
304+
}
305+
else if (common != signature[i]) {
306+
Py_CLEAR(common); // Not homogeneous, unset common
307+
// printf("Output signature not homogeneous, cleared common type\n");
308+
break;
309+
}
310+
}
311+
}
312+
313+
// If no common output dtype, use standard promotion for inputs
314+
if (common == NULL) {
315+
// printf("Using standard promotion for inputs\n");
316+
common = PyArray_PromoteDTypeSequence(nin, op_dtypes);
317+
if (common == NULL) {
318+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
319+
PyErr_Clear(); // Do not propagate normal promotion errors
320+
}
321+
// printf("Exiting quad_ufunc_promoter (promotion failed)\n");
322+
return -1;
323+
}
324+
// printf("Common type after promotion: %s\n", get_dtype_name(common));
325+
}
326+
}
327+
328+
// Set all new_op_dtypes to the common dtype
329+
for (int i = 0; i < nargs; i++) {
330+
if (signature[i]) {
331+
// If signature is specified for this argument, use it
332+
Py_INCREF(signature[i]);
333+
new_op_dtypes[i] = signature[i];
334+
// printf("new_op_dtypes[%d] set to %s (from signature)\n", i,
335+
// get_dtype_name(new_op_dtypes[i]));
336+
}
337+
else {
338+
// Otherwise, use the common dtype
339+
Py_INCREF(common);
340+
new_op_dtypes[i] = common;
341+
// printf("new_op_dtypes[%d] set to %s (from common)\n", i,
342+
// get_dtype_name(new_op_dtypes[i]));
343+
}
344+
}
345+
346+
Py_XDECREF(common);
347+
// printf("Exiting quad_ufunc_promoter\n");
348+
return 0;
349+
}
192350

193351
template <binop_def binop>
194352
int
195353
create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
196354
{
197355
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
198356
if (ufunc == NULL) {
357+
Py_DecRef(ufunc);
199358
return -1;
200359
}
201360

@@ -220,6 +379,25 @@ create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
220379
return -1;
221380
}
222381

382+
PyObject *promoter_capsule =
383+
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
384+
if (promoter_capsule == NULL) {
385+
return -1;
386+
}
387+
388+
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
389+
if (DTypes == 0) {
390+
Py_DECREF(promoter_capsule);
391+
return -1;
392+
}
393+
394+
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
395+
Py_DECREF(promoter_capsule);
396+
Py_DECREF(DTypes);
397+
return -1;
398+
}
399+
Py_DECREF(promoter_capsule);
400+
Py_DECREF(DTypes);
223401
return 0;
224402
}
225403

@@ -272,6 +450,22 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
272450
return 0;
273451
}
274452

453+
NPY_NO_EXPORT int
454+
comparison_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
455+
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
456+
{
457+
PyArray_DTypeMeta *new_signature[NPY_MAXARGS];
458+
459+
memcpy(new_signature, signature, 3 * sizeof(PyArray_DTypeMeta *));
460+
new_signature[2] = NULL;
461+
int res = quad_ufunc_promoter(ufunc, op_dtypes, new_signature, new_op_dtypes);
462+
if (res < 0) {
463+
return -1;
464+
}
465+
Py_XSETREF(new_op_dtypes[2], &PyArray_BoolDType);
466+
return 0;
467+
}
468+
275469
template <cmp_def comp>
276470
int
277471
create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
@@ -300,6 +494,26 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
300494
return -1;
301495
}
302496

497+
PyObject *promoter_capsule =
498+
PyCapsule_New((void *)&comparison_ufunc_promoter, "numpy._ufunc_promoter", NULL);
499+
if (promoter_capsule == NULL) {
500+
return -1;
501+
}
502+
503+
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArray_BoolDType);
504+
if (DTypes == 0) {
505+
Py_DECREF(promoter_capsule);
506+
return -1;
507+
}
508+
509+
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
510+
Py_DECREF(promoter_capsule);
511+
Py_DECREF(DTypes);
512+
return -1;
513+
}
514+
Py_DECREF(promoter_capsule);
515+
Py_DECREF(DTypes);
516+
303517
return 0;
304518
}
305519

0 commit comments

Comments
 (0)