Skip to content

Commit 3960ea1

Browse files
Apply suggestions from code review
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 3d4bc82 commit 3960ea1

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/array_api_extra/testing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
5555
Tag a function to be tested on lazy backends.
5656
5757
Tag a function, which must be imported in the test module globals, so that when any
58-
tests defined in the same module are executed with `xp=jax.numpy` the function is
58+
tests defined in the same module are executed with ``xp=jax.numpy`` the function is
5959
replaced with a jitted version of itself, and when it is executed with
60-
`xp=dask.array` the function will raise if it attempts to materialize the graph.
60+
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
6161
This will be later expanded to provide test coverage for other lazy backends.
6262
6363
In order for the tag to be effective, the test or a fixture must call
@@ -69,7 +69,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
6969
Function to be tested.
7070
allow_dask_compute : int, optional
7171
Number of times `func` is allowed to internally materialize the Dask graph. This
72-
is typically triggered by `bool()`, `float()`, or `np.asarray()`.
72+
is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``.
7373
7474
Set to 1 if you are aware that `func` converts the input parameters to numpy and
7575
want to let it do so at least for the time being, knowing that it is going to be
@@ -78,16 +78,16 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
7878
If a test needs values higher than 1 to pass, it is a canary that the conversion
7979
to numpy/bool/float is happening multiple times, which translates to multiple
8080
computations of the whole graph. Short of making the function fully lazy, you
81-
should at least add explicit calls to `np.asarray()` early in the function.
81+
should at least add explicit calls to ``np.asarray()`` early in the function.
8282
*Note:* the counter of `allow_dask_compute` resets after each call to `func`, so
8383
a test function that invokes `func` multiple times should still work with this
8484
parameter set to 1.
8585
8686
Default: 0, meaning that `func` must be fully lazy and never materialize the
8787
graph.
8888
jax_jit : bool, optional
89-
Set to True to replace `func` with `jax.jit(func)` after calling the
90-
:func:`patch_lazy_xp_functions` test helper with `xp=jax.numpy`. Set to False if
89+
Set to True to replace `func` with ``jax.jit(func)`` after calling the
90+
:func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False if
9191
`func` is only compatible with eager (non-jitted) JAX. Default: True.
9292
static_argnums : int | Sequence[int], optional
9393
Passed to jax.jit. Positional arguments to treat as static (compile-time
@@ -104,7 +104,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
104104
105105
Examples
106106
--------
107-
In `test_mymodule.py`::
107+
In ``test_mymodule.py``::
108108
109109
from array_api_extra.testing import lazy_xp_function from mymodule import myfunc
110110
@@ -145,12 +145,12 @@ def patch_lazy_xp_functions(
145145
"""
146146
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
147147
148-
If `xp==jax.numpy`, search for all functions which have been tagged with
148+
If ``xp==jax.numpy``, search for all functions which have been tagged with
149149
:func:`lazy_xp_function` in the globals of the module that defines the current test
150150
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
151151
152-
If `xp==dask.array`, wrap the functions with a decorator that disables `compute()`
153-
and `persist()`.
152+
If ``xp==dask.array``, wrap the functions with a decorator that disables ``compute()``
153+
and ``persist()``.
154154
155155
This function should be typically called by your library's `xp` fixture that runs
156156
tests on multiple backends::
@@ -248,7 +248,7 @@ def _allow_dask_compute(
248248
msg = (
249249
f"Called `dask.compute()` or `dask.persist()` {n + 1} times, "
250250
f"but {n_str} calls are allowed. Set "
251-
f"`lazy_xp_function({func_name}, allow_dask_compute={n + 1}) "
251+
f"`lazy_xp_function({func_name}, allow_dask_compute={n + 1})` "
252252
"to allow for more (but note that this will harm performance). "
253253
)
254254

0 commit comments

Comments
 (0)