Skip to content

Reductions along leading axes can be incredibly slow in C and Numba backends #935

Closed
@ricardoV94

Description

@ricardoV94

Description

Reported by @aseyboldt

from timeit import timeit
from functools import partial

import pytensor
import pytensor.tensor as pt
import numpy as np

from jax import block_until_ready

N = 256
r = 10

x_test = np.random.uniform(size=(N, N, N))
x = pytensor.shared(x_test, name="x", shape=x_test.shape)

for axis in [0, 1, 2]:    
    y = x.sum(axis)    
    c_fn = pytensor.function([], y, mode="FAST_RUN")
    numba_fn = pytensor.function([], y, mode="NUMBA")
    jax_fn_ = pytensor.function([], y, mode="JAX")
    jax_fn = lambda : np.asarray(jax_fn_())
    numpy_fn = partial(np.sum, x_test, axis=axis)
    
    np.testing.assert_allclose(c_fn(), numpy_fn())
    np.testing.assert_allclose(numba_fn(), numpy_fn())
    np.testing.assert_allclose(jax_fn(), numpy_fn())
    print(f"\n{axis=}")
    for name, fn in [("C", c_fn), ("numba", numba_fn), ("jax", jax_fn), ("numpy", numpy_fn)]:
        print(f"  | {name}: {timeit(fn, number=r) / r: .4f}s")       

I'm running JAX on a CPU

axis=0
  | C:  1.8741s
  | numba:  1.7838s
  | jax:  0.7674s
  | numpy:  0.0075s
axis=1
  | C:  0.0286s
  | numba:  0.0280s
  | jax:  0.0133s
  | numpy:  0.0083s
axis=2
  | C:  0.0142s
  | numba:  0.0153s
  | jax:  0.0046s
  | numpy:  0.0069s

#931 makes numba slightly better in axis=0 at the expense of doing worse on axis=2

axis=0
  | C:  1.6623s
  | numba:  0.0203s
  | jax:  0.6826s
  | numpy:  0.0098s
axis=1
  | C:  0.0294s
  | numba:  0.0288s
  | jax:  0.0126s
  | numpy:  0.0078s
axis=2
  | C:  0.0141s
  | numba:  0.1160s
  | jax:  0.0044s
  | numpy:  0.0064s

In any case numpy is wiping our ass :)

Surprisingly JAX is also doing bad on the first case, although not as bad as C/Numba. Performance is probably due to bad iteration order / cache access

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions