@@ -55,9 +55,9 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
55
55
Tag a function to be tested on lazy backends.
56
56
57
57
Tag a function, which must be imported in the test module globals, so that when any
58
- tests defined in the same module are executed with `xp=jax.numpy` the function is
58
+ tests defined in the same module are executed with `` xp=jax.numpy` ` the function is
59
59
replaced with a jitted version of itself, and when it is executed with
60
- `xp=dask.array` the function will raise if it attempts to materialize the graph.
60
+ `` xp=dask.array` ` the function will raise if it attempts to materialize the graph.
61
61
This will be later expanded to provide test coverage for other lazy backends.
62
62
63
63
In order for the tag to be effective, the test or a fixture must call
@@ -69,7 +69,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
69
69
Function to be tested.
70
70
allow_dask_compute : int, optional
71
71
Number of times `func` is allowed to internally materialize the Dask graph. This
72
- is typically triggered by `bool()`, `float()`, or `np.asarray()`.
72
+ is typically triggered by `` bool()`` , `` float()`` , or `` np.asarray()` `.
73
73
74
74
Set to 1 if you are aware that `func` converts the input parameters to numpy and
75
75
want to let it do so at least for the time being, knowing that it is going to be
@@ -78,16 +78,16 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
78
78
If a test needs values higher than 1 to pass, it is a canary that the conversion
79
79
to numpy/bool/float is happening multiple times, which translates to multiple
80
80
computations of the whole graph. Short of making the function fully lazy, you
81
- should at least add explicit calls to `np.asarray()` early in the function.
81
+ should at least add explicit calls to `` np.asarray()` ` early in the function.
82
82
*Note:* the counter of `allow_dask_compute` resets after each call to `func`, so
83
83
a test function that invokes `func` multiple times should still work with this
84
84
parameter set to 1.
85
85
86
86
Default: 0, meaning that `func` must be fully lazy and never materialize the
87
87
graph.
88
88
jax_jit : bool, optional
89
- Set to True to replace `func` with `jax.jit(func)` after calling the
90
- :func:`patch_lazy_xp_functions` test helper with `xp=jax.numpy`. Set to False if
89
+ Set to True to replace `func` with `` jax.jit(func)` ` after calling the
90
+ :func:`patch_lazy_xp_functions` test helper with `` xp=jax.numpy` `. Set to False if
91
91
`func` is only compatible with eager (non-jitted) JAX. Default: True.
92
92
static_argnums : int | Sequence[int], optional
93
93
Passed to jax.jit. Positional arguments to treat as static (compile-time
@@ -104,7 +104,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
104
104
105
105
Examples
106
106
--------
107
- In `test_mymodule.py`::
107
+ In `` test_mymodule.py` `::
108
108
109
109
from array_api_extra.testing import lazy_xp_function from mymodule import myfunc
110
110
@@ -145,12 +145,12 @@ def patch_lazy_xp_functions(
145
145
"""
146
146
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
147
147
148
- If `xp==jax.numpy`, search for all functions which have been tagged with
148
+ If `` xp==jax.numpy` `, search for all functions which have been tagged with
149
149
:func:`lazy_xp_function` in the globals of the module that defines the current test
150
150
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
151
151
152
- If `xp==dask.array`, wrap the functions with a decorator that disables `compute()`
153
- and `persist()`.
152
+ If `` xp==dask.array`` , wrap the functions with a decorator that disables `` compute()` `
153
+ and `` persist()` `.
154
154
155
155
This function should be typically called by your library's `xp` fixture that runs
156
156
tests on multiple backends::
@@ -248,7 +248,7 @@ def _allow_dask_compute(
248
248
msg = (
249
249
f"Called `dask.compute()` or `dask.persist()` { n + 1 } times, "
250
250
f"but { n_str } calls are allowed. Set "
251
- f"`lazy_xp_function({ func_name } , allow_dask_compute={ n + 1 } ) "
251
+ f"`lazy_xp_function({ func_name } , allow_dask_compute={ n + 1 } )` "
252
252
"to allow for more (but note that this will harm performance). "
253
253
)
254
254
0 commit comments