1
+ from textwrap import dedent , indent
2
+
1
3
from pytensor .configdefaults import config
2
4
3
5
@@ -8,51 +10,49 @@ def make_declare(loop_orders, dtypes, sub):
8
10
"""
9
11
decl = ""
10
12
for i , (loop_order , dtype ) in enumerate (zip (loop_orders , dtypes )):
11
- var = sub [f"lv{ int ( i ) } " ] # input name corresponding to ith loop variable
13
+ var = sub [f"lv{ i } " ] # input name corresponding to ith loop variable
12
14
# we declare an iteration variable
13
15
# and an integer for the number of dimensions
14
- decl += f"""
15
- { dtype } * { var } _iter;
16
- """
16
+ decl += f"{ dtype } * { var } _iter;\n "
17
17
for j , value in enumerate (loop_order ):
18
18
if value != "x" :
19
19
# If the dimension is not broadcasted, we declare
20
20
# the number of elements in that dimension,
21
21
# the stride in that dimension,
22
22
# and the jump from an iteration to the next
23
23
decl += f"""
24
- npy_intp { var } _n{ int ( value ) } ;
25
- ssize_t { var } _stride{ int ( value ) } ;
26
- int { var } _jump{ int ( value ) } _{ int ( j ) } ;
24
+ npy_intp { var } _n{ value } ;
25
+ ssize_t { var } _stride{ value } ;
26
+ int { var } _jump{ value } _{ j } ;
27
27
"""
28
28
29
29
else :
30
30
# if the dimension is broadcasted, we only need
31
31
# the jump (arbitrary length and stride = 0)
32
- decl += f"""
33
- int { var } _jump{ value } _{ int (j )} ;
34
- """
32
+ decl += f"int { var } _jump{ value } _{ j } ;\n "
35
33
36
34
return decl
37
35
38
36
39
37
def make_checks (loop_orders , dtypes , sub ):
40
38
init = ""
41
39
for i , (loop_order , dtype ) in enumerate (zip (loop_orders , dtypes )):
42
- var = f"%( lv{ int ( i ) } )s"
40
+ var = sub [ f" lv{ i } " ]
43
41
# List of dimensions of var that are not broadcasted
44
42
nonx = [x for x in loop_order if x != "x" ]
45
43
if nonx :
46
44
# If there are dimensions that are not broadcasted
47
45
# this is a check that the number of dimensions of the
48
46
# tensor is as expected.
49
47
min_nd = max (nonx ) + 1
50
- init += f"""
51
- if (PyArray_NDIM({ var } ) < { min_nd } ) {{
52
- PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
53
- %(fail)s
54
- }}
55
- """
48
+ init += dedent (
49
+ f"""
50
+ if (PyArray_NDIM({ var } ) < { min_nd } ) {{
51
+ PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
52
+ { indent (sub ["fail" ], " " * 12 )}
53
+ }}
54
+ """
55
+ )
56
56
57
57
# In loop j, adjust represents the difference of values of the
58
58
# data pointer between the beginning and the end of the
@@ -75,9 +75,7 @@ def make_checks(loop_orders, dtypes, sub):
75
75
adjust = f"{ var } _n{ index } *{ var } _stride{ index } "
76
76
else :
77
77
jump = f"-({ adjust } )"
78
- init += f"""
79
- { var } _jump{ index } _{ j } = { jump } ;
80
- """
78
+ init += f"{ var } _jump{ index } _{ j } = { jump } ;\n "
81
79
adjust = "0"
82
80
check = ""
83
81
@@ -101,34 +99,36 @@ def make_checks(loop_orders, dtypes, sub):
101
99
102
100
j0 , x0 = to_compare [0 ]
103
101
for j , x in to_compare [1 :]:
104
- check += f"""
105
- if (%(lv{ j0 } )s_n{ x0 } != %(lv{ j } )s_n{ x } )
106
- {{
107
- if (%(lv{ j0 } )s_n{ x0 } == 1 || %(lv{ j } )s_n{ x } == 1)
102
+ check += dedent (
103
+ f"""
104
+ if ({ sub [f"lv{ j0 } " ]} _n{ x0 } != { sub [f"lv{ j } " ]} _n{ x } )
108
105
{{
109
- PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
110
- { j0 } ,
111
- { x0 } ,
112
- (long long int) %(lv{ j0 } )s_n{ x0 } ,
113
- { j } ,
114
- { x } ,
115
- (long long int) %(lv{ j } )s_n{ x }
116
- );
117
- }} else {{
118
- PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
106
+ if ({ sub [f"lv{ j0 } " ]} _n{ x0 } == 1 || { sub [f"lv{ j } " ]} _n{ x } == 1)
107
+ {{
108
+ PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
119
109
{ j0 } ,
120
110
{ x0 } ,
121
- (long long int) %( lv{ j0 } )s_n { x0 } ,
111
+ (long long int) { sub [ f" lv{ j0 } " ] } _n { x0 } ,
122
112
{ j } ,
123
113
{ x } ,
124
- (long long int) %(lv{ j } )s_n{ x }
125
- );
114
+ (long long int) { sub [f"lv{ j } " ]} _n{ x }
115
+ );
116
+ }} else {{
117
+ PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
118
+ { j0 } ,
119
+ { x0 } ,
120
+ (long long int) { sub [f"lv{ j0 } " ]} _n{ x0 } ,
121
+ { j } ,
122
+ { x } ,
123
+ (long long int) { sub [f"lv{ j } " ]} _n{ x }
124
+ );
125
+ }}
126
+ { sub ["fail" ]}
126
127
}}
127
- %(fail)s
128
- }}
129
- """
128
+ """
129
+ )
130
130
131
- return init % sub + check % sub
131
+ return init + check
132
132
133
133
134
134
def compute_output_dims_lengths (array_name : str , loop_orders , sub ) -> str :
@@ -144,7 +144,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
144
144
# Borrow the length of the first non-broadcastable input dimension
145
145
for j , candidate in enumerate (candidates ):
146
146
if candidate != "x" :
147
- var = sub [f"lv{ int ( j ) } " ]
147
+ var = sub [f"lv{ j } " ]
148
148
dims_c_code += f"{ array_name } [{ i } ] = { var } _n{ candidate } ;\n "
149
149
break
150
150
# If none is non-broadcastable, the output dimension has a length of 1
@@ -177,35 +177,37 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
177
177
# way that its contiguous dimensions match one of the input's
178
178
# contiguous dimensions, or the dimension with the smallest
179
179
# stride. Right now, it is allocated to be C_CONTIGUOUS.
180
- return f"""
181
- {{
182
- npy_intp dims[{ nd } ];
183
- //npy_intp* dims = (npy_intp*)malloc({ nd } * sizeof(npy_intp));
184
- { init_dims }
185
- if (!{ olv } ) {{
186
- { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims,
187
- { type } ,
188
- { fortran } );
189
- }}
190
- else {{
191
- PyArray_Dims new_dims;
192
- new_dims.len = { nd } ;
193
- new_dims.ptr = dims;
194
- PyObject* success = PyArray_Resize({ olv } , &new_dims, 0, NPY_CORDER);
195
- if (!success) {{
196
- // If we can't resize the ndarray we have we can allocate a new one.
197
- PyErr_Clear();
198
- Py_XDECREF({ olv } );
199
- { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims, { type } , 0);
200
- }} else {{
201
- Py_DECREF(success);
180
+ return dedent (
181
+ f"""
182
+ {{
183
+ npy_intp dims[{ nd } ];
184
+ { init_dims }
185
+ if (!{ olv } ) {{
186
+ { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } ,
187
+ dims,
188
+ { type } ,
189
+ { fortran } );
190
+ }}
191
+ else {{
192
+ PyArray_Dims new_dims;
193
+ new_dims.len = { nd } ;
194
+ new_dims.ptr = dims;
195
+ PyObject* success = PyArray_Resize({ olv } , &new_dims, 0, NPY_CORDER);
196
+ if (!success) {{
197
+ // If we can't resize the ndarray we have we can allocate a new one.
198
+ PyErr_Clear();
199
+ Py_XDECREF({ olv } );
200
+ { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims, { type } , 0);
201
+ }} else {{
202
+ Py_DECREF(success);
203
+ }}
204
+ }}
205
+ if (!{ olv } ) {{
206
+ { fail }
202
207
}}
203
208
}}
204
- if (!{ olv } ) {{
205
- { fail }
206
- }}
207
- }}
208
- """
209
+ """
210
+ )
209
211
210
212
211
213
def make_loop (loop_orders , dtypes , loop_tasks , sub , openmp = None ):
@@ -235,11 +237,11 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
235
237
"""
236
238
237
239
def loop_over (preloop , code , indices , i ):
238
- iterv = f"ITER_{ int ( i ) } "
240
+ iterv = f"ITER_{ i } "
239
241
update = ""
240
242
suitable_n = "1"
241
243
for j , index in enumerate (indices ):
242
- var = sub [f"lv{ int ( j ) } " ]
244
+ var = sub [f"lv{ j } " ]
243
245
dtype = dtypes [j ]
244
246
update += f"{ dtype } &{ var } _i = * ( { var } _iter + { iterv } * { var } _jump{ index } _{ i } );\n "
245
247
@@ -305,21 +307,21 @@ def make_reordered_loop(
305
307
nnested = len (init_loop_orders [0 ])
306
308
307
309
# This is the var from which we'll get the loop order
308
- ovar = sub [f"lv{ int ( olv_index ) } " ]
310
+ ovar = sub [f"lv{ olv_index } " ]
309
311
310
312
# The loops are ordered by (decreasing) absolute values of ovar's strides.
311
313
# The first element of each pair is the absolute value of the stride
312
314
# The second element correspond to the index in the initial loop order
313
315
order_loops = f"""
314
- std::vector< std::pair<int, int> > { ovar } _loops({ int ( nnested ) } );
316
+ std::vector< std::pair<int, int> > { ovar } _loops({ nnested } );
315
317
std::vector< std::pair<int, int> >::iterator { ovar } _loops_it = { ovar } _loops.begin();
316
318
"""
317
319
318
320
# Fill the loop vector with the appropriate <stride, index> pairs
319
321
for i , index in enumerate (init_loop_orders [olv_index ]):
320
322
if index != "x" :
321
323
order_loops += f"""
322
- { ovar } _loops_it->first = abs(PyArray_STRIDES({ ovar } )[{ int ( index ) } ]);
324
+ { ovar } _loops_it->first = abs(PyArray_STRIDES({ ovar } )[{ index } ]);
323
325
"""
324
326
else :
325
327
# Stride is 0 when dimension is broadcastable
@@ -328,7 +330,7 @@ def make_reordered_loop(
328
330
"""
329
331
330
332
order_loops += f"""
331
- { ovar } _loops_it->second = { int ( i ) } ;
333
+ { ovar } _loops_it->second = { i } ;
332
334
++{ ovar } _loops_it;
333
335
"""
334
336
@@ -352,7 +354,7 @@ def make_reordered_loop(
352
354
353
355
for i in range (nnested ):
354
356
declare_totals += f"""
355
- int TOTAL_{ int ( i ) } = init_totals[{ ovar } _loops_it->second];
357
+ int TOTAL_{ i } = init_totals[{ ovar } _loops_it->second];
356
358
++{ ovar } _loops_it;
357
359
"""
358
360
@@ -365,7 +367,7 @@ def get_loop_strides(loop_order, i):
365
367
specified loop_order.
366
368
367
369
"""
368
- var = sub [f"lv{ int ( i ) } " ]
370
+ var = sub [f"lv{ i } " ]
369
371
r = []
370
372
for index in loop_order :
371
373
# Note: the stride variable is not declared for broadcasted variables
@@ -383,7 +385,7 @@ def get_loop_strides(loop_order, i):
383
385
)
384
386
385
387
declare_strides = f"""
386
- int init_strides[{ int ( nvars ) } ][{ int ( nnested ) } ] = {{
388
+ int init_strides[{ nvars } ][{ nnested } ] = {{
387
389
{ strides }
388
390
}};"""
389
391
@@ -394,33 +396,33 @@ def get_loop_strides(loop_order, i):
394
396
"""
395
397
396
398
for i in range (nvars ):
397
- var = sub [f"lv{ int ( i ) } " ]
399
+ var = sub [f"lv{ i } " ]
398
400
declare_strides += f"""
399
401
{ ovar } _loops_rit = { ovar } _loops.rbegin();"""
400
402
for j in reversed (range (nnested )):
401
403
declare_strides += f"""
402
- int { var } _stride_l{ int ( j ) } = init_strides[{ int ( i ) } ][{ ovar } _loops_rit->second];
404
+ int { var } _stride_l{ j } = init_strides[{ i } ][{ ovar } _loops_rit->second];
403
405
++{ ovar } _loops_rit;
404
406
"""
405
407
406
408
declare_iter = ""
407
409
for i , dtype in enumerate (dtypes ):
408
- var = sub [f"lv{ int ( i ) } " ]
410
+ var = sub [f"lv{ i } " ]
409
411
declare_iter += f"{ var } _iter = ({ dtype } *)(PyArray_DATA({ var } ));\n "
410
412
411
413
pointer_update = ""
412
414
for j , dtype in enumerate (dtypes ):
413
- var = sub [f"lv{ int ( j ) } " ]
415
+ var = sub [f"lv{ j } " ]
414
416
pointer_update += f"{ dtype } &{ var } _i = * ( { var } _iter"
415
417
for i in reversed (range (nnested )):
416
- iterv = f"ITER_{ int ( i ) } "
417
- pointer_update += f"+{ var } _stride_l{ int ( i ) } *{ iterv } "
418
+ iterv = f"ITER_{ i } "
419
+ pointer_update += f"+{ var } _stride_l{ i } *{ iterv } "
418
420
pointer_update += ");\n "
419
421
420
422
loop = inner_task
421
423
for i in reversed (range (nnested )):
422
- iterv = f"ITER_{ int ( i ) } "
423
- total = f"TOTAL_{ int ( i ) } "
424
+ iterv = f"ITER_{ i } "
425
+ total = f"TOTAL_{ i } "
424
426
update = ""
425
427
forloop = ""
426
428
# The pointers are defined only in the most inner loop
@@ -434,36 +436,14 @@ def get_loop_strides(loop_order, i):
434
436
435
437
loop = f"""
436
438
{ forloop }
437
- {{ // begin loop { int ( i ) }
439
+ {{ // begin loop { i }
438
440
{ update }
439
441
{ loop }
440
- }} // end loop { int ( i ) }
442
+ }} // end loop { i }
441
443
"""
442
444
443
- return f"{{\n { order_loops } \n { declare_totals } \n { declare_strides } \n { declare_iter } \n { loop } \n }}\n "
444
-
445
-
446
- # print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
447
- # ('double', 'int', 'float'),
448
- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
449
-
450
- # print make_checks(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
451
- # ('double', 'int', 'float'),
452
- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
453
-
454
- # print make_alloc(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
455
- # 'double',
456
- # dict(olv='out', lv0='x', lv1='y', lv2='z', fail="FAIL;"))
457
-
458
- # print make_loop(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
459
- # ('double', 'int', 'float'),
460
- # (("C00;", "C%01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
461
- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
462
-
463
- # print make_loop(((0, 1, 2, 3), (3, 'x', 0, 'x'), (0, 'x', 'x', 'x')),
464
- # ('double', 'int', 'float'),
465
- # (("C00;", "C01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
466
- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
445
+ code = "\n " .join ((order_loops , declare_totals , declare_strides , declare_iter , loop ))
446
+ return f"{{\n { code } \n }}\n "
467
447
468
448
469
449
##################
0 commit comments