Skip to content

Commit 600b2c8

Browse files
committed
fixed inter-backend cast segment fault
1 parent 672be17 commit 600b2c8

File tree

3 files changed

+77
-46
lines changed

3 files changed

+77
-46
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
2828
QuadPrecDTypeObject *given_descrs[2],
2929
QuadPrecDTypeObject *loop_descrs[2], npy_intp *view_offset)
3030
{
31-
if (given_descrs[0]->backend != given_descrs[1]->backend) {
32-
return NPY_UNSAFE_CASTING;
33-
}
31+
NPY_CASTING casting = NPY_NO_CASTING;
32+
33+
if (given_descrs[0]->backend != given_descrs[1]->backend)
34+
casting = NPY_UNSAFE_CASTING;
3435

3536
Py_INCREF(given_descrs[0]);
3637
loop_descrs[0] = given_descrs[0];
@@ -45,7 +46,7 @@ quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
4546
}
4647

4748
*view_offset = 0;
48-
return NPY_NO_CASTING;
49+
return casting;
4950
}
5051

5152
static int
@@ -62,10 +63,26 @@ quad_to_quad_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
6263
QuadPrecDTypeObject *descr_in = (QuadPrecDTypeObject *)context->descriptors[0];
6364
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)context->descriptors[1];
6465

66+
// inter-backend casting
6567
if (descr_in->backend != descr_out->backend) {
66-
PyErr_SetString(PyExc_TypeError,
67-
"Cannot convert between different quad-precision backends");
68-
return -1;
68+
while (N--) {
69+
quad_value in_val, out_val;
70+
if (descr_in->backend == BACKEND_SLEEF) {
71+
memcpy(&in_val.sleef_value, in_ptr, sizeof(Sleef_quad));
72+
out_val.longdouble_value = Sleef_cast_to_doubleq1(in_val.sleef_value);
73+
}
74+
else {
75+
memcpy(&in_val.longdouble_value, in_ptr, sizeof(long double));
76+
out_val.sleef_value = Sleef_cast_from_doubleq1(in_val.longdouble_value);
77+
}
78+
memcpy(out_ptr, &out_val,
79+
(descr_out->backend == BACKEND_SLEEF) ? sizeof(Sleef_quad)
80+
: sizeof(long double));
81+
in_ptr += in_stride;
82+
out_ptr += out_stride;
83+
}
84+
85+
return 0;
6986
}
7087

7188
size_t elem_size =
@@ -93,10 +110,26 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
93110
QuadPrecDTypeObject *descr_in = (QuadPrecDTypeObject *)context->descriptors[0];
94111
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)context->descriptors[1];
95112

113+
// inter-backend casting
96114
if (descr_in->backend != descr_out->backend) {
97-
PyErr_SetString(PyExc_TypeError,
98-
"Cannot convert between different quad-precision backends");
99-
return -1;
115+
if (descr_in->backend == BACKEND_SLEEF) {
116+
while (N--) {
117+
Sleef_quad in_val = *(Sleef_quad *)in_ptr;
118+
*(long double *)out_ptr = Sleef_cast_to_doubleq1(in_val);
119+
in_ptr += in_stride;
120+
out_ptr += out_stride;
121+
}
122+
}
123+
else {
124+
while (N--) {
125+
long double in_val = *(long double *)in_ptr;
126+
*(Sleef_quad *)out_ptr = Sleef_cast_from_doubleq1(in_val);
127+
in_ptr += in_stride;
128+
out_ptr += out_stride;
129+
}
130+
}
131+
132+
return 0;
100133
}
101134

102135
if (descr_in->backend == BACKEND_SLEEF) {
@@ -627,6 +660,11 @@ add_spec(PyArrayMethod_Spec *spec)
627660
if (spec_count < NUM_CASTS) {
628661
specs[spec_count++] = spec;
629662
}
663+
else {
664+
delete[] spec->dtypes;
665+
delete[] spec->slots;
666+
delete spec;
667+
}
630668
}
631669

632670
// functions to add casts
@@ -682,9 +720,8 @@ add_cast_to(PyArray_DTypeMeta *from)
682720
PyArrayMethod_Spec **
683721
init_casts_internal(void)
684722
{
685-
PyArray_DTypeMeta **quad2quad_dtypes =
686-
new PyArray_DTypeMeta *[2]{&QuadPrecDType, &QuadPrecDType};
687-
PyType_Slot *quad2quad_slots = new PyType_Slot[]{
723+
PyArray_DTypeMeta **quad2quad_dtypes = new PyArray_DTypeMeta *[2]{nullptr, nullptr};
724+
PyType_Slot *quad2quad_slots = new PyType_Slot[4]{
688725
{NPY_METH_resolve_descriptors, (void *)&quad_to_quad_resolve_descriptors},
689726
{NPY_METH_strided_loop, (void *)&quad_to_quad_strided_loop_aligned},
690727
{NPY_METH_unaligned_strided_loop, (void *)&quad_to_quad_strided_loop_unaligned},
@@ -694,7 +731,7 @@ init_casts_internal(void)
694731
.name = "cast_QuadPrec_to_QuadPrec",
695732
.nin = 1,
696733
.nout = 1,
697-
.casting = NPY_NO_CASTING,
734+
.casting = NPY_UNSAFE_CASTING, // since SLEEF -> ld might lose precision
698735
.flags = NPY_METH_SUPPORTS_UNALIGNED,
699736
.dtypes = quad2quad_dtypes,
700737
.slots = quad2quad_slots,
@@ -749,13 +786,13 @@ init_casts(void)
749786
void
750787
free_casts(void)
751788
{
752-
for (auto cast : specs) {
753-
if (cast == nullptr) {
754-
continue;
789+
for (size_t i = 0; i < spec_count; i++) {
790+
if (specs[i]) {
791+
delete[] specs[i]->dtypes;
792+
delete[] specs[i]->slots;
793+
delete specs[i];
794+
specs[i] = nullptr;
755795
}
756-
delete[] cast->dtypes;
757-
delete[] cast->slots;
758-
delete cast;
759796
}
760797
spec_count = 0;
761798
}

quaddtype/numpy_quaddtype/src/scalar_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ quad_binary_func(PyObject *op1, PyObject *op2)
124124
return (PyObject *)res;
125125
}
126126

127-
// todo: add support with float and int
128127
PyObject *
129128
quad_richcompare(QuadPrecisionObject *self, PyObject *other, int cmp_op)
130129
{
@@ -212,7 +211,8 @@ QuadPrecision_float(QuadPrecisionObject *self)
212211
{
213212
if (self->backend == BACKEND_SLEEF) {
214213
return PyFloat_FromDouble(Sleef_cast_to_doubleq1(self->value.sleef_value));
215-
} else {
214+
}
215+
else {
216216
return PyFloat_FromDouble((double)self->value.longdouble_value);
217217
}
218218
}
@@ -222,12 +222,12 @@ QuadPrecision_int(QuadPrecisionObject *self)
222222
{
223223
if (self->backend == BACKEND_SLEEF) {
224224
return PyLong_FromLongLong(Sleef_cast_to_int64q1(self->value.sleef_value));
225-
} else {
225+
}
226+
else {
226227
return PyLong_FromLongLong((long long)self->value.longdouble_value);
227228
}
228229
}
229230

230-
231231
PyNumberMethods quad_as_scalar = {
232232
.nb_add = (binaryfunc)quad_binary_func<quad_add, ld_add>,
233233
.nb_subtract = (binaryfunc)quad_binary_func<quad_sub, ld_sub>,

quaddtype/numpy_quaddtype/src/umath.cpp

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtyp
282282
// Determine target backend and if casting is needed
283283
NPY_CASTING casting = NPY_NO_CASTING;
284284
if (descr_in1->backend != descr_in2->backend) {
285-
286285
target_backend = BACKEND_LONGDOUBLE;
287286
casting = NPY_SAFE_CASTING;
288287
}
@@ -398,12 +397,11 @@ static int
398397
quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
399398
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
400399
{
401-
402400
int nin = ufunc->nin;
403401
int nargs = ufunc->nargs;
404402
PyArray_DTypeMeta *common = NULL;
405403
bool has_quad = false;
406-
404+
407405
// Handle the special case for reductions
408406
if (op_dtypes[0] == NULL) {
409407
assert(nin == 2 && ufunc->nout == 1); /* must be reduction */
@@ -417,7 +415,6 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
417415
// Check if any input or signature is QuadPrecision
418416
for (int i = 0; i < nin; i++) {
419417
if (op_dtypes[i] == &QuadPrecDType) {
420-
421418
has_quad = true;
422419
}
423420
}
@@ -461,7 +458,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
461458
else {
462459
// Otherwise, use the common dtype
463460
Py_INCREF(common);
464-
461+
465462
new_op_dtypes[i] = common;
466463
}
467464
}
@@ -563,13 +560,14 @@ init_quad_binary_ops(PyObject *numpy)
563560

564561
static NPY_CASTING
565562
quad_comparison_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
566-
PyArray_Descr *const given_descrs[],
567-
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
563+
PyArray_Descr *const given_descrs[],
564+
PyArray_Descr *loop_descrs[],
565+
npy_intp *NPY_UNUSED(view_offset))
568566
{
569567
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
570568
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];
571569
QuadBackendType target_backend;
572-
570+
573571
// As dealing with different backends then cast to boolean
574572
NPY_CASTING casting = NPY_NO_CASTING;
575573
if (descr_in1->backend != descr_in2->backend) {
@@ -599,7 +597,7 @@ quad_comparison_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const
599597
if (!loop_descrs[2]) {
600598
return (NPY_CASTING)-1;
601599
}
602-
return casting;
600+
return casting;
603601
}
604602

605603
template <cmp_quad_def sleef_comp, cmp_londouble_def ld_comp>
@@ -626,10 +624,9 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
626624
npy_bool result;
627625

628626
if (backend == BACKEND_SLEEF) {
629-
result = sleef_comp(&in1.sleef_value, &in2.sleef_value);
627+
result = sleef_comp(&in1.sleef_value, &in2.sleef_value);
630628
}
631629
else {
632-
633630
result = ld_comp(&in1.longdouble_value, &in2.longdouble_value);
634631
}
635632

@@ -642,12 +639,11 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
642639
return 0;
643640
}
644641

645-
646642
template <cmp_quad_def sleef_comp, cmp_londouble_def ld_comp>
647643
int
648644
quad_generic_comp_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
649-
npy_intp const dimensions[], npy_intp const strides[],
650-
NpyAuxData *auxdata)
645+
npy_intp const dimensions[], npy_intp const strides[],
646+
NpyAuxData *auxdata)
651647
{
652648
npy_intp N = dimensions[0];
653649
char *in1_ptr = data[0], *in2_ptr = data[1];
@@ -658,19 +654,16 @@ quad_generic_comp_strided_loop_aligned(PyArrayMethod_Context *context, char *con
658654

659655
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
660656
QuadBackendType backend = descr->backend;
661-
while (N--)
662-
{
657+
while (N--) {
663658
quad_value in1 = *(quad_value *)in1_ptr;
664659
quad_value in2 = *(quad_value *)in2_ptr;
665660

666661
npy_bool result;
667662

668-
if (backend == BACKEND_SLEEF)
669-
{
663+
if (backend == BACKEND_SLEEF) {
670664
result = sleef_comp(&in1.sleef_value, &in2.sleef_value);
671-
}
672-
else
673-
{
665+
}
666+
else {
674667
result = ld_comp(&in1.longdouble_value, &in2.longdouble_value);
675668
}
676669

@@ -711,7 +704,8 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
711704

712705
PyType_Slot slots[] = {
713706
{NPY_METH_resolve_descriptors, (void *)&quad_comparison_op_resolve_descriptors},
714-
{NPY_METH_strided_loop, (void *)&quad_generic_comp_strided_loop_aligned<sleef_comp, ld_comp>},
707+
{NPY_METH_strided_loop,
708+
(void *)&quad_generic_comp_strided_loop_aligned<sleef_comp, ld_comp>},
715709
{NPY_METH_unaligned_strided_loop,
716710
(void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
717711
{0, NULL}};

0 commit comments

Comments
 (0)