1
+ #define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
2
+ #define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
3
+ #define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
4
+ #define NPY_TARGET_VERSION NPY_2_0_API_VERSION
5
+ #define NO_IMPORT_ARRAY
6
+ #define NO_IMPORT_UFUNC
7
+
8
+ extern " C" {
9
+ #include < Python.h>
10
+
11
+ #include " numpy/arrayobject.h"
12
+ #include " numpy/ndarraytypes.h"
13
+ #include " numpy/ufuncobject.h"
14
+
15
+ #include " numpy/dtype_api.h"
16
+ }
17
+
18
+ #include " scalar.h"
19
+ #include " dtype.h"
20
+ #include " umath.h"
21
+ #include " ops.hpp"
22
+
23
+ template <unary_op_def unary_op>
24
+ int quad_generic_unary_op_strided_loop (PyArrayMethod_Context *context,
25
+ char *const data[], npy_intp const dimensions[],
26
+ npy_intp const strides[], NpyAuxData *auxdata)
27
+ {
28
+ npy_intp N = dimensions[0 ];
29
+ char *in_ptr = data[0 ];
30
+ char *out_ptr = data[1 ];
31
+ npy_intp in_stride = strides[0 ];
32
+ npy_intp out_stride = strides[1 ];
33
+
34
+ while (N--)
35
+ {
36
+ unary_op ((Sleef_quad *)in_ptr, (Sleef_quad *)out_ptr);
37
+ in_ptr += in_stride;
38
+ out_ptr += out_stride;
39
+ }
40
+ return 0 ;
41
+ }
42
+
43
+ static NPY_CASTING
44
+ quad_unary_op_resolve_descriptors (PyObject *self,
45
+ PyArray_DTypeMeta *dtypes[], QuadPrecDTypeObject *given_descrs[],
46
+ QuadPrecDTypeObject *loop_descrs[], npy_intp *unused)
47
+ {
48
+ Py_INCREF (given_descrs[0 ]);
49
+ loop_descrs[0 ] = given_descrs[0 ];
50
+
51
+ if (given_descrs[1 ] == NULL ) {
52
+ Py_INCREF (given_descrs[0 ]);
53
+ loop_descrs[1 ] = given_descrs[0 ];
54
+ return NPY_NO_CASTING;
55
+ }
56
+ Py_INCREF (given_descrs[1 ]);
57
+ loop_descrs[1 ] = given_descrs[1 ];
58
+
59
+ return NPY_NO_CASTING; // Quad precision is always the same precision
60
+ }
61
+
62
+ template <unary_op_def unary_op>
63
+ int create_quad_unary_ufunc (PyObject *numpy, const char *ufunc_name)
64
+ {
65
+ PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
66
+ if (ufunc == NULL ) {
67
+ return -1 ;
68
+ }
69
+
70
+ PyArray_DTypeMeta *dtypes[2 ] = {
71
+ &QuadPrecDType, &QuadPrecDType};
72
+
73
+ PyType_Slot slots[] = {
74
+ {NPY_METH_resolve_descriptors, (void *)&quad_unary_op_resolve_descriptors},
75
+ {NPY_METH_strided_loop, (void *)&quad_generic_unary_op_strided_loop<unary_op>},
76
+ {0 , NULL }
77
+ };
78
+
79
+ PyArrayMethod_Spec Spec = {
80
+ .name = " quad_unary_op" ,
81
+ .nin = 1 ,
82
+ .nout = 1 ,
83
+ .casting = NPY_NO_CASTING,
84
+ .flags = (NPY_ARRAYMETHOD_FLAGS)0 ,
85
+ .dtypes = dtypes,
86
+ .slots = slots,
87
+ };
88
+
89
+ if (PyUFunc_AddLoopFromSpec (ufunc, &Spec) < 0 ) {
90
+ return -1 ;
91
+ }
92
+
93
+ return 0 ;
94
+ }
95
+
96
+ int init_quad_unary_ops (PyObject *numpy)
97
+ {
98
+ if (create_quad_unary_ufunc<quad_negative>(numpy, " negative" ) < 0 ) {
99
+ return -1 ;
100
+ }
101
+ if (create_quad_unary_ufunc<quad_absolute>(numpy, " absolute" ) < 0 ) {
102
+ return -1 ;
103
+ }
104
+ return 0 ;
105
+ }
106
+
107
+ // Binary ufuncs
108
+
109
+ template <binop_def binop>
110
+ int quad_generic_binop_strided_loop (PyArrayMethod_Context *context,
111
+ char *const data[], npy_intp const dimensions[],
112
+ npy_intp const strides[], NpyAuxData *auxdata)
113
+ {
114
+ npy_intp N = dimensions[0 ];
115
+ char *in1_ptr = data[0 ], *in2_ptr = data[1 ];
116
+ char *out_ptr = data[2 ];
117
+ npy_intp in1_stride = strides[0 ];
118
+ npy_intp in2_stride = strides[1 ];
119
+ npy_intp out_stride = strides[2 ];
120
+
121
+ while (N--) {
122
+ binop ((Sleef_quad *)out_ptr, (Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
123
+
124
+ in1_ptr += in1_stride;
125
+ in2_ptr += in2_stride;
126
+ out_ptr += out_stride;
127
+ }
128
+ return 0 ;
129
+ }
130
+
131
+ static NPY_CASTING
132
+ quad_binary_op_resolve_descriptors (PyObject *self,
133
+ PyArray_DTypeMeta *dtypes[], QuadPrecDTypeObject *given_descrs[],
134
+ QuadPrecDTypeObject *loop_descrs[], npy_intp *unused)
135
+ {
136
+ Py_INCREF (given_descrs[0 ]);
137
+ loop_descrs[0 ] = given_descrs[0 ];
138
+ Py_INCREF (given_descrs[1 ]);
139
+ loop_descrs[1 ] = given_descrs[1 ];
140
+
141
+ if (given_descrs[2 ] == NULL ) {
142
+ Py_INCREF (given_descrs[0 ]);
143
+ loop_descrs[2 ] = given_descrs[0 ];
144
+ }
145
+ else {
146
+ Py_INCREF (given_descrs[2 ]);
147
+ loop_descrs[2 ] = given_descrs[2 ];
148
+ }
149
+
150
+ return NPY_NO_CASTING; // Quad precision is always the same precision
151
+ }
152
+
153
+ // todo: skipping the promoter for now, since same type operation will be requried
154
+
155
+ template <binop_def binop>
156
+ int create_quad_binary_ufunc (PyObject *numpy, const char *ufunc_name)
157
+ {
158
+ PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
159
+ if (ufunc == NULL ) {
160
+ return -1 ;
161
+ }
162
+
163
+ PyArray_DTypeMeta *dtypes[3 ] = {
164
+ &QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
165
+
166
+ PyType_Slot slots[] = {
167
+ {NPY_METH_resolve_descriptors,
168
+ (void *)&quad_binary_op_resolve_descriptors},
169
+ {NPY_METH_strided_loop,
170
+ (void *)&quad_generic_binop_strided_loop<binop>},
171
+ {0 , NULL }
172
+ };
173
+
174
+ PyArrayMethod_Spec Spec = {
175
+ .name = " quad_binop" ,
176
+ .nin = 2 ,
177
+ .nout = 1 ,
178
+ .casting = NPY_NO_CASTING,
179
+ .flags = (NPY_ARRAYMETHOD_FLAGS)0 ,
180
+ .dtypes = dtypes,
181
+ .slots = slots,
182
+ };
183
+
184
+ if (PyUFunc_AddLoopFromSpec (ufunc, &Spec) < 0 ) {
185
+ return -1 ;
186
+ }
187
+
188
+ return 0 ;
189
+ }
190
+
191
+ int init_quad_binary_ops (PyObject *numpy)
192
+ {
193
+ if (create_quad_binary_ufunc<quad_add>(numpy, " add" ) < 0 ) {
194
+ return -1 ;
195
+ }
196
+ if (create_quad_binary_ufunc<quad_sub>(numpy, " subtract" ) < 0 ) {
197
+ return -1 ;
198
+ }
199
+
200
+ return 0 ;
201
+ }
202
+
203
+ int init_quad_umath (void )
204
+ {
205
+ PyObject * numpy = PyImport_ImportModule (" numpy" );
206
+ if (!numpy)
207
+ return -1 ;
208
+
209
+ if (init_quad_unary_ops (numpy) < 0 ) {
210
+ goto err;
211
+ }
212
+
213
+ if (init_quad_binary_ops (numpy) < 0 ) {
214
+ goto err;
215
+ }
216
+
217
+ Py_DECREF (numpy);
218
+ return 0 ;
219
+
220
+ err:
221
+ Py_DECREF (numpy);
222
+ return -1 ;
223
+
224
+ }
0 commit comments