Skip to content

Commit 7174d7d

Browse files
committed
Generalize and simplify local_reduce_join
1 parent 58413a0 commit 7174d7d

File tree

2 files changed

+87
-71
lines changed

2 files changed

+87
-71
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,68 +1626,54 @@ def local_reduce_chain(fgraph, node):
16261626
@node_rewriter([CAReduce])
16271627
def local_reduce_join(fgraph, node):
16281628
"""
1629-
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
1629+
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
16301630
1631-
Notes
1632-
-----
1633-
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
1634-
all cases.
1635-
1636-
Currently we must reduce on axis 0. It is probably extensible to the case
1637-
where we join and reduce on the same set of axis.
1631+
When a, b have a dim length of 1 along the join axis
16381632
16391633
"""
1640-
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
1641-
join_node = node.inputs[0].owner
1642-
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
1643-
return
1634+
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
1635+
return None
16441636

1645-
if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
1646-
# Support only 2 inputs for now
1647-
if len(join_node.inputs) != 3:
1648-
return
1649-
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
1650-
return
1651-
elif len(join_node.inputs) <= 2:
1652-
# This is a useless join that should get removed by another rewrite?
1653-
return
1637+
[joined_out] = node.inputs
1638+
joined_node = joined_out.owner
1639+
join_axis_tensor, *joined_inputs = joined_node.inputs
16541640

1655-
new_inp = []
1656-
for inp in join_node.inputs[1:]:
1657-
inp = inp.owner
1658-
if not inp:
1659-
return
1660-
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
1661-
"x",
1662-
*range(inp.inputs[0].ndim),
1663-
):
1664-
return
1665-
new_inp.append(inp.inputs[0])
1666-
ret = Elemwise(node.op.scalar_op)(*new_inp)
1641+
n_joined_inputs = len(joined_inputs)
1642+
if n_joined_inputs < 2:
1643+
# Let some other rewrite get rid of this useless Join
1644+
return None
1645+
if n_joined_inputs > 2 and not hasattr(node.op.scalar_op, "nfunc_variadic"):
1646+
# We don't rewrite if a single Elemwise cannot take all inputs at once
1647+
return None
16671648

1668-
if ret.dtype != node.outputs[0].dtype:
1669-
# The reduction do something about the dtype.
1670-
return
1649+
if not isinstance(join_axis_tensor, Constant):
1650+
return None
1651+
join_axis = join_axis_tensor.data
16711652

1672-
reduce_axis = node.op.axis
1673-
if reduce_axis is None:
1674-
reduce_axis = tuple(range(node.inputs[0].ndim))
1653+
# Check whether reduction happens on joined axis
1654+
reduce_op = node.op
1655+
reduce_axis = reduce_op.axis
1656+
if reduce_axis is None:
1657+
if joined_out.type.ndim > 1:
1658+
return None
1659+
elif reduce_axis != (join_axis,):
1660+
return None
16751661

1676-
if len(reduce_axis) != 1 or 0 not in reduce_axis:
1677-
return
1662+
# Check all inputs are broadcastable along the join axis and squeeze those dims away
1663+
new_inp = []
1664+
for inp in joined_inputs:
1665+
if not inp.type.broadcastable[join_axis]:
1666+
return None
1667+
# Let other rewrites clean up useless expand_dims
1668+
new_inp.append(inp.squeeze(join_axis))
16781669

1679-
# We add the new check late to don't add extra warning.
1680-
try:
1681-
join_axis = get_underlying_scalar_constant_value(
1682-
join_node.inputs[0], only_process_constants=True
1683-
)
1670+
ret = Elemwise(node.op.scalar_op)(*new_inp)
16841671

1685-
if join_axis != reduce_axis[0]:
1686-
return
1687-
except NotScalarConstantError:
1688-
return
1672+
if ret.dtype != node.outputs[0].dtype:
1673+
# The reduction do something about the dtype.
1674+
return None
16891675

1690-
return [ret]
1676+
return [ret]
16911677

16921678

16931679
@register_infer_shape

tests/tensor/rewriting/test_math.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
32313231
class TestLocalReduce:
32323232
def setup_method(self):
32333233
self.mode = get_default_mode().including(
3234-
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
3234+
"canonicalize", "specialize", "uncanonicalize"
32353235
)
32363236

32373237
def test_local_reduce_broadcast_all_0(self):
@@ -3304,62 +3304,92 @@ def test_local_reduce_broadcast_some_1(self):
33043304
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
33053305
)
33063306

3307-
def test_local_reduce_join(self):
3307+
3308+
class TestReduceJoin:
3309+
def setup_method(self):
3310+
self.mode = get_default_mode().including("canonicalize", "specialize")
3311+
3312+
@pytest.mark.parametrize(
3313+
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
3314+
)
3315+
def test_local_reduce_join(self, op, nin):
33083316
vx = matrix()
33093317
vy = matrix()
33103318
vz = matrix()
33113319
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
33123320
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
33133321
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3314-
# Test different reduction scalar operation
3315-
for out, res in [
3316-
(pt_max((vx, vy), 0), np.max((x, y), 0)),
3317-
(pt_min((vx, vy), 0), np.min((x, y), 0)),
3318-
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
3319-
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
3320-
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
3321-
]:
3322-
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
3323-
assert (f(x, y, z) == res).all(), out
3324-
topo = f.maker.fgraph.toposort()
3325-
assert len(topo) <= 2, out
3326-
assert isinstance(topo[-1].op, Elemwise), out
33273322

3323+
inputs = (vx, vy, vz)[:nin]
3324+
test_values = (x, y, z)[:nin]
3325+
3326+
out = op(inputs, axis=0)
3327+
f = function(inputs, out, mode=self.mode)
3328+
np.testing.assert_allclose(
3329+
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
3330+
)
3331+
topo = f.maker.fgraph.toposort()
3332+
assert len(topo) <= 2
3333+
assert isinstance(topo[-1].op, Elemwise)
3334+
3335+
def test_type(self):
33283336
# Test different axis for the join and the reduction
33293337
# We must force the dtype, of otherwise, this tests will fail
33303338
# on 32 bit systems
33313339
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
33323340

33333341
f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
3334-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3342+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33353343
topo = f.maker.fgraph.toposort()
33363344
assert isinstance(topo[-1].op, Elemwise)
33373345

33383346
# Test a case that was bugged in a old PyTensor bug
33393347
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)
33403348

3341-
utt.assert_allclose(f(), [15, 15])
3349+
np.testing.assert_allclose(f(), [15, 15])
33423350
topo = f.maker.fgraph.toposort()
33433351
assert not isinstance(topo[-1].op, Elemwise)
33443352

33453353
# This case could be rewritten
33463354
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33473355
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
3348-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3356+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33493357
topo = f.maker.fgraph.toposort()
33503358
assert not isinstance(topo[-1].op, Elemwise)
33513359

33523360
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33533361
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
3354-
utt.assert_allclose(f(), [15, 15])
3362+
np.testing.assert_allclose(f(), [15, 15])
33553363
topo = f.maker.fgraph.toposort()
33563364
assert not isinstance(topo[-1].op, Elemwise)
33573365

3366+
def test_not_supported_axis_none(self):
33583367
# Test that the rewrite does not crash in one case where it
33593368
# is not applied. Reported at
33603369
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3370+
vx = matrix()
3371+
vy = matrix()
3372+
vz = matrix()
3373+
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
3374+
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
3375+
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3376+
33613377
out = pt_sum([vx, vy, vz], axis=None)
3362-
f = function([vx, vy, vz], out)
3378+
f = function([vx, vy, vz], out, mode=self.mode)
3379+
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))
3380+
3381+
def test_not_supported_unequal_shapes(self):
3382+
# Not the same shape along the join axis
3383+
vx = matrix(shape=(1, 3))
3384+
vy = matrix(shape=(2, 3))
3385+
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
3386+
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
3387+
out = pt_sum(join(0, vx, vy), axis=0)
3388+
3389+
f = function([vx, vy], out, mode=self.mode)
3390+
np.testing.assert_allclose(
3391+
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
3392+
)
33633393

33643394

33653395
def test_local_useless_adds():

0 commit comments

Comments
 (0)