Open
Description
Description
In #1429 we use the following expression to lower some xtensor indexing operations to tensor operations:
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]
This creates an unnecessary intermediate tensor and indexing operation. We should add rewrites to:
- Convert
arange(...)[slice(start, stop, step)]
directly to an equivalentarange(...)
- Convert
arange(...)[scalar_index]
to just the scalar value
In both cases, we need to take care of potentially negative indices (or None in slices). All variables in arange and the indexing may also be symbolic. However, we don't want to risk creating crazy graphs like #112
These optimizations would benefit both tensor and xtensor code, as the tensor rewrites would automatically apply after xtensor lowering.
Files to change
pytensor/tensor/rewriting/subtensor_lift.py
: Add new optimization patterns
Implementation details
Two new pattern rewrites should be added:
- A rewrite that converts
Subtensor(ARange, slice(...))
to a directARange
with adjusted parameters - A rewrite that converts
Subtensor(ARange, scalar)
to a constant or scalar value
Current behavior
If we profile or print the current graph for a function like:
import pytensor
import pytensor.tensor as pt
shape_val = pt.lscalar("shape_val")
intermediate = pt.arange(shape_val)[:1]
fn = pytensor.function([shape_val], intermediate)
fn.dprint()
# Subtensor{:stop} [id A] 1
# ├─ ARange{dtype='int64'} [id B] 0
# │ ├─ 0 [id C]
# │ ├─ shape_val [id D]
# │ └─ 1 [id E]
# └─ 1 [id F]
We'll see that the compiled graph includes:
- An ARange operation to create a sequence from 0 to shape_val
- A Subtensor operation to apply the slice that keeps only one value (in this case)
- This creates an unnecessary intermediate tensor