Open
Description
Description
JAX jitting can be insanely slow when there are large constants in the graph. We could add a helper to convert any large constants to symbolic inputs (we already did some constant folding work on our end anyway), so JAX can't get hang up on those.
See related discussion on their side: jax-ml/jax#21300
The idea is to have a pytensor.graph.replace.replace_large_constants_by_inputs
that returns the graph with constants replaced by PyTensor input variables and the respective values