Skip to content

Downstream 1317 #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ Here's an example for :class:`IfElse`:
return res if n_outs > 1 else res[0]


Step 3: Register the function with the `jax_funcify` dispatcher
Step 3: Register the function with the `_jax_funcify` dispatcher
---------------------------------------------------------------

With the PyTensor `Op` replicated in JAX, we’ll need to register the
function with the PyTensor JAX `Linker`. This is done through the use of
`singledispatch`. If you don't know how `singledispatch` works, see the
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.

The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.basic._numba_funcify` and
:func:`pytensor.link.jax.dispatch.jax_funcify`.

Here’s an example for the `Eye`\ `Op`:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def jax_funcify_FunctionGraph(
return fgraph_to_python(
fgraph,
jax_funcify,
type_conversion_fn=jax_typify,
const_conversion_fn=jax_typify,
fgraph_name=fgraph_name,
**kwargs,
)
Expand Down
6 changes: 5 additions & 1 deletion pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# isort: off
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
from pytensor.link.numba.dispatch.basic import (
numba_funcify,
numba_const_convert,
numba_njit,
)

# Load dispatch specializations
import pytensor.link.numba.dispatch.scalar
Expand Down
Loading