Skip to content

Commit a8303a0

Browse files
committed
Cleanup elemwise_cgen.py
1 parent 49daa85 commit a8303a0

File tree

1 file changed

+92
-112
lines changed

1 file changed

+92
-112
lines changed

pytensor/tensor/elemwise_cgen.py

Lines changed: 92 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from textwrap import dedent, indent
2+
13
from pytensor.configdefaults import config
24

35

@@ -8,51 +10,49 @@ def make_declare(loop_orders, dtypes, sub):
810
"""
911
decl = ""
1012
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)):
11-
var = sub[f"lv{int(i)}"] # input name corresponding to ith loop variable
13+
var = sub[f"lv{i}"] # input name corresponding to ith loop variable
1214
# we declare an iteration variable
1315
# and an integer for the number of dimensions
14-
decl += f"""
15-
{dtype}* {var}_iter;
16-
"""
16+
decl += f"{dtype}* {var}_iter;\n"
1717
for j, value in enumerate(loop_order):
1818
if value != "x":
1919
# If the dimension is not broadcasted, we declare
2020
# the number of elements in that dimension,
2121
# the stride in that dimension,
2222
# and the jump from an iteration to the next
2323
decl += f"""
24-
npy_intp {var}_n{int(value)};
25-
ssize_t {var}_stride{int(value)};
26-
int {var}_jump{int(value)}_{int(j)};
24+
npy_intp {var}_n{value};
25+
ssize_t {var}_stride{value};
26+
int {var}_jump{value}_{j};
2727
"""
2828

2929
else:
3030
# if the dimension is broadcasted, we only need
3131
# the jump (arbitrary length and stride = 0)
32-
decl += f"""
33-
int {var}_jump{value}_{int(j)};
34-
"""
32+
decl += f"int {var}_jump{value}_{j};\n"
3533

3634
return decl
3735

3836

3937
def make_checks(loop_orders, dtypes, sub):
4038
init = ""
4139
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)):
42-
var = f"%(lv{int(i)})s"
40+
var = sub[f"lv{i}"]
4341
# List of dimensions of var that are not broadcasted
4442
nonx = [x for x in loop_order if x != "x"]
4543
if nonx:
4644
# If there are dimensions that are not broadcasted
4745
# this is a check that the number of dimensions of the
4846
# tensor is as expected.
4947
min_nd = max(nonx) + 1
50-
init += f"""
51-
if (PyArray_NDIM({var}) < {min_nd}) {{
52-
PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
53-
%(fail)s
54-
}}
55-
"""
48+
init += dedent(
49+
f"""
50+
if (PyArray_NDIM({var}) < {min_nd}) {{
51+
PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
52+
{indent(sub["fail"], " " * 12)}
53+
}}
54+
"""
55+
)
5656

5757
# In loop j, adjust represents the difference of values of the
5858
# data pointer between the beginning and the end of the
@@ -75,9 +75,7 @@ def make_checks(loop_orders, dtypes, sub):
7575
adjust = f"{var}_n{index}*{var}_stride{index}"
7676
else:
7777
jump = f"-({adjust})"
78-
init += f"""
79-
{var}_jump{index}_{j} = {jump};
80-
"""
78+
init += f"{var}_jump{index}_{j} = {jump};\n"
8179
adjust = "0"
8280
check = ""
8381

@@ -101,34 +99,36 @@ def make_checks(loop_orders, dtypes, sub):
10199

102100
j0, x0 = to_compare[0]
103101
for j, x in to_compare[1:]:
104-
check += f"""
105-
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
106-
{{
107-
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
102+
check += dedent(
103+
f"""
104+
if ({sub[f"lv{j0}"]}_n{x0} != {sub[f"lv{j}"]}_n{x})
108105
{{
109-
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
110-
{j0},
111-
{x0},
112-
(long long int) %(lv{j0})s_n{x0},
113-
{j},
114-
{x},
115-
(long long int) %(lv{j})s_n{x}
116-
);
117-
}} else {{
118-
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
106+
if ({sub[f"lv{j0}"]}_n{x0} == 1 || {sub[f"lv{j}"]}_n{x} == 1)
107+
{{
108+
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
119109
{j0},
120110
{x0},
121-
(long long int) %(lv{j0})s_n{x0},
111+
(long long int) {sub[f"lv{j0}"]}_n{x0},
122112
{j},
123113
{x},
124-
(long long int) %(lv{j})s_n{x}
125-
);
114+
(long long int) {sub[f"lv{j}"]}_n{x}
115+
);
116+
}} else {{
117+
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
118+
{j0},
119+
{x0},
120+
(long long int) {sub[f"lv{j0}"]}_n{x0},
121+
{j},
122+
{x},
123+
(long long int) {sub[f"lv{j}"]}_n{x}
124+
);
125+
}}
126+
{sub["fail"]}
126127
}}
127-
%(fail)s
128-
}}
129-
"""
128+
"""
129+
)
130130

131-
return init % sub + check % sub
131+
return init + check
132132

133133

134134
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
@@ -144,7 +144,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
144144
# Borrow the length of the first non-broadcastable input dimension
145145
for j, candidate in enumerate(candidates):
146146
if candidate != "x":
147-
var = sub[f"lv{int(j)}"]
147+
var = sub[f"lv{j}"]
148148
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
149149
break
150150
# If none is non-broadcastable, the output dimension has a length of 1
@@ -177,35 +177,37 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
177177
# way that its contiguous dimensions match one of the input's
178178
# contiguous dimensions, or the dimension with the smallest
179179
# stride. Right now, it is allocated to be C_CONTIGUOUS.
180-
return f"""
181-
{{
182-
npy_intp dims[{nd}];
183-
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
184-
{init_dims}
185-
if (!{olv}) {{
186-
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims,
187-
{type},
188-
{fortran});
189-
}}
190-
else {{
191-
PyArray_Dims new_dims;
192-
new_dims.len = {nd};
193-
new_dims.ptr = dims;
194-
PyObject* success = PyArray_Resize({olv}, &new_dims, 0, NPY_CORDER);
195-
if (!success) {{
196-
// If we can't resize the ndarray we have we can allocate a new one.
197-
PyErr_Clear();
198-
Py_XDECREF({olv});
199-
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}, 0);
200-
}} else {{
201-
Py_DECREF(success);
180+
return dedent(
181+
f"""
182+
{{
183+
npy_intp dims[{nd}];
184+
{init_dims}
185+
if (!{olv}) {{
186+
{olv} = (PyArrayObject*)PyArray_EMPTY({nd},
187+
dims,
188+
{type},
189+
{fortran});
190+
}}
191+
else {{
192+
PyArray_Dims new_dims;
193+
new_dims.len = {nd};
194+
new_dims.ptr = dims;
195+
PyObject* success = PyArray_Resize({olv}, &new_dims, 0, NPY_CORDER);
196+
if (!success) {{
197+
// If we can't resize the ndarray we have we can allocate a new one.
198+
PyErr_Clear();
199+
Py_XDECREF({olv});
200+
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}, 0);
201+
}} else {{
202+
Py_DECREF(success);
203+
}}
204+
}}
205+
if (!{olv}) {{
206+
{fail}
202207
}}
203208
}}
204-
if (!{olv}) {{
205-
{fail}
206-
}}
207-
}}
208-
"""
209+
"""
210+
)
209211

210212

211213
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
@@ -235,11 +237,11 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
235237
"""
236238

237239
def loop_over(preloop, code, indices, i):
238-
iterv = f"ITER_{int(i)}"
240+
iterv = f"ITER_{i}"
239241
update = ""
240242
suitable_n = "1"
241243
for j, index in enumerate(indices):
242-
var = sub[f"lv{int(j)}"]
244+
var = sub[f"lv{j}"]
243245
dtype = dtypes[j]
244246
update += f"{dtype} &{var}_i = * ( {var}_iter + {iterv} * {var}_jump{index}_{i} );\n"
245247

@@ -305,21 +307,21 @@ def make_reordered_loop(
305307
nnested = len(init_loop_orders[0])
306308

307309
# This is the var from which we'll get the loop order
308-
ovar = sub[f"lv{int(olv_index)}"]
310+
ovar = sub[f"lv{olv_index}"]
309311

310312
# The loops are ordered by (decreasing) absolute values of ovar's strides.
311313
# The first element of each pair is the absolute value of the stride
312314
# The second element correspond to the index in the initial loop order
313315
order_loops = f"""
314-
std::vector< std::pair<int, int> > {ovar}_loops({int(nnested)});
316+
std::vector< std::pair<int, int> > {ovar}_loops({nnested});
315317
std::vector< std::pair<int, int> >::iterator {ovar}_loops_it = {ovar}_loops.begin();
316318
"""
317319

318320
# Fill the loop vector with the appropriate <stride, index> pairs
319321
for i, index in enumerate(init_loop_orders[olv_index]):
320322
if index != "x":
321323
order_loops += f"""
322-
{ovar}_loops_it->first = abs(PyArray_STRIDES({ovar})[{int(index)}]);
324+
{ovar}_loops_it->first = abs(PyArray_STRIDES({ovar})[{index}]);
323325
"""
324326
else:
325327
# Stride is 0 when dimension is broadcastable
@@ -328,7 +330,7 @@ def make_reordered_loop(
328330
"""
329331

330332
order_loops += f"""
331-
{ovar}_loops_it->second = {int(i)};
333+
{ovar}_loops_it->second = {i};
332334
++{ovar}_loops_it;
333335
"""
334336

@@ -352,7 +354,7 @@ def make_reordered_loop(
352354

353355
for i in range(nnested):
354356
declare_totals += f"""
355-
int TOTAL_{int(i)} = init_totals[{ovar}_loops_it->second];
357+
int TOTAL_{i} = init_totals[{ovar}_loops_it->second];
356358
++{ovar}_loops_it;
357359
"""
358360

@@ -365,7 +367,7 @@ def get_loop_strides(loop_order, i):
365367
specified loop_order.
366368
367369
"""
368-
var = sub[f"lv{int(i)}"]
370+
var = sub[f"lv{i}"]
369371
r = []
370372
for index in loop_order:
371373
# Note: the stride variable is not declared for broadcasted variables
@@ -383,7 +385,7 @@ def get_loop_strides(loop_order, i):
383385
)
384386

385387
declare_strides = f"""
386-
int init_strides[{int(nvars)}][{int(nnested)}] = {{
388+
int init_strides[{nvars}][{nnested}] = {{
387389
{strides}
388390
}};"""
389391

@@ -394,33 +396,33 @@ def get_loop_strides(loop_order, i):
394396
"""
395397

396398
for i in range(nvars):
397-
var = sub[f"lv{int(i)}"]
399+
var = sub[f"lv{i}"]
398400
declare_strides += f"""
399401
{ovar}_loops_rit = {ovar}_loops.rbegin();"""
400402
for j in reversed(range(nnested)):
401403
declare_strides += f"""
402-
int {var}_stride_l{int(j)} = init_strides[{int(i)}][{ovar}_loops_rit->second];
404+
int {var}_stride_l{j} = init_strides[{i}][{ovar}_loops_rit->second];
403405
++{ovar}_loops_rit;
404406
"""
405407

406408
declare_iter = ""
407409
for i, dtype in enumerate(dtypes):
408-
var = sub[f"lv{int(i)}"]
410+
var = sub[f"lv{i}"]
409411
declare_iter += f"{var}_iter = ({dtype}*)(PyArray_DATA({var}));\n"
410412

411413
pointer_update = ""
412414
for j, dtype in enumerate(dtypes):
413-
var = sub[f"lv{int(j)}"]
415+
var = sub[f"lv{j}"]
414416
pointer_update += f"{dtype} &{var}_i = * ( {var}_iter"
415417
for i in reversed(range(nnested)):
416-
iterv = f"ITER_{int(i)}"
417-
pointer_update += f"+{var}_stride_l{int(i)}*{iterv}"
418+
iterv = f"ITER_{i}"
419+
pointer_update += f"+{var}_stride_l{i}*{iterv}"
418420
pointer_update += ");\n"
419421

420422
loop = inner_task
421423
for i in reversed(range(nnested)):
422-
iterv = f"ITER_{int(i)}"
423-
total = f"TOTAL_{int(i)}"
424+
iterv = f"ITER_{i}"
425+
total = f"TOTAL_{i}"
424426
update = ""
425427
forloop = ""
426428
# The pointers are defined only in the most inner loop
@@ -434,36 +436,14 @@ def get_loop_strides(loop_order, i):
434436

435437
loop = f"""
436438
{forloop}
437-
{{ // begin loop {int(i)}
439+
{{ // begin loop {i}
438440
{update}
439441
{loop}
440-
}} // end loop {int(i)}
442+
}} // end loop {i}
441443
"""
442444

443-
return f"{{\n{order_loops}\n{declare_totals}\n{declare_strides}\n{declare_iter}\n{loop}\n}}\n"
444-
445-
446-
# print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
447-
# ('double', 'int', 'float'),
448-
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
449-
450-
# print make_checks(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
451-
# ('double', 'int', 'float'),
452-
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
453-
454-
# print make_alloc(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
455-
# 'double',
456-
# dict(olv='out', lv0='x', lv1='y', lv2='z', fail="FAIL;"))
457-
458-
# print make_loop(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
459-
# ('double', 'int', 'float'),
460-
# (("C00;", "C%01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
461-
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
462-
463-
# print make_loop(((0, 1, 2, 3), (3, 'x', 0, 'x'), (0, 'x', 'x', 'x')),
464-
# ('double', 'int', 'float'),
465-
# (("C00;", "C01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
466-
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
445+
code = "\n".join((order_loops, declare_totals, declare_strides, declare_iter, loop))
446+
return f"{{\n{code}\n}}\n"
467447

468448

469449
##################

0 commit comments

Comments
 (0)