Skip to content

Commit 7f8a913

Browse files
committed
Generalize and simplify local_reduce_join
1 parent e2e6563 commit 7f8a913

File tree

2 files changed

+87
-71
lines changed

2 files changed

+87
-71
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,68 +1627,54 @@ def local_reduce_chain(fgraph, node):
16271627
@node_rewriter([CAReduce])
16281628
def local_reduce_join(fgraph, node):
16291629
"""
1630-
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
1630+
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
16311631
1632-
Notes
1633-
-----
1634-
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
1635-
all cases.
1636-
1637-
Currently we must reduce on axis 0. It is probably extensible to the case
1638-
where we join and reduce on the same set of axis.
1632+
When a, b have a dim length of 1 along the join axis
16391633
16401634
"""
1641-
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
1642-
join_node = node.inputs[0].owner
1643-
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
1644-
return
1635+
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
1636+
return None
16451637

1646-
if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
1647-
# Support only 2 inputs for now
1648-
if len(join_node.inputs) != 3:
1649-
return
1650-
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
1651-
return
1652-
elif len(join_node.inputs) <= 2:
1653-
# This is a useless join that should get removed by another rewrite?
1654-
return
1638+
[joined_out] = node.inputs
1639+
joined_node = joined_out.owner
1640+
join_axis_tensor, *joined_inputs = joined_node.inputs
16551641

1656-
new_inp = []
1657-
for inp in join_node.inputs[1:]:
1658-
inp = inp.owner
1659-
if not inp:
1660-
return
1661-
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
1662-
"x",
1663-
*range(inp.inputs[0].ndim),
1664-
):
1665-
return
1666-
new_inp.append(inp.inputs[0])
1667-
ret = Elemwise(node.op.scalar_op)(*new_inp)
1642+
n_joined_inputs = len(joined_inputs)
1643+
if n_joined_inputs < 2:
1644+
# Let some other rewrite get rid of this useless Join
1645+
return None
1646+
if n_joined_inputs > 2 and not hasattr(node.op.scalar_op, "nfunc_variadic"):
1647+
# We don't rewrite if a single Elemwise cannot take all inputs at once
1648+
return None
16681649

1669-
if ret.dtype != node.outputs[0].dtype:
1670-
# The reduction do something about the dtype.
1671-
return
1650+
if not isinstance(join_axis_tensor, Constant):
1651+
return None
1652+
join_axis = join_axis_tensor.data
16721653

1673-
reduce_axis = node.op.axis
1674-
if reduce_axis is None:
1675-
reduce_axis = tuple(range(node.inputs[0].ndim))
1654+
# Check whether reduction happens on joined axis
1655+
reduce_op = node.op
1656+
reduce_axis = reduce_op.axis
1657+
if reduce_axis is None:
1658+
if joined_out.type.ndim > 1:
1659+
return None
1660+
elif reduce_axis != (join_axis,):
1661+
return None
16761662

1677-
if len(reduce_axis) != 1 or 0 not in reduce_axis:
1678-
return
1663+
# Check all inputs are broadcastable along the join axis and squeeze those dims away
1664+
new_inp = []
1665+
for inp in joined_inputs:
1666+
if not inp.type.broadcastable[join_axis]:
1667+
return None
1668+
# Let other rewrites clean up useless expand_dims
1669+
new_inp.append(inp.squeeze(join_axis))
16791670

1680-
# We add the new check late to don't add extra warning.
1681-
try:
1682-
join_axis = get_underlying_scalar_constant_value(
1683-
join_node.inputs[0], only_process_constants=True
1684-
)
1671+
ret = Elemwise(node.op.scalar_op)(*new_inp)
16851672

1686-
if join_axis != reduce_axis[0]:
1687-
return
1688-
except NotScalarConstantError:
1689-
return
1673+
if ret.dtype != node.outputs[0].dtype:
1674+
# The reduction do something about the dtype.
1675+
return None
16901676

1691-
return [ret]
1677+
return [ret]
16921678

16931679

16941680
@register_infer_shape

tests/tensor/rewriting/test_math.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
32313231
class TestLocalReduce:
32323232
def setup_method(self):
32333233
self.mode = get_default_mode().including(
3234-
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
3234+
"canonicalize", "specialize", "uncanonicalize"
32353235
)
32363236

32373237
def test_local_reduce_broadcast_all_0(self):
@@ -3304,62 +3304,92 @@ def test_local_reduce_broadcast_some_1(self):
33043304
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
33053305
)
33063306

3307-
def test_local_reduce_join(self):
3307+
3308+
class TestReduceJoin:
3309+
def setup_method(self):
3310+
self.mode = get_default_mode().including("canonicalize", "specialize")
3311+
3312+
@pytest.mark.parametrize(
3313+
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
3314+
)
3315+
def test_local_reduce_join(self, op, nin):
33083316
vx = matrix()
33093317
vy = matrix()
33103318
vz = matrix()
33113319
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
33123320
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
33133321
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3314-
# Test different reduction scalar operation
3315-
for out, res in [
3316-
(pt_max((vx, vy), 0), np.max((x, y), 0)),
3317-
(pt_min((vx, vy), 0), np.min((x, y), 0)),
3318-
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
3319-
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
3320-
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
3321-
]:
3322-
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
3323-
assert (f(x, y, z) == res).all(), out
3324-
topo = f.maker.fgraph.toposort()
3325-
assert len(topo) <= 2, out
3326-
assert isinstance(topo[-1].op, Elemwise), out
33273322

3323+
inputs = (vx, vy, vz)[:nin]
3324+
test_values = (x, y, z)[:nin]
3325+
3326+
out = op(inputs, axis=0)
3327+
f = function(inputs, out, mode=self.mode)
3328+
np.testing.assert_allclose(
3329+
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
3330+
)
3331+
topo = f.maker.fgraph.toposort()
3332+
assert len(topo) <= 2
3333+
assert isinstance(topo[-1].op, Elemwise)
3334+
3335+
def test_type(self):
33283336
# Test different axis for the join and the reduction
33293337
# We must force the dtype, of otherwise, this tests will fail
33303338
# on 32 bit systems
33313339
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
33323340

33333341
f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
3334-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3342+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33353343
topo = f.maker.fgraph.toposort()
33363344
assert isinstance(topo[-1].op, Elemwise)
33373345

33383346
# Test a case that was bugged in a old PyTensor bug
33393347
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)
33403348

3341-
utt.assert_allclose(f(), [15, 15])
3349+
np.testing.assert_allclose(f(), [15, 15])
33423350
topo = f.maker.fgraph.toposort()
33433351
assert not isinstance(topo[-1].op, Elemwise)
33443352

33453353
# This case could be rewritten
33463354
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33473355
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
3348-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3356+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33493357
topo = f.maker.fgraph.toposort()
33503358
assert not isinstance(topo[-1].op, Elemwise)
33513359

33523360
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33533361
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
3354-
utt.assert_allclose(f(), [15, 15])
3362+
np.testing.assert_allclose(f(), [15, 15])
33553363
topo = f.maker.fgraph.toposort()
33563364
assert not isinstance(topo[-1].op, Elemwise)
33573365

3366+
def test_not_supported_axis_none(self):
33583367
# Test that the rewrite does not crash in one case where it
33593368
# is not applied. Reported at
33603369
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3370+
vx = matrix()
3371+
vy = matrix()
3372+
vz = matrix()
3373+
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
3374+
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
3375+
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3376+
33613377
out = pt_sum([vx, vy, vz], axis=None)
3362-
f = function([vx, vy, vz], out)
3378+
f = function([vx, vy, vz], out, mode=self.mode)
3379+
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))
3380+
3381+
def test_not_supported_unequal_shapes(self):
3382+
# Not the same shape along the join axis
3383+
vx = matrix(shape=(1, 3))
3384+
vy = matrix(shape=(2, 3))
3385+
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
3386+
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
3387+
out = pt_sum(join(0, vx, vy), axis=0)
3388+
3389+
f = function([vx, vy], out, mode=self.mode)
3390+
np.testing.assert_allclose(
3391+
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
3392+
)
33633393

33643394

33653395
def test_local_useless_adds():

0 commit comments

Comments
 (0)