diff --git a/doc/dev_start_guide.rst b/doc/dev_start_guide.rst index 010d0ffb75..6c5f5f581b 100644 --- a/doc/dev_start_guide.rst +++ b/doc/dev_start_guide.rst @@ -209,7 +209,8 @@ You can now build the documentation from the root of the project with: .. code-block:: bash - python -m sphinx -b html ./doc ./html + # -j for parallel and faster doc build + sphinx-build -b html ./doc ./html -j auto Afterward, you can go to `html/index.html` and navigate the changes in a browser. One way to do this is to go to the `html` directory and run: @@ -219,7 +220,7 @@ Afterward, you can go to `html/index.html` and navigate the changes in a browser python -m http.server -**Do not commit the `html` directory. The documentation is built automatically.** +**Do not commit the `html` directory.** For more documentation customizations such as different formats e.g., PDF, refer to the `Sphinx documentation `_. Other tools that might help diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 42c7304b5c..75063551b7 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -1,5 +1,5 @@ Adding JAX, Numba and Pytorch support for `Op`\s -======================================= +================================================ PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function. @@ -7,7 +7,7 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Py This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`. Step 1: Identify the PyTensor :class:`Op` you'd like to implement ------------------------------------------------------------------------- +----------------------------------------------------------------- Find the source for the PyTensor :class:`Op` you'd like to be supported and identify the function signature and return values. These can be determined by @@ -97,8 +97,8 @@ Next, we look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` in Python. This method is effectively what needs to be implemented. -Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close) ---------------------------------------------------------- +Step 2: Find the relevant or close method in JAX/Numba/Pytorch +-------------------------------------------------------------- With a precise idea of what the PyTensor :class:`Op` does we need to figure out how to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named @@ -269,7 +269,7 @@ and :func:`torch.cumprod` z[0] = np.cumprod(x, axis=self.axis) Step 3: Register the function with the respective dispatcher ---------------------------------------------------------------- +------------------------------------------------------------ With the PyTensor `Op` replicated, we'll need to register the function with the backends `Linker`. This is done through the use of @@ -626,28 +626,26 @@ Step 4: Write tests Note ---- -In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: +Due to restrictions with JAX JIT compiler as reported in issue `#654 `_, +PyTensor graphs with dynamic shapes may be untranslatable to JAX. For example, this code snipper for :class:`Eye` `Op` .. code:: python - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" + x_at = pt.scalar(dtype=np.int64) + eye_var = pt.eye(x_at) + f = pytensor.function([x_at], eye_var, mode="JAX") + f(3) - # Create a symbolic input for `Eye` - x_at = pt.scalar() +cannot be translated to JAX, since it involved a dynamic shape. This is one issue that may pop up during +linking an `Op` to JAX. - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) +Note that not that all dynamic shapes are disallowed. +For example, if the function depends on input shapes, it still works. +This code snippet gives the answer that is expected in the example above. - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) - - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) +.. code:: python -This one nowadays leads to a test failure due to new restrictions in JAX + JIT, -as reported in issue `#654 `_. -All jitted functions now must have constant shape, which means a graph like the -one of :class:`Eye` can never be translated to JAX, since it's fundamentally a -function with dynamic shapes. In other words, only PyTensor graphs with static shapes -can be translated to JAX at the moment. \ No newline at end of file + x_at = pt.vector(dtype=np.int64) + eye_var = pt.eye(x_at.shape[0]) + f = pytensor.function([x_at], eye_var, mode="JAX") + f([3, 3, 3]) \ No newline at end of file diff --git a/doc/internal/metadocumentation.rst b/doc/internal/metadocumentation.rst index e5c2be28de..a618c1e4ed 100644 --- a/doc/internal/metadocumentation.rst +++ b/doc/internal/metadocumentation.rst @@ -8,33 +8,7 @@ Documentation Documentation AKA Meta-Documentation How to build documentation -------------------------- -Let's say you are writing documentation, and want to see the `sphinx -`__ output before you push it. -The documentation will be generated in the ``html`` directory. - -.. code-block:: bash - - cd PyTensor/ - python ./doc/scripts/docgen.py - -If you don't want to generate the pdf, do the following: - -.. code-block:: bash - - cd PyTensor/ - python ./doc/scripts/docgen.py --nopdf - - -For more details: - -.. code-block:: bash - - $ python doc/scripts/docgen.py --help - Usage: doc/scripts/docgen.py [OPTIONS] - -o : output the html files in the specified dir - --rst: only compile the doc (requires sphinx) - --nopdf: do not produce a PDF file from the doc, only HTML - --help: this help +Refer to relevant section of :doc:`../dev_start_guide`. Use ReST for documentation --------------------------