Skip to content

Minor edits to contributing documentation #930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/dev_start_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ You can now build the documentation from the root of the project with:

.. code-block:: bash

python -m sphinx -b html ./doc ./html
# -j for parallel and faster doc build
sphinx-build -b html ./doc ./html -j auto


Afterward, you can go to `html/index.html` and navigate the changes in a browser. One way to do this is to go to the `html` directory and run:
Expand All @@ -219,7 +220,7 @@ Afterward, you can go to `html/index.html` and navigate the changes in a browser

python -m http.server

**Do not commit the `html` directory. The documentation is built automatically.**
**Do not commit the `html` directory.**
For more documentation customizations such as different formats e.g., PDF, refer to the `Sphinx documentation <https://www.sphinx-doc.org/en/master/usage/builders/index.html>`_.

Other tools that might help
Expand Down
44 changes: 21 additions & 23 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
Adding JAX, Numba and Pytorch support for `Op`\s
=======================================
================================================

PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.

This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.

Step 1: Identify the PyTensor :class:`Op` you'd like to implement
------------------------------------------------------------------------
-----------------------------------------------------------------

Find the source for the PyTensor :class:`Op` you'd like to be supported and
identify the function signature and return values. These can be determined by
Expand Down Expand Up @@ -97,8 +97,8 @@ Next, we look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented.

Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close)
---------------------------------------------------------
Step 2: Find the relevant or close method in JAX/Numba/Pytorch
--------------------------------------------------------------

With a precise idea of what the PyTensor :class:`Op` does we need to figure out how
to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named
Expand Down Expand Up @@ -269,7 +269,7 @@ and :func:`torch.cumprod`
z[0] = np.cumprod(x, axis=self.axis)

Step 3: Register the function with the respective dispatcher
---------------------------------------------------------------
------------------------------------------------------------

With the PyTensor `Op` replicated, we'll need to register the
function with the backends `Linker`. This is done through the use of
Expand Down Expand Up @@ -626,28 +626,26 @@ Step 4: Write tests

Note
----
In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows:
Due to restrictions with JAX JIT compiler as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_,
PyTensor graphs with dynamic shapes may be untranslatable to JAX. For example, this code snipper for :class:`Eye` `Op`

.. code:: python

def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
x_at = pt.scalar(dtype=np.int64)
eye_var = pt.eye(x_at)
f = pytensor.function([x_at], eye_var, mode="JAX")
f(3)

# Create a symbolic input for `Eye`
x_at = pt.scalar()
cannot be translated to JAX, since it involved a dynamic shape. This is one issue that may pop up during
linking an `Op` to JAX.

# Create a variable that is the output of an `Eye` `Op`
eye_var = pt.eye(x_at)
Note that not that all dynamic shapes are disallowed.
For example, if the function depends on input shapes, it still works.
This code snippet gives the answer that is expected in the example above.

# Create an PyTensor `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])

# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])
.. code:: python

This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
All jitted functions now must have constant shape, which means a graph like the
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
can be translated to JAX at the moment.
x_at = pt.vector(dtype=np.int64)
eye_var = pt.eye(x_at.shape[0])
f = pytensor.function([x_at], eye_var, mode="JAX")
f([3, 3, 3])
28 changes: 1 addition & 27 deletions doc/internal/metadocumentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,7 @@ Documentation Documentation AKA Meta-Documentation
How to build documentation
--------------------------

Let's say you are writing documentation, and want to see the `sphinx
<http://sphinx.pocoo.org/>`__ output before you push it.
The documentation will be generated in the ``html`` directory.

.. code-block:: bash

cd PyTensor/
python ./doc/scripts/docgen.py

If you don't want to generate the pdf, do the following:

.. code-block:: bash

cd PyTensor/
python ./doc/scripts/docgen.py --nopdf


For more details:

.. code-block:: bash

$ python doc/scripts/docgen.py --help
Usage: doc/scripts/docgen.py [OPTIONS]
-o <dir>: output the html files in the specified dir
--rst: only compile the doc (requires sphinx)
--nopdf: do not produce a PDF file from the doc, only HTML
--help: this help
Refer to relevant section of :doc:`../dev_start_guide`.

Use ReST for documentation
--------------------------
Expand Down
Loading