Skip to content

ENH: lazy_while (and friends) #303

Open
@purepani

Description

@purepani

While converting to the array api for magpylib in magpylib/magpylib#844, a few of the functions(algorithms for some elliptic integrals) have the unavoidable pattern

tol = 1e-6
while (xp.all(val < tol)):
    val = make_val_better(val)
return val

This doesn't directly work for jit in jax(and I assume other lazy backends). jax offers a jax.lax.while_loop primative to replace this, so it would be good to dispatch to that with an array_api_extra.lazy_while.

There's also a jax.lax.fori_loop primitive that may be similarly useful. I have never used dask, and have no clue how to do this in dask, but I assume it also needs to be dispatched somehow.

I'll likely do this as part of magpylib/magpylib#844, and I can upstream it after.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions