Skip to content

Optimize slices of arange #1431

Open
Open
@ricardoV94

Description

@ricardoV94

Description

In #1429 we use the following expression to lower some xtensor indexing operations to tensor operations:

adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]

This creates an unnecessary intermediate tensor and indexing operation. We should add rewrites to:

  1. Convert arange(...)[slice(start, stop, step)] directly to an equivalent arange(...)
  2. Convert arange(...)[scalar_index] to just the scalar value

In both cases, we need to take care of potentially negative indices (or None in slices). All variables in arange and the indexing may also be symbolic. However, we don't want to risk creating crazy graphs like #112

These optimizations would benefit both tensor and xtensor code, as the tensor rewrites would automatically apply after xtensor lowering.

Files to change

  • pytensor/tensor/rewriting/subtensor_lift.py: Add new optimization patterns

Implementation details

Two new pattern rewrites should be added:

  1. A rewrite that converts Subtensor(ARange, slice(...)) to a direct ARange with adjusted parameters
  2. A rewrite that converts Subtensor(ARange, scalar) to a constant or scalar value

Current behavior

If we profile or print the current graph for a function like:

import pytensor
import pytensor.tensor as pt

shape_val = pt.lscalar("shape_val")
intermediate = pt.arange(shape_val)[:1]
fn = pytensor.function([shape_val], intermediate)
fn.dprint()
# Subtensor{:stop} [id A] 1
#  ├─ ARange{dtype='int64'} [id B] 0
#  │  ├─ 0 [id C]
#  │  ├─ shape_val [id D]
#  │  └─ 1 [id E]
#  └─ 1 [id F]

We'll see that the compiled graph includes:

  1. An ARange operation to create a sequence from 0 to shape_val
  2. A Subtensor operation to apply the slice that keeps only one value (in this case)
  3. This creates an unnecessary intermediate tensor

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions