Skip to content

Commit 054a5e4

Browse files
committed
CAReduce loop reordering C-impl
1 parent ab38b24 commit 054a5e4

File tree

3 files changed

+325
-188
lines changed

3 files changed

+325
-188
lines changed

pytensor/tensor/elemwise.py

Lines changed: 100 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy
2+
from textwrap import dedent
23

34
import numpy as np
45
from numpy.core.numeric import normalize_axis_tuple
@@ -1466,116 +1467,114 @@ def infer_shape(self, fgraph, node, shapes):
14661467
return ((),)
14671468
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)
14681469

1469-
def _c_all(self, node, name, inames, onames, sub):
1470-
input = node.inputs[0]
1471-
output = node.outputs[0]
1470+
def _c_all(self, node, name, input_names, output_names, sub):
1471+
[inp] = node.inputs
1472+
[out] = node.outputs
1473+
ndim = inp.type.ndim
14721474

1473-
iname = inames[0]
1474-
oname = onames[0]
1475+
[inp_name] = input_names
1476+
[out_name] = output_names
14751477

1476-
idtype = input.type.dtype_specs()[1]
1477-
odtype = output.type.dtype_specs()[1]
1478+
inp_dtype = inp.type.dtype_specs()[1]
1479+
out_dtype = out.type.dtype_specs()[1]
14781480

14791481
acc_dtype = getattr(self, "acc_dtype", None)
14801482

14811483
if acc_dtype is not None:
14821484
if acc_dtype == "float16":
14831485
raise MethodNotDefined("no c_code for float16")
14841486
acc_type = TensorType(shape=node.outputs[0].type.shape, dtype=acc_dtype)
1485-
adtype = acc_type.dtype_specs()[1]
1487+
acc_dtype = acc_type.dtype_specs()[1]
14861488
else:
1487-
adtype = odtype
1489+
acc_dtype = out_dtype
14881490

14891491
axis = self.axis
14901492
if axis is None:
1491-
axis = list(range(input.type.ndim))
1493+
axis = list(range(inp.type.ndim))
14921494

14931495
if len(axis) == 0:
1496+
# This is just an Elemwise cast operation
14941497
# The acc_dtype is never a downcast compared to the input dtype
14951498
# So we just need a cast to the output dtype.
1496-
var = pytensor.tensor.basic.cast(input, node.outputs[0].dtype)
1497-
if var is input:
1498-
var = Elemwise(scalar_identity)(input)
1499+
var = pytensor.tensor.basic.cast(inp, node.outputs[0].dtype)
1500+
if var is inp:
1501+
var = Elemwise(scalar_identity)(inp)
14991502
assert var.dtype == node.outputs[0].dtype
1500-
return var.owner.op._c_all(var.owner, name, inames, onames, sub)
1501-
1502-
order1 = [i for i in range(input.type.ndim) if i not in axis]
1503-
order = order1 + list(axis)
1503+
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)
15041504

1505-
nnested = len(order1)
1505+
inp_dims = list(range(ndim))
1506+
non_reduced_dims = [i for i in inp_dims if i not in axis]
1507+
counter = iter(range(ndim))
1508+
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]
15061509

1507-
sub = dict(sub)
1508-
for i, (input, iname) in enumerate(zip(node.inputs, inames)):
1509-
sub[f"lv{i}"] = iname
1510+
sub = sub.copy()
1511+
sub["lv0"] = inp_name
1512+
sub["lv1"] = out_name
1513+
sub["olv"] = out_name
15101514

1511-
decl = ""
1512-
if adtype != odtype:
1515+
if acc_dtype != out_dtype:
15131516
# Create an accumulator variable different from the output
1514-
aname = "acc"
1515-
decl = acc_type.c_declare(aname, sub)
1516-
decl += acc_type.c_init(aname, sub)
1517+
acc_name = "acc"
1518+
setup = acc_type.c_declare(acc_name, sub) + acc_type.c_init(acc_name, sub)
15171519
else:
15181520
# the output is the accumulator variable
1519-
aname = oname
1520-
1521-
decl += cgen.make_declare([order], [idtype], sub)
1522-
checks = cgen.make_checks([order], [idtype], sub)
1523-
1524-
alloc = ""
1525-
i += 1
1526-
sub[f"lv{i}"] = oname
1527-
sub["olv"] = oname
1528-
1529-
# Allocate output buffer
1530-
alloc += cgen.make_declare(
1531-
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
1532-
)
1533-
alloc += cgen.make_alloc([order1], odtype, sub)
1534-
alloc += cgen.make_checks(
1535-
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
1521+
acc_name = out_name
1522+
setup = ""
1523+
1524+
# Define strides of input array
1525+
setup += cgen.make_declare(
1526+
[inp_dims], [inp_dtype], sub, compute_stride_jump=False
1527+
) + cgen.make_checks([inp_dims], [inp_dtype], sub, compute_stride_jump=False)
1528+
1529+
# Define strides of output array and allocate it
1530+
out_sub = sub | {"lv0": out_name}
1531+
alloc = (
1532+
cgen.make_declare(
1533+
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
1534+
)
1535+
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
1536+
+ cgen.make_checks(
1537+
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
1538+
)
15361539
)
15371540

1538-
if adtype != odtype:
1539-
# Allocate accumulation buffer
1540-
sub[f"lv{i}"] = aname
1541-
sub["olv"] = aname
1541+
if acc_dtype != out_dtype:
1542+
# Define strides of accumulation buffer and allocate it
1543+
sub["lv1"] = acc_name
1544+
sub["olv"] = acc_name
15421545

1543-
alloc += cgen.make_declare(
1544-
[list(range(nnested)) + ["x"] * len(axis)],
1545-
[adtype],
1546-
dict(sub, lv0=aname),
1547-
)
1548-
alloc += cgen.make_alloc([order1], adtype, sub)
1549-
alloc += cgen.make_checks(
1550-
[list(range(nnested)) + ["x"] * len(axis)],
1551-
[adtype],
1552-
dict(sub, lv0=aname),
1546+
acc_sub = sub | {"lv0": acc_name}
1547+
alloc += (
1548+
cgen.make_declare(
1549+
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
1550+
)
1551+
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
1552+
+ cgen.make_checks(
1553+
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
1554+
)
15531555
)
15541556

15551557
identity = self.scalar_op.identity
1556-
15571558
if np.isposinf(identity):
1558-
if input.type.dtype in ("float32", "float64"):
1559+
if inp.type.dtype in ("float32", "float64"):
15591560
identity = "__builtin_inf()"
1560-
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
1561+
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
15611562
identity = "1"
15621563
else:
1563-
identity = "NPY_MAX_" + str(input.type.dtype).upper()
1564+
identity = "NPY_MAX_" + str(inp.type.dtype).upper()
15641565
elif np.isneginf(identity):
1565-
if input.type.dtype in ("float32", "float64"):
1566+
if inp.type.dtype in ("float32", "float64"):
15661567
identity = "-__builtin_inf()"
1567-
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
1568+
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
15681569
identity = "0"
15691570
else:
1570-
identity = "NPY_MIN_" + str(input.type.dtype).upper()
1571+
identity = "NPY_MIN_" + str(inp.type.dtype).upper()
15711572
elif identity is None:
15721573
raise TypeError(f"The {self.scalar_op} does not define an identity.")
15731574

1574-
task0_decl = f"{adtype}& {aname}_i = *{aname}_iter;\n{aname}_i = {identity};"
1575-
1576-
task1_decl = f"{idtype}& {inames[0]}_i = *{inames[0]}_iter;\n"
1575+
initial_value = f"{acc_name}_i = {identity};"
15771576

1578-
task1_code = self.scalar_op.c_code(
1577+
inner_task = self.scalar_op.c_code(
15791578
Apply(
15801579
self.scalar_op,
15811580
[
@@ -1588,44 +1587,44 @@ def _c_all(self, node, name, inames, onames, sub):
15881587
],
15891588
),
15901589
None,
1591-
[f"{aname}_i", f"{inames[0]}_i"],
1592-
[f"{aname}_i"],
1590+
[f"{acc_name}_i", f"{inp_name}_i"],
1591+
[f"{acc_name}_i"],
15931592
sub,
15941593
)
1595-
code1 = f"""
1596-
{{
1597-
{task1_decl}
1598-
{task1_code}
1599-
}}
1600-
"""
16011594

1602-
if node.inputs[0].type.ndim:
1603-
if len(axis) == 1:
1604-
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
1605-
else:
1606-
all_code = (
1607-
[("", "")] * nnested
1608-
+ [(task0_decl, "")]
1609-
+ [("", "")] * (len(axis) - 2)
1610-
+ [("", code1), ""]
1611-
)
1595+
if out.type.ndim == 0:
1596+
# Simple case where everything is reduced, no need for loop ordering
1597+
loop = cgen.make_complete_loop_careduce(
1598+
inp_var=inp_name,
1599+
acc_var=acc_name,
1600+
inp_dtype=inp_dtype,
1601+
acc_dtype=acc_dtype,
1602+
initial_value=initial_value,
1603+
inner_task=inner_task,
1604+
)
16121605
else:
1613-
all_code = [task0_decl + code1]
1614-
loop = cgen.make_loop_careduce(
1615-
[order, list(range(nnested)) + ["x"] * len(axis)],
1616-
[idtype, adtype],
1617-
all_code,
1618-
sub,
1619-
)
1606+
loop = cgen.make_reordered_loop_careduce(
1607+
inp_var=inp_name,
1608+
acc_var=acc_name,
1609+
inp_dtype=inp_dtype,
1610+
acc_dtype=acc_dtype,
1611+
inp_ndim=ndim,
1612+
reduction_axes=axis,
1613+
initial_value=initial_value,
1614+
inner_task=inner_task,
1615+
)
16201616

1621-
end = ""
1622-
if adtype != odtype:
1623-
end = f"""
1624-
PyArray_CopyInto({oname}, {aname});
1625-
"""
1626-
end += acc_type.c_cleanup(aname, sub)
1617+
if acc_dtype != out_dtype:
1618+
cast = dedent(
1619+
f"""
1620+
PyArray_CopyInto({out_name}, {acc_name});
1621+
{acc_type.c_cleanup(acc_name, sub)}
1622+
"""
1623+
)
1624+
else:
1625+
cast = ""
16271626

1628-
return decl, checks, alloc, loop, end
1627+
return setup, alloc, loop, cast
16291628

16301629
def c_code(self, node, name, inames, onames, sub):
16311630
code = "\n".join(self._c_all(node, name, inames, onames, sub))
@@ -1637,7 +1636,7 @@ def c_headers(self, **kwargs):
16371636

16381637
def c_code_cache_version_apply(self, node):
16391638
# the version corresponding to the c code in this Op
1640-
version = [9]
1639+
version = [10]
16411640

16421641
# now we insert versions for the ops on which we depend...
16431642
scalar_node = Apply(

0 commit comments

Comments
 (0)