Skip to content

local_subtensor_merge can complicate graphs  #112

Open
@aseyboldt

Description

@aseyboldt

Description

The local_subtensor_merge op often makes graph worse instead of better:
https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/rewriting/subtensor.py#L475

import pytensor.tensor as pt
import pytensor

x = pt.dvector("x")
y = x[1:-1][1:-1][1:-1]

pytensor.config.optdb__max_use_ratio = 20
func = pytensor.function([x], y)

# Before rewriting:
"""
Subtensor{int64:int64:} [id A]
 |Subtensor{int64:int64:} [id B]
 | |Subtensor{int64:int64:} [id C]
 | | |x [id D]
 | | |ScalarConstant{1} [id E]
 | | |ScalarConstant{-1} [id F]
 | |ScalarConstant{1} [id G]
 | |ScalarConstant{-1} [id H]
 |ScalarConstant{1} [id I]
 |ScalarConstant{-1} [id J]
"""

After:

DeepCopyOp [id A] 27
 |Subtensor{int64:int64:int8} [id B] 26
   |x [id C]
   |ScalarFromTensor [id D] 24
   | |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 2)] [id E] 22
   |   |Elemwise{Composite{LE((i0 - i1), 0)}} [id F] 21
   |   | |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2))}}[(0, 0)] [id G] 19
   |   | | |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 1)] [id H] 12
   |   | | | |Elemwise{Composite{LE((i0 - i1), 0)}} [id I] 10
   |   | | | | |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id J] 8
   |   | | | | | |Elemwise{sub,no_inplace} [id K] 7
   |   | | | | |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id L] 4
   |   | | | | |   | |Elemwise{sub,no_inplace} [id M] 3
   |   | | | | |   |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id N] 1
   |   | | | | |   |   | |Shape_i{0} [id O] 0
   |   | | | | |   |   |   |x [id C]
   |   | | | | |   |   |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}}[(0, 0)] [id P] 2
   |   | | | | |   |     |Shape_i{0} [id O] 0
   |   | | | | |   |     |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id N] 1
   |   | | | | |   |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}} [id Q] 6
   |   | | | | |     |Elemwise{sub,no_inplace} [id M] 3
   |   | | | | |     |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id L] 4
   |   | | | | |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}}[(0, 0)] [id R] 9
   |   | | | |   |Elemwise{sub,no_inplace} [id K] 7
   |   | | | |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id J] 8
   |   | | | |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}} [id Q] 6
   |   | | | |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id J] 8
   |   | | | |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id L] 4
   |   | | |TensorFromScalar [id S] 18
   |   | | | |add [id T] 16
   |   | | |   |ScalarFromTensor [id U] 14
   |   | | |   | |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 1)] [id H] 12
   |   | | |   |ScalarFromTensor [id V] 5
   |   | | |     |Elemwise{sub,no_inplace} [id M] 3
   |   | | |Elemwise{sub,no_inplace} [id M] 3
   |   | |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0), i2), 0), i3), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0), i2), 0), i3)}}[(0, 0)] [id W] 20
   |   |   |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 2)] [id X] 11
   |   |   | |Elemwise{Composite{LE((i0 - i1), 0)}} [id I] 10
   |   |   | |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}} [id Q] 6
   |   |   | |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}}[(0, 0)] [id R] 9
   |   |   | |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id L] 4
   |   |   |TensorFromScalar [id Y] 17
   |   |   | |add [id Z] 15
   |   |   |   |ScalarFromTensor [id BA] 13
   |   |   |   | |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 2)] [id X] 11
   |   |   |   |ScalarFromTensor [id V] 5
   |   |   |Elemwise{sub,no_inplace} [id M] 3
   |   |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2))}}[(0, 0)] [id G] 19
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}}[(0, 0)] [id P] 2
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0), i2), 0), i3), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0), i2), 0), i3)}}[(0, 0)] [id W] 20
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id N] 1
   |ScalarFromTensor [id BB] 25
   | |Elemwise{Composite{Switch(i0, 0, minimum((i1 + i2), i3))}}[(0, 1)] [id BC] 23
   |   |Elemwise{Composite{LE((i0 - i1), 0)}} [id F] 21
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(1, i0), 0), i1)}}[(0, 0)] [id P] 2
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}(i0, 0, i1), 0, -1), i2))}}[(0, 0)] [id G] 19
   |   |Elemwise{Composite{Switch(LT(Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0), 0), 0, Composite{Switch(GE(i0, i1), i1, i0)}(Composite{Switch(LT(i0, i1), i2, i0)}((i0 - 1), 0, -1), i0))}} [id N] 1
   |ScalarConstant{1} [id BD]

I think this rewrite might be fine in some special cases with known shapes/indices, but in general I don't see why we would do this rewrite.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions