Skip to content

Commit e752fc3

Browse files
committed
CAReduce loop reordering C-impl
1 parent 00a8a88 commit e752fc3

File tree

3 files changed

+418
-188
lines changed

3 files changed

+418
-188
lines changed

pytensor/tensor/elemwise.py

Lines changed: 101 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy
2+
from textwrap import dedent
23

34
import numpy as np
45
from numpy.core.numeric import normalize_axis_tuple
@@ -1448,116 +1449,114 @@ def infer_shape(self, fgraph, node, shapes):
14481449
return ((),)
14491450
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)
14501451

1451-
def _c_all(self, node, name, inames, onames, sub):
1452-
input = node.inputs[0]
1453-
output = node.outputs[0]
1452+
def _c_all(self, node, name, input_names, output_names, sub):
1453+
[inp] = node.inputs
1454+
[out] = node.outputs
1455+
ndim = inp.type.ndim
14541456

1455-
iname = inames[0]
1456-
oname = onames[0]
1457+
[inp_name] = input_names
1458+
[out_name] = output_names
14571459

1458-
idtype = input.type.dtype_specs()[1]
1459-
odtype = output.type.dtype_specs()[1]
1460+
inp_dtype = inp.type.dtype_specs()[1]
1461+
out_dtype = out.type.dtype_specs()[1]
14601462

14611463
acc_dtype = getattr(self, "acc_dtype", None)
14621464

14631465
if acc_dtype is not None:
14641466
if acc_dtype == "float16":
14651467
raise MethodNotDefined("no c_code for float16")
14661468
acc_type = TensorType(shape=node.outputs[0].type.shape, dtype=acc_dtype)
1467-
adtype = acc_type.dtype_specs()[1]
1469+
acc_dtype = acc_type.dtype_specs()[1]
14681470
else:
1469-
adtype = odtype
1471+
acc_dtype = out_dtype
14701472

14711473
axis = self.axis
14721474
if axis is None:
1473-
axis = list(range(input.type.ndim))
1475+
axis = list(range(inp.type.ndim))
14741476

14751477
if len(axis) == 0:
1478+
# This is just an Elemwise cast operation
14761479
# The acc_dtype is never a downcast compared to the input dtype
14771480
# So we just need a cast to the output dtype.
1478-
var = pytensor.tensor.basic.cast(input, node.outputs[0].dtype)
1479-
if var is input:
1480-
var = Elemwise(scalar_identity)(input)
1481+
var = pytensor.tensor.basic.cast(inp, node.outputs[0].dtype)
1482+
if var is inp:
1483+
var = Elemwise(scalar_identity)(inp)
14811484
assert var.dtype == node.outputs[0].dtype
1482-
return var.owner.op._c_all(var.owner, name, inames, onames, sub)
1483-
1484-
order1 = [i for i in range(input.type.ndim) if i not in axis]
1485-
order = order1 + list(axis)
1485+
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)
14861486

1487-
nnested = len(order1)
1487+
inp_dims = list(range(ndim))
1488+
non_reduced_dims = [i for i in inp_dims if i not in axis]
1489+
counter = iter(range(ndim))
1490+
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]
14881491

1489-
sub = dict(sub)
1490-
for i, (input, iname) in enumerate(zip(node.inputs, inames)):
1491-
sub[f"lv{i}"] = iname
1492+
sub = sub.copy()
1493+
sub["lv0"] = inp_name
1494+
sub["lv1"] = out_name
1495+
sub["olv"] = out_name
14921496

1493-
decl = ""
1494-
if adtype != odtype:
1497+
if acc_dtype != out_dtype:
14951498
# Create an accumulator variable different from the output
1496-
aname = "acc"
1497-
decl = acc_type.c_declare(aname, sub)
1498-
decl += acc_type.c_init(aname, sub)
1499+
acc_name = "acc"
1500+
setup = acc_type.c_declare(acc_name, sub) + acc_type.c_init(acc_name, sub)
14991501
else:
15001502
# the output is the accumulator variable
1501-
aname = oname
1502-
1503-
decl += cgen.make_declare([order], [idtype], sub)
1504-
checks = cgen.make_checks([order], [idtype], sub)
1505-
1506-
alloc = ""
1507-
i += 1
1508-
sub[f"lv{i}"] = oname
1509-
sub["olv"] = oname
1510-
1511-
# Allocate output buffer
1512-
alloc += cgen.make_declare(
1513-
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
1514-
)
1515-
alloc += cgen.make_alloc([order1], odtype, sub)
1516-
alloc += cgen.make_checks(
1517-
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
1503+
acc_name = out_name
1504+
setup = ""
1505+
1506+
# Define strides of input array
1507+
setup += cgen.make_declare(
1508+
[inp_dims], [inp_dtype], sub, compute_stride_jump=False
1509+
) + cgen.make_checks([inp_dims], [inp_dtype], sub, compute_stride_jump=False)
1510+
1511+
# Define strides of output array and allocate it
1512+
out_sub = sub | {"lv0": out_name}
1513+
alloc = (
1514+
cgen.make_declare(
1515+
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
1516+
)
1517+
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
1518+
+ cgen.make_checks(
1519+
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
1520+
)
15181521
)
15191522

1520-
if adtype != odtype:
1521-
# Allocate accumulation buffer
1522-
sub[f"lv{i}"] = aname
1523-
sub["olv"] = aname
1523+
if acc_dtype != out_dtype:
1524+
# Define strides of accumulation buffer and allocate it
1525+
sub["lv1"] = acc_name
1526+
sub["olv"] = acc_name
15241527

1525-
alloc += cgen.make_declare(
1526-
[list(range(nnested)) + ["x"] * len(axis)],
1527-
[adtype],
1528-
dict(sub, lv0=aname),
1529-
)
1530-
alloc += cgen.make_alloc([order1], adtype, sub)
1531-
alloc += cgen.make_checks(
1532-
[list(range(nnested)) + ["x"] * len(axis)],
1533-
[adtype],
1534-
dict(sub, lv0=aname),
1528+
acc_sub = sub | {"lv0": acc_name}
1529+
alloc += (
1530+
cgen.make_declare(
1531+
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
1532+
)
1533+
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
1534+
+ cgen.make_checks(
1535+
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
1536+
)
15351537
)
15361538

15371539
identity = self.scalar_op.identity
1538-
15391540
if np.isposinf(identity):
1540-
if input.type.dtype in ("float32", "float64"):
1541+
if inp.type.dtype in ("float32", "float64"):
15411542
identity = "__builtin_inf()"
1542-
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
1543+
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
15431544
identity = "1"
15441545
else:
1545-
identity = "NPY_MAX_" + str(input.type.dtype).upper()
1546+
identity = "NPY_MAX_" + str(inp.type.dtype).upper()
15461547
elif np.isneginf(identity):
1547-
if input.type.dtype in ("float32", "float64"):
1548+
if inp.type.dtype in ("float32", "float64"):
15481549
identity = "-__builtin_inf()"
1549-
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
1550+
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
15501551
identity = "0"
15511552
else:
1552-
identity = "NPY_MIN_" + str(input.type.dtype).upper()
1553+
identity = "NPY_MIN_" + str(inp.type.dtype).upper()
15531554
elif identity is None:
15541555
raise TypeError(f"The {self.scalar_op} does not define an identity.")
15551556

1556-
task0_decl = f"{adtype}& {aname}_i = *{aname}_iter;\n{aname}_i = {identity};"
1557-
1558-
task1_decl = f"{idtype}& {inames[0]}_i = *{inames[0]}_iter;\n"
1557+
initial_value = f"{acc_name}_i = {identity};"
15591558

1560-
task1_code = self.scalar_op.c_code(
1559+
inner_task = self.scalar_op.c_code(
15611560
Apply(
15621561
self.scalar_op,
15631562
[
@@ -1570,44 +1569,45 @@ def _c_all(self, node, name, inames, onames, sub):
15701569
],
15711570
),
15721571
None,
1573-
[f"{aname}_i", f"{inames[0]}_i"],
1574-
[f"{aname}_i"],
1572+
[f"{acc_name}_i", f"{inp_name}_i"],
1573+
[f"{acc_name}_i"],
15751574
sub,
15761575
)
1577-
code1 = f"""
1578-
{{
1579-
{task1_decl}
1580-
{task1_code}
1581-
}}
1582-
"""
15831576

1584-
if node.inputs[0].type.ndim:
1585-
if len(axis) == 1:
1586-
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
1587-
else:
1588-
all_code = (
1589-
[("", "")] * nnested
1590-
+ [(task0_decl, "")]
1591-
+ [("", "")] * (len(axis) - 2)
1592-
+ [("", code1), ""]
1593-
)
1577+
if out.type.ndim == 0:
1578+
# Simple case where everything is reduced, no need for loop ordering
1579+
loop = cgen.make_complete_loop_careduce(
1580+
inp_var=inp_name,
1581+
acc_var=acc_name,
1582+
inp_dtype=inp_dtype,
1583+
acc_dtype=acc_dtype,
1584+
initial_value=initial_value,
1585+
inner_task=inner_task,
1586+
fail_code=sub["fail"],
1587+
)
15941588
else:
1595-
all_code = [task0_decl + code1]
1596-
loop = cgen.make_loop_careduce(
1597-
[order, list(range(nnested)) + ["x"] * len(axis)],
1598-
[idtype, adtype],
1599-
all_code,
1600-
sub,
1601-
)
1589+
loop = cgen.make_reordered_loop_careduce(
1590+
inp_var=inp_name,
1591+
acc_var=acc_name,
1592+
inp_dtype=inp_dtype,
1593+
acc_dtype=acc_dtype,
1594+
inp_ndim=ndim,
1595+
reduction_axes=axis,
1596+
initial_value=initial_value,
1597+
inner_task=inner_task,
1598+
)
16021599

1603-
end = ""
1604-
if adtype != odtype:
1605-
end = f"""
1606-
PyArray_CopyInto({oname}, {aname});
1607-
"""
1608-
end += acc_type.c_cleanup(aname, sub)
1600+
if acc_dtype != out_dtype:
1601+
cast = dedent(
1602+
f"""
1603+
PyArray_CopyInto({out_name}, {acc_name});
1604+
{acc_type.c_cleanup(acc_name, sub)}
1605+
"""
1606+
)
1607+
else:
1608+
cast = ""
16091609

1610-
return decl, checks, alloc, loop, end
1610+
return setup, alloc, loop, cast
16111611

16121612
def c_code(self, node, name, inames, onames, sub):
16131613
code = "\n".join(self._c_all(node, name, inames, onames, sub))
@@ -1619,7 +1619,7 @@ def c_headers(self, **kwargs):
16191619

16201620
def c_code_cache_version_apply(self, node):
16211621
# the version corresponding to the c code in this Op
1622-
version = [9]
1622+
version = [10]
16231623

16241624
# now we insert versions for the ops on which we depend...
16251625
scalar_node = Apply(

0 commit comments

Comments
 (0)