Closed
Description
Description
Reported by @aseyboldt
from timeit import timeit
from functools import partial
import pytensor
import pytensor.tensor as pt
import numpy as np
from jax import block_until_ready
N = 256
r = 10
x_test = np.random.uniform(size=(N, N, N))
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
for axis in [0, 1, 2]:
y = x.sum(axis)
c_fn = pytensor.function([], y, mode="FAST_RUN")
numba_fn = pytensor.function([], y, mode="NUMBA")
jax_fn_ = pytensor.function([], y, mode="JAX")
jax_fn = lambda : np.asarray(jax_fn_())
numpy_fn = partial(np.sum, x_test, axis=axis)
np.testing.assert_allclose(c_fn(), numpy_fn())
np.testing.assert_allclose(numba_fn(), numpy_fn())
np.testing.assert_allclose(jax_fn(), numpy_fn())
print(f"\n{axis=}")
for name, fn in [("C", c_fn), ("numba", numba_fn), ("jax", jax_fn), ("numpy", numpy_fn)]:
print(f" | {name}: {timeit(fn, number=r) / r: .4f}s")
I'm running JAX on a CPU
axis=0
| C: 1.8741s
| numba: 1.7838s
| jax: 0.7674s
| numpy: 0.0075s
axis=1
| C: 0.0286s
| numba: 0.0280s
| jax: 0.0133s
| numpy: 0.0083s
axis=2
| C: 0.0142s
| numba: 0.0153s
| jax: 0.0046s
| numpy: 0.0069s
#931 makes numba slightly better in axis=0 at the expense of doing worse on axis=2
axis=0
| C: 1.6623s
| numba: 0.0203s
| jax: 0.6826s
| numpy: 0.0098s
axis=1
| C: 0.0294s
| numba: 0.0288s
| jax: 0.0126s
| numpy: 0.0078s
axis=2
| C: 0.0141s
| numba: 0.1160s
| jax: 0.0044s
| numpy: 0.0064s
In any case numpy is wiping our ass :)
Surprisingly JAX is also doing bad on the first case, although not as bad as C/Numba. Performance is probably due to bad iteration order / cache access