Open
Description
While converting to the array api for magpylib
in magpylib/magpylib#844, a few of the functions(algorithms for some elliptic integrals) have the unavoidable pattern
tol = 1e-6
while (xp.all(val < tol)):
val = make_val_better(val)
return val
This doesn't directly work for jit in jax
(and I assume other lazy backends). jax
offers a jax.lax.while_loop
primative to replace this, so it would be good to dispatch to that with an array_api_extra.lazy_while
.
There's also a jax.lax.fori_loop
primitive that may be similarly useful. I have never used dask, and have no clue how to do this in dask, but I assume it also needs to be dispatched somehow.
I'll likely do this as part of magpylib/magpylib#844, and I can upstream it after.