Skip to content

Commit e8cba6a

Browse files
committed
temporary solution to handle both backends
1 parent 7a85fbf commit e8cba6a

File tree

3 files changed

+70
-41
lines changed

3 files changed

+70
-41
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,15 @@ numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta
272272
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
273273
npy_intp *view_offset)
274274
{
275+
printf("cast.cpp: numpy_to_quad_resolve_descriptors is called\n");
275276
if (given_descrs[1] == NULL) {
276277
loop_descrs[1] = (PyArray_Descr *)new_quaddtype_instance(BACKEND_SLEEF);
277278
if (loop_descrs[1] == nullptr) {
278279
return (NPY_CASTING)-1;
279280
}
280281
}
281282
else {
283+
printf("cast.cpp: numpy_to_quad_resolve_descriptors, I am in ELSE condition\n");
282284
Py_INCREF(given_descrs[1]);
283285
loop_descrs[1] = given_descrs[1];
284286
}
@@ -297,8 +299,12 @@ numpy_to_quad_strided_loop(PyArrayMethod_Context *context, char *const data[],
297299
char *in_ptr = data[0];
298300
char *out_ptr = data[1];
299301

302+
QuadPrecDTypeObject *descr_out1 = (QuadPrecDTypeObject *)context->descriptors[0];
303+
printf("The type of context->descriptor[0] is: %s\n", Py_TYPE(descr_out1)->tp_name);
300304
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)context->descriptors[1];
305+
printf("The type of context->descriptor[1] is: %s\n", Py_TYPE(descr_out)->tp_name);
301306
QuadBackendType backend = descr_out->backend;
307+
printf("cast.cpp: numpy_to_quad_strided_loop with backend: %d\n", backend);
302308
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
303309

304310
while (N--) {

quaddtype/numpy_quaddtype/src/dtype.c

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ quad_store(char *data_ptr, void *x, QuadBackendType backend)
4949
QuadPrecDTypeObject *
5050
new_quaddtype_instance(QuadBackendType backend)
5151
{
52+
// if (backend != BACKEND_SLEEF && backend != BACKEND_LONGDOUBLE)
53+
// {
54+
// PyErr_SetString(PyExc_TypeError,
55+
// "Backend must be sleef or longdouble");
56+
// return NULL;
57+
// }
58+
printf("New Quandtype instance is created with backend: %d\n", backend);
5259
QuadPrecDTypeObject *new = (QuadPrecDTypeObject *)PyArrayDescr_Type.tp_new(
5360
(PyTypeObject *)&QuadPrecDType, NULL, NULL);
5461
if (new == NULL) {
@@ -63,13 +70,15 @@ new_quaddtype_instance(QuadBackendType backend)
6370
static QuadPrecDTypeObject *
6471
ensure_canonical(QuadPrecDTypeObject *self)
6572
{
73+
printf("Ensure Canonical is called\n");
6674
Py_INCREF(self);
6775
return self;
6876
}
6977

7078
static QuadPrecDTypeObject *
7179
common_instance(QuadPrecDTypeObject *dtype1, QuadPrecDTypeObject *dtype2)
7280
{
81+
printf("Common Instance is called\n");
7382
if (dtype1->backend != dtype2->backend) {
7483
PyErr_SetString(PyExc_TypeError,
7584
"Cannot find common instance for QuadPrecDTypes with different backends");
@@ -82,6 +91,7 @@ common_instance(QuadPrecDTypeObject *dtype1, QuadPrecDTypeObject *dtype2)
8291
static PyArray_DTypeMeta *
8392
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
8493
{
94+
printf("Common dtype is called\n");
8595
// Promote integer and floating-point types to QuadPrecDType
8696
if (other->type_num >= 0 &&
8797
(PyTypeNum_ISINTEGER(other->type_num) || PyTypeNum_ISFLOAT(other->type_num))) {
@@ -105,7 +115,9 @@ quadprec_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), P
105115
PyErr_SetString(PyExc_TypeError, "Can only store QuadPrecision in a QuadPrecDType array.");
106116
return NULL;
107117
}
118+
108119
QuadPrecisionObject *quad_obj = (QuadPrecisionObject *)obj;
120+
printf("dtype.c: quadprec_discover_descriptor_from_pyobject is called with backend %d\n", quad_obj->backend);
109121
return (PyArray_Descr *)new_quaddtype_instance(quad_obj->backend);
110122
}
111123

@@ -156,10 +168,11 @@ quadprec_getitem(QuadPrecDTypeObject *descr, char *dataptr)
156168
static PyArray_Descr *
157169
quadprec_default_descr(PyArray_DTypeMeta *cls)
158170
{
159-
QuadPrecDTypeObject *temp = (QuadPrecDTypeObject *)cls;
160-
const char *s1 = (temp->backend == BACKEND_SLEEF) ? "SLEEF" : "LONGDOUBLE";
161-
printf("called with backend: %s\n", s1);
162-
return (PyArray_Descr *)new_quaddtype_instance(temp->backend);
171+
QuadPrecDTypeObject * a = (QuadPrecDTypeObject *)cls;
172+
printf("Default descriptor called with backend: %d\n", a->backend);
173+
QuadPrecDTypeObject * temp = new_quaddtype_instance(a->backend);
174+
printf("Default descriptor made backend: %d\n", temp->backend);
175+
return (PyArray_Descr *)temp;
163176
}
164177

165178
static PyType_Slot QuadPrecDType_Slots[] = {

quaddtype/numpy_quaddtype/src/umath.cpp

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ init_quad_unary_ops(PyObject *numpy)
175175
if (create_quad_unary_ufunc<quad_negative, ld_negative>(numpy, "negative") < 0) {
176176
return -1;
177177
}
178+
if (create_quad_unary_ufunc<quad_positive, ld_positive>(numpy, "positive") < 0) {
179+
return -1;
180+
}
178181
if (create_quad_unary_ufunc<quad_absolute, ld_absolute>(numpy, "absolute") < 0) {
179182
return -1;
180183
}
@@ -224,44 +227,62 @@ quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtyp
224227
PyArray_Descr *const given_descrs[],
225228
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
226229
{
227-
printf("Descriptor Resolver is calledn\n");
228-
Py_INCREF(given_descrs[0]);
229-
loop_descrs[0] = given_descrs[0];
230-
Py_INCREF(given_descrs[1]);
231-
loop_descrs[1] = given_descrs[1];
230+
printf("Descriptor Resolver is called\n");
232231

233232
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
234233
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];
234+
QuadBackendType target_backend;
235+
235236
const char *s1 = (descr_in1->backend == BACKEND_SLEEF) ? "SLEEF" : "LONGDOUBLE";
236237
const char *s2 = (descr_in2->backend == BACKEND_SLEEF) ? "SLEEF" : "LONGDOUBLE";
237-
printf("1: %s\n", s1);
238-
printf("2: %s\n", s2);
238+
printf("1: %s %d %s\n", s1, descr_in1->backend, Py_TYPE(given_descrs[0])->tp_name);
239+
printf("2: %s %d %s\n", s2, descr_in2->backend, Py_TYPE(given_descrs[1])->tp_name);
239240

241+
// Determine target backend and if casting is needed
242+
NPY_CASTING casting = NPY_NO_CASTING;
240243
if (descr_in1->backend != descr_in2->backend) {
241-
PyErr_SetString(PyExc_TypeError,
242-
"Cannot operate on QuadPrecision objects with different backends");
243-
return (NPY_CASTING)-1;
244+
target_backend = BACKEND_LONGDOUBLE;
245+
casting = NPY_SAFE_CASTING;
246+
printf("Different backends detected. Casting to LONGDOUBLE.\n");
247+
} else {
248+
target_backend = descr_in1->backend;
249+
printf("Unified backend: %s\n", (target_backend == BACKEND_SLEEF) ? "SLEEF" : "LONGDOUBLE");
250+
}
251+
252+
// Set up input descriptors, casting if necessary
253+
for (int i = 0; i < 2; i++) {
254+
if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
255+
loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
256+
if (!loop_descrs[i]) {
257+
return (NPY_CASTING)-1;
258+
}
259+
} else {
260+
Py_INCREF(given_descrs[i]);
261+
loop_descrs[i] = given_descrs[i];
262+
}
244263
}
245264

265+
// Set up output descriptor
246266
if (given_descrs[2] == NULL) {
247-
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(descr_in1->backend);
267+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
248268
if (!loop_descrs[2]) {
249269
return (NPY_CASTING)-1;
250270
}
251-
}
252-
else {
253-
Py_INCREF(given_descrs[2]);
254-
loop_descrs[2] = given_descrs[2];
255-
}
256-
257-
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)loop_descrs[2];
258-
if (descr_out->backend != descr_in1->backend) {
259-
PyErr_SetString(PyExc_TypeError,
260-
"Output QuadPrecision object must have the same backend as inputs");
261-
return (NPY_CASTING)-1;
271+
} else {
272+
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2];
273+
if (descr_out->backend != target_backend) {
274+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
275+
if (!loop_descrs[2]) {
276+
return (NPY_CASTING)-1;
277+
}
278+
} else {
279+
Py_INCREF(given_descrs[2]);
280+
loop_descrs[2] = given_descrs[2];
281+
}
262282
}
263283

264-
return NPY_NO_CASTING;
284+
printf("Casting result: %d\n", casting);
285+
return casting;
265286
}
266287

267288
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
@@ -270,6 +291,7 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
270291
npy_intp const dimensions[], npy_intp const strides[],
271292
NpyAuxData *auxdata)
272293
{
294+
printf("Umath: Generic Strided loop is calledn\n");
273295
npy_intp N = dimensions[0];
274296
char *in1_ptr = data[0], *in2_ptr = data[1];
275297
char *out_ptr = data[2];
@@ -328,22 +350,10 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
328350

329351
// Check if any input or signature is QuadPrecision
330352
for (int i = 0; i < nin; i++) {
353+
printf("iterating on dtype : %s\n", get_dtype_name(op_dtypes[i]));
331354
if (op_dtypes[i] == &QuadPrecDType) {
332355
has_quad = true;
333-
QuadPrecDTypeObject *descr =
334-
(QuadPrecDTypeObject *)PyArray_GetDefaultDescr(op_dtypes[i]);
335-
336-
const char *s = (descr->backend == BACKEND_SLEEF) ? "SLEEF" : "LONGDOUBLE";
337-
printf("QuadPrecision detected in input %d or signature with backend: %s\n", i, s);
338-
if (backend == BACKEND_INVALID)
339-
backend = descr->backend;
340-
else if (backend != BACKEND_INVALID && backend != descr->backend) {
341-
PyErr_SetString(PyExc_TypeError,
342-
"Cannot mix QuadPrecDType arrays with different backends");
343-
return -1;
344-
}
345-
Py_DECREF(descr);
346-
break;
356+
printf("QuadPrecision detected in input %d\n", i);
347357
}
348358
}
349359

0 commit comments

Comments
 (0)