Skip to content

Transform graph to make large constants symbolic inputs #1224

Open
@ricardoV94

Description

@ricardoV94

Description

JAX jitting can be insanely slow when there are large constants in the graph. We could add a helper to convert any large constants to symbolic inputs (we already did some constant folding work on our end anyway), so JAX can't get hang up on those.

See related discussion on their side: jax-ml/jax#21300

The idea is to have a pytensor.graph.replace.replace_large_constants_by_inputs that returns the graph with constants replaced by PyTensor input variables and the respective values

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions