From fe8ac4724abf3c0a4328100eed829217cd6951c5 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:31:36 +0000 Subject: [PATCH 1/9] Changed example for extending JAX. --- doc/extending/creating_a_numba_jax_op.rst | 148 ++++++++++++++-------- 1 file changed, 96 insertions(+), 52 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..8f15d39c61 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -7,55 +7,61 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will focus specifically on the JAX case, but the same mechanisms are used for Numba as well. -Step 1: Identify the PyTensor :class:`Op` you’d like to implement in JAX +Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX ------------------------------------------------------------------------ -Find the source for the PyTensor :class:`Op` you’d like to be supported in JAX, and -identify the function signature and return values. These can be determined by -looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar +Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and +identify the function signature and return values. These can be determined by +looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: .. code:: python - def make_node(self, n, m, k): - n = as_tensor_variable(n) - m = as_tensor_variable(m) - k = as_tensor_variable(k) - assert n.ndim == 0 - assert m.ndim == 0 - assert k.ndim == 0 - return Apply( - self, - [n, m, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) + def make_node(self, a, val): + a = ptb.as_tensor_variable(a) + val = ptb.as_tensor_variable(val) + if a.ndim < 2: + raise TypeError( + "%s: first parameter must have at least" + " two dimensions" % self.__class__.__name__ + ) + elif val.ndim != 0: + raise TypeError( + f"{self.__class__.__name__}: second parameter must be a scalar" + ) + val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) + if val.dtype != a.dtype: + raise TypeError( + "%s: type of second parameter must be the same as" + " the first's" % self.__class__.__name__ + ) + return Apply(self, [a, val], [a.type()]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to -return--both in terms of their data types and number of dimensions. +return--both in terms of their data types and number of dimensions/shapes. The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from values that are +More specifically, the :class:`Apply` implies that the inputs come from two values that are automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that they must be scalars. According to this +the ``assert``\s that follow imply that the first one must be a tensor with at least two +dimensions (i.e., matrix) and the second must be a scalar. According to this logic, the inputs could have any data type (e.g. floats, ints), so our JAX implementation must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :attr:`Eye.dtype`, and that it has two non-broadcastable -dimensions. The latter implies that the result is necessarily a matrix. The -former implies that our JAX implementation will need to access the :attr:`dtype` -attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting. +determined by :meth:`a.type()` i.e., the data type of the original tensor. +This implies that the result is necessarily a matrix. Next, we can look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` -in Python. This method is effectively what needs to be implemented in JAX. +in Python. This method is effectively what needs to be implemented in JAX. Step 2: Find the relevant JAX method (or something close) @@ -82,11 +88,15 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] +In this case, we have to use custom logic to implement the JAX version of +:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use +:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set +them to the value we want. Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- -With the PyTensor `Op` replicated in JAX, we’ll need to register the +With the PyTensor `Op` replicated in JAX, we'll need to register the function with the PyTensor JAX `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. @@ -94,35 +104,31 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here’s an example for the `Eye`\ `Op`: +Here's an example for the `FillDiagonal`\ `Op`: .. code:: python import jax.numpy as jnp - from pytensor.tensor.basic import Eye + from pytensor.tensor.extra_ops import FillDiagonal from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + @jax_funcify.register(FillDiagonal) + def jax_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) - # Obtain necessary "static" attributes from the Op being converted - dtype = op.dtype - - # Create a JAX jit-able function that implements the Op - def eye(N, M, k): - return jnp.eye(N, M, k, dtype=dtype) - - return eye + return filldiagonal Step 4: Write tests ------------------- Test that your registered `Op` is working correctly by adding tests to the -appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of -the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can +appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of +the modules in ``tests.link.numba``). The tests should ensure that your implementation can handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. Check the existing tests for the general outline of these kinds of tests. In most cases, a helper function can be used to easily verify the correspondence @@ -131,23 +137,61 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`Eye`: +Here's a small example of a test for :class:`FillDiagonal`: .. code:: python + import numpy as np + import pytensor.tensor as pt + import pytensor.tensor.basic as ptb + from pytensor.configdefaults import config + from tests.link.jax.test_basic import compare_jax_and_py + from pytensor.graph import FunctionGraph + from pytensor.graph.op import get_test_value + + def test_jax_FillDiagonal(): + """Test JAX conversion of the `FillDiagonal` `Op`.""" + + # Create a symbolic input for the first input of `FillDiagonal` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create a scalar value for the second input + c = ptb.as_tensor(5) - import pytensor.tensor as pt + # Create the output variable + out = pt.fill_diagonal(a, c) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +Note +---- +In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: + +.. code:: python + def test_jax_Eye(): + """Test JAX conversion of the `Eye` `Op`.""" - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" + # Create a symbolic input for `Eye` + x_at = pt.scalar() - # Create a symbolic input for `Eye` - x_at = pt.scalar() + # Create a variable that is the output of an `Eye` `Op` + eye_var = pt.eye(x_at) - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) + # Create an PyTensor `FunctionGraph` + out_fg = FunctionGraph(outputs=[eye_var]) - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) + # Pass the graph and any inputs to the testing function + compare_jax_and_py(out_fg, [3]) - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) +This one nowadays leads to a test failure due to new restrictions in JAX + JIT, +as reported in issue `#654 `_. +All jitted functions now must have constant shape, which means a graph like the +one of :class:`Eye` can never be translated to JAX, since it's fundamentally a +function with dynamic shapes. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. \ No newline at end of file From 1e6dfdbe260bebf0194087ee33d39b48a5e15ed5 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:31:36 +0000 Subject: [PATCH 2/9] Changed example for extending JAX. --- doc/extending/creating_a_numba_jax_op.rst | 148 ++++++++++++++-------- 1 file changed, 96 insertions(+), 52 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..8f15d39c61 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -7,55 +7,61 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will focus specifically on the JAX case, but the same mechanisms are used for Numba as well. -Step 1: Identify the PyTensor :class:`Op` you’d like to implement in JAX +Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX ------------------------------------------------------------------------ -Find the source for the PyTensor :class:`Op` you’d like to be supported in JAX, and -identify the function signature and return values. These can be determined by -looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar +Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and +identify the function signature and return values. These can be determined by +looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: .. code:: python - def make_node(self, n, m, k): - n = as_tensor_variable(n) - m = as_tensor_variable(m) - k = as_tensor_variable(k) - assert n.ndim == 0 - assert m.ndim == 0 - assert k.ndim == 0 - return Apply( - self, - [n, m, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) + def make_node(self, a, val): + a = ptb.as_tensor_variable(a) + val = ptb.as_tensor_variable(val) + if a.ndim < 2: + raise TypeError( + "%s: first parameter must have at least" + " two dimensions" % self.__class__.__name__ + ) + elif val.ndim != 0: + raise TypeError( + f"{self.__class__.__name__}: second parameter must be a scalar" + ) + val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) + if val.dtype != a.dtype: + raise TypeError( + "%s: type of second parameter must be the same as" + " the first's" % self.__class__.__name__ + ) + return Apply(self, [a, val], [a.type()]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to -return--both in terms of their data types and number of dimensions. +return--both in terms of their data types and number of dimensions/shapes. The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from values that are +More specifically, the :class:`Apply` implies that the inputs come from two values that are automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that they must be scalars. According to this +the ``assert``\s that follow imply that the first one must be a tensor with at least two +dimensions (i.e., matrix) and the second must be a scalar. According to this logic, the inputs could have any data type (e.g. floats, ints), so our JAX implementation must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :attr:`Eye.dtype`, and that it has two non-broadcastable -dimensions. The latter implies that the result is necessarily a matrix. The -former implies that our JAX implementation will need to access the :attr:`dtype` -attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting. +determined by :meth:`a.type()` i.e., the data type of the original tensor. +This implies that the result is necessarily a matrix. Next, we can look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` -in Python. This method is effectively what needs to be implemented in JAX. +in Python. This method is effectively what needs to be implemented in JAX. Step 2: Find the relevant JAX method (or something close) @@ -82,11 +88,15 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] +In this case, we have to use custom logic to implement the JAX version of +:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use +:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set +them to the value we want. Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- -With the PyTensor `Op` replicated in JAX, we’ll need to register the +With the PyTensor `Op` replicated in JAX, we'll need to register the function with the PyTensor JAX `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. @@ -94,35 +104,31 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here’s an example for the `Eye`\ `Op`: +Here's an example for the `FillDiagonal`\ `Op`: .. code:: python import jax.numpy as jnp - from pytensor.tensor.basic import Eye + from pytensor.tensor.extra_ops import FillDiagonal from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + @jax_funcify.register(FillDiagonal) + def jax_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) - # Obtain necessary "static" attributes from the Op being converted - dtype = op.dtype - - # Create a JAX jit-able function that implements the Op - def eye(N, M, k): - return jnp.eye(N, M, k, dtype=dtype) - - return eye + return filldiagonal Step 4: Write tests ------------------- Test that your registered `Op` is working correctly by adding tests to the -appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of -the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can +appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of +the modules in ``tests.link.numba``). The tests should ensure that your implementation can handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. Check the existing tests for the general outline of these kinds of tests. In most cases, a helper function can be used to easily verify the correspondence @@ -131,23 +137,61 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`Eye`: +Here's a small example of a test for :class:`FillDiagonal`: .. code:: python + import numpy as np + import pytensor.tensor as pt + import pytensor.tensor.basic as ptb + from pytensor.configdefaults import config + from tests.link.jax.test_basic import compare_jax_and_py + from pytensor.graph import FunctionGraph + from pytensor.graph.op import get_test_value + + def test_jax_FillDiagonal(): + """Test JAX conversion of the `FillDiagonal` `Op`.""" + + # Create a symbolic input for the first input of `FillDiagonal` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create a scalar value for the second input + c = ptb.as_tensor(5) - import pytensor.tensor as pt + # Create the output variable + out = pt.fill_diagonal(a, c) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +Note +---- +In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: + +.. code:: python + def test_jax_Eye(): + """Test JAX conversion of the `Eye` `Op`.""" - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" + # Create a symbolic input for `Eye` + x_at = pt.scalar() - # Create a symbolic input for `Eye` - x_at = pt.scalar() + # Create a variable that is the output of an `Eye` `Op` + eye_var = pt.eye(x_at) - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) + # Create an PyTensor `FunctionGraph` + out_fg = FunctionGraph(outputs=[eye_var]) - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) + # Pass the graph and any inputs to the testing function + compare_jax_and_py(out_fg, [3]) - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) +This one nowadays leads to a test failure due to new restrictions in JAX + JIT, +as reported in issue `#654 `_. +All jitted functions now must have constant shape, which means a graph like the +one of :class:`Eye` can never be translated to JAX, since it's fundamentally a +function with dynamic shapes. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. \ No newline at end of file From 8f2a4a2c70efc905ea3ecced7a437dcfcb357f59 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 6 Apr 2024 22:25:09 +0800 Subject: [PATCH 3/9] Use CumOp in the example --- doc/extending/creating_a_numba_jax_op.rst | 188 ++++++++++++++++------ 1 file changed, 138 insertions(+), 50 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 8f15d39c61..0d5f6460e9 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -16,30 +16,34 @@ looking at the :meth:`Op.make_node` implementation. In general, one needs to be with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, you want to extend support for :class:`CumsumOp`\: .. code:: python - def make_node(self, a, val): - a = ptb.as_tensor_variable(a) - val = ptb.as_tensor_variable(val) - if a.ndim < 2: - raise TypeError( - "%s: first parameter must have at least" - " two dimensions" % self.__class__.__name__ - ) - elif val.ndim != 0: - raise TypeError( - f"{self.__class__.__name__}: second parameter must be a scalar" - ) - val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) - if val.dtype != a.dtype: - raise TypeError( - "%s: type of second parameter must be the same as" - " the first's" % self.__class__.__name__ - ) - return Apply(self, [a, val], [a.type()]) + class CumsumOp(Op): + __props__ = ("axis",) + def __new__(typ, *args, **kwargs): + obj = object.__new__(CumOp, *args, **kwargs) + obj.mode = "add" + return obj + + +:class:`CumsumOp` turns out to be a variant of :class:`CumOp`\ :class:`Op` +which currently has an :meth:`Op.make_node` as follows: + +.. code:: python + + def make_node(self, x): + x = ptb.as_tensor_variable(x) + out_type = x.type() + + if self.axis is None: + out_type = vector(dtype=x.dtype) # Flatten + elif self.axis >= x.ndim or self.axis < -x.ndim: + raise ValueError(f"axis(={self.axis}) out of bounds") + + return Apply(self, [x], [out_type]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to @@ -48,22 +52,52 @@ The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from two values that are -automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that the first one must be a tensor with at least two -dimensions (i.e., matrix) and the second must be a scalar. According to this -logic, the inputs could have any data type (e.g. floats, ints), so our JAX -implementation must be able to handle all the possible data types. +More specifically, the :class:`Apply` implies that there is one input that is +automatically converted to PyTensor variables via :func:`as_tensor_variable`. +There is another parameter, `axis`, that is used to determine the direction +of the operation, hence shape of the output. The check that follows imply that +`axis` must refer to a dimension in the input tensor. The input's elements +could also have any data type (e.g. floats, ints), so our JAX implementation +must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :meth:`a.type()` i.e., the data type of the original tensor. +determined by :meth:`x.type()` i.e., the data type of the original tensor. This implies that the result is necessarily a matrix. -Next, we can look at the :meth:`Op.perform` implementation to see exactly +Some class may have a more complex behavior. For example, the :class:`CumOp`\ :class:`Op` +also has another variant :class:`CumprodOp`\ :class:`Op` with the exact signature +as :class:`CumsumOp`\ :class:`Op`. The difference lies in that the `mode` attribute in +:class:`CumOp` definition: + +.. code:: python + + class CumOp(COp): + # See function cumsum/cumprod for docstring + + __props__ = ("axis", "mode") + check_input = False + params_type = ParamsType( + c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) + ) + + def __init__(self, axis: int | None = None, mode="add"): + if mode not in ("add", "mul"): + raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') + self.axis = axis + self.mode = mode + + c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + +`__props__` is used to parametrize the general behavior of the :class:`Op`. One need to +pay attention to this to decide whether the JAX implementation should support all variants +or raise an explicit NotImplementedError for cases that are not supported e.g., when +:class:`CumsumOp` of :class:`CumOp("add")` is supported but not :class:`CumprodOp` of +:class:`CumOp("mul")`. + +Next, we look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` in Python. This method is effectively what needs to be implemented in JAX. - Step 2: Find the relevant JAX method (or something close) --------------------------------------------------------- @@ -88,10 +122,19 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] -In this case, we have to use custom logic to implement the JAX version of -:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use -:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set -them to the value we want. +In this case, :class:`CumOp` is implemented with NumPy's :func:`numpy.cumsum` +and :func:`numpy.cumprod`, which have JAX equivalents: :func:`jax.numpy.cumsum` +and :func:`jax.numpy.cumprod`. + +.. code:: python + + def perform(self, node, inputs, output_storage): + x = inputs[0] + z = output_storage[0] + if self.mode == "add": + z[0] = np.cumsum(x, axis=self.axis) + else: + z[0] = np.cumprod(x, axis=self.axis) Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- @@ -104,24 +147,51 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here's an example for the `FillDiagonal`\ `Op`: +Here's an example for the `CumOp`\ `Op`: + +.. code:: python + + import jax.numpy as jnp + + from pytensor.tensor.extra_ops import CumOp + from pytensor.link.jax.dispatch import jax_funcify + + + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + return jnp.cumprod(x, axis=axis) + + return cumop + +Suppose `jnp.cumprod` does not exist, we will need to register the function as follows: .. code:: python import jax.numpy as jnp - from pytensor.tensor.extra_ops import FillDiagonal + from pytensor.tensor.extra_ops import CumOp from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(FillDiagonal) - def jax_funcify_FillDiagonal(op, **kwargs): - def filldiagonal(value, diagonal): - i, j = jnp.diag_indices(min(value.shape[-2:])) - return value.at[..., i, j].set(diagonal) + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode - return filldiagonal + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + raise NotImplementedError("JAX does not support cumprod function at the moment.") + return cumop Step 4: Write tests ------------------- @@ -137,31 +207,28 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`FillDiagonal`: +Here's a small example of a test for :class:`CumOp` above: .. code:: python + import numpy as np import pytensor.tensor as pt - import pytensor.tensor.basic as ptb from pytensor.configdefaults import config from tests.link.jax.test_basic import compare_jax_and_py from pytensor.graph import FunctionGraph from pytensor.graph.op import get_test_value - def test_jax_FillDiagonal(): - """Test JAX conversion of the `FillDiagonal` `Op`.""" + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" - # Create a symbolic input for the first input of `FillDiagonal` + # Create a symbolic input for the first input of `CumOp` a = pt.matrix("a") # Create test value tag for a a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) - # Create a scalar value for the second input - c = ptb.as_tensor(5) - # Create the output variable - out = pt.fill_diagonal(a, c) + out = pt.cumsum(a, axis=0) # Create a PyTensor `FunctionGraph` fgraph = FunctionGraph([a], [out]) @@ -169,6 +236,27 @@ Here's a small example of a test for :class:`FillDiagonal`: # Pass the graph and inputs to the testing function compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + # For the second mode of CumOp + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +If the variant :class:`CumprodOp` is not implemented, we can add a test for it as follows: + +.. code:: python + + import pytest + + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" + a = pt.matrix("a") + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + with pytest.raises(NotImplementedError): + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + Note ---- In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: From 4d54f9d406e63f682d06578e83f1edd75fb01833 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sat, 6 Apr 2024 22:36:47 +0800 Subject: [PATCH 4/9] Delete new_jax_example.ipynb --- new_jax_example.ipynb | 86 ------------------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 new_jax_example.ipynb diff --git a/new_jax_example.ipynb b/new_jax_example.ipynb deleted file mode 100644 index f1736aa80d..0000000000 --- a/new_jax_example.ipynb +++ /dev/null @@ -1,86 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], - "source": [ - "# General import\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "import pytensor.tensor as pt\n", - "from pytensor.link.jax.dispatch import jax_funcify\n", - "\n", - "# Import for testing\n", - "import pytest\n", - "from pytensor.configdefaults import config\n", - "from tests.link.jax.test_basic import compare_jax_and_py\n", - "from pytensor.graph import FunctionGraph\n", - "from pytensor.graph.op import get_test_value\n", - "\n", - "# Import for the op to extend to JAX\n", - "from pytensor.tensor.extra_ops import CumOp\n", - "\n", - "@jax_funcify.register(CumOp)\n", - "def jax_funcify_CumOp(op, **kwargs):\n", - " axis = op.axis\n", - " mode = op.mode\n", - "\n", - " def cumop(x, axis=axis, mode=mode):\n", - " if mode == \"add\":\n", - " return jnp.cumsum(x, axis=axis)\n", - " else:\n", - " raise NotImplementedError(\"JAX does not support cumprod function at the moment.\")\n", - "\n", - " return cumop\n", - "\n", - "\n", - "def test_jax_CumOp():\n", - " \"\"\"Test JAX conversion of the `CumOp` `Op`.\"\"\"\n", - " a = pt.matrix(\"a\")\n", - " a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))\n", - " \n", - " out = pt.cumsum(a, axis=0)\n", - " fgraph = FunctionGraph([a], [out])\n", - " compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])\n", - " \n", - " with pytest.raises(NotImplementedError):\n", - " out = pt.cumprod(a, axis=1)\n", - " fgraph = FunctionGraph([a], [out])\n", - " compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])\n", - "\n", - "test_jax_CumOp()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytensor-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From bfab13cfda43a8ca4d7e433053de09b4d50dad2a Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Mon, 15 Apr 2024 21:49:49 +0800 Subject: [PATCH 5/9] Updated environment.yml with pip dependency for pymc-sphinx-theme --- environment.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/environment.yml b/environment.yml index e84f1c5207..8edab2c6c5 100644 --- a/environment.yml +++ b/environment.yml @@ -37,6 +37,9 @@ dependencies: - pygments - pydot - ipython + - pip + - pip: + - git+https://github.com/pymc-devs/pymc-sphinx-theme # code style - ruff # developer tools From ea10b143e1a6359f0bb1d2e1e89aba9ec18946cf Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Tue, 16 Apr 2024 20:35:38 +0800 Subject: [PATCH 6/9] Updated documentation based on new environment file --- doc/dev_start_guide.rst | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/doc/dev_start_guide.rst b/doc/dev_start_guide.rst index 1469731e73..163d873c6f 100644 --- a/doc/dev_start_guide.rst +++ b/doc/dev_start_guide.rst @@ -168,15 +168,9 @@ create a virtual environment in the project directory: .. code-block:: bash - conda env create -n pytensor-dev -f environment.yml + conda env create -f environment.yml conda activate pytensor-dev -Afterward, you can install the development dependencies: - -.. code-block:: bash - - pip install -r requirements.txt - Next, ``pre-commit`` needs to be configured so that the linting and code quality checks are performed before each commit: @@ -202,12 +196,12 @@ see the `NumPy development guide `_. Contributing to the documentation --------------------------------- -To contribute to the documentation, first follow the instructions in the previous section. Afterward, you can install the documentation dependencies in the virtual environment you created: - +The documentation build dependencies have also been included in the virtual environment you created. You can also create a separate virtual environment just for the documentation using the `environment.yml` file located inside the `doc` folder. .. code-block:: bash - pip install -r requirements-rtd.txt + conda env create -f doc/environment.yml + conda activate pytensor-docs You can now build the documentation from the root of the project with: From 9a9db89dbe0e008320de9795e1f2be31eaae09ab Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 21 Apr 2024 12:57:19 +0800 Subject: [PATCH 7/9] HP | Updated environment.yml for docs building dependency --- doc/environment.yml | 2 +- environment.yml | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/doc/environment.yml b/doc/environment.yml index 7616f86399..c86375ccf1 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -12,7 +12,7 @@ dependencies: - sphinx>=5.1.0,<6 - mock - pillow + - pymc-sphinx-theme - pip - pip: - - git+https://github.com/pymc-devs/pymc-sphinx-theme - -e .. diff --git a/environment.yml b/environment.yml index 8edab2c6c5..0e5d3387af 100644 --- a/environment.yml +++ b/environment.yml @@ -37,9 +37,7 @@ dependencies: - pygments - pydot - ipython - - pip - - pip: - - git+https://github.com/pymc-devs/pymc-sphinx-theme + - pymc-sphinx-theme # code style - ruff # developer tools From c8e8cc9964d153a0afea167598346b683cc81a0f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 21 Apr 2024 13:13:59 +0800 Subject: [PATCH 8/9] HP | Removed unnecessary scripts & updated contributor guide --- doc/dev_start_guide.rst | 2 +- doc/scripts/docgen.py | 111 ---------------------------------------- 2 files changed, 1 insertion(+), 112 deletions(-) delete mode 100644 doc/scripts/docgen.py diff --git a/doc/dev_start_guide.rst b/doc/dev_start_guide.rst index 163d873c6f..9b855c2f0e 100644 --- a/doc/dev_start_guide.rst +++ b/doc/dev_start_guide.rst @@ -209,7 +209,7 @@ You can now build the documentation from the root of the project with: .. code-block:: bash - python doc/scripts/docgen.py + python -m sphinx -b html ./doc ./html Afterward, you can go to `html/index.html` and navigate the changes in a browser. One way to do this is to go to the `html` directory and run: diff --git a/doc/scripts/docgen.py b/doc/scripts/docgen.py deleted file mode 100644 index 6189099092..0000000000 --- a/doc/scripts/docgen.py +++ /dev/null @@ -1,111 +0,0 @@ -import sys -import os -import shutil -import getopt -from collections import defaultdict - -if __name__ == '__main__': - # Equivalent of sys.path[0]/../.. - throot = os.path.abspath( - os.path.join(sys.path[0], os.pardir, os.pardir)) - - options = defaultdict(bool) - opts, args = getopt.getopt( - sys.argv[1:], - 'o:f:', - ['rst', 'help', 'nopdf', 'cache', 'check', 'test']) - options.update({x: y or True for x, y in opts}) - if options['--help']: - print(f'Usage: {sys.argv[0]} [OPTIONS] [files...]') - print(' -o : output the html files in the specified dir') - print(' --cache: use the doctree cache') - print(' --rst: only compile the doc (requires sphinx)') - print(' --nopdf: do not produce a PDF file from the doc, only HTML') - print(' --test: run all the code samples in the documentation') - print(' --check: treat warnings as errors') - print(' --help: this help') - print('If one or more files are specified after the options then only ' - 'those files will be built. Otherwise the whole tree is ' - 'processed. Specifying files will implies --cache.') - sys.exit(0) - - if not(options['--rst'] or options['--test']): - # Default is now rst - options['--rst'] = True - - def mkdir(path): - try: - os.mkdir(path) - except OSError: - pass - - outdir = options['-o'] or (throot + '/html') - files = None - if len(args) != 0: - files = [os.path.abspath(f) for f in args] - currentdir = os.getcwd() - mkdir(outdir) - os.chdir(outdir) - - # Make sure the appropriate 'pytensor' directory is in the PYTHONPATH - pythonpath = os.environ.get('PYTHONPATH', '') - pythonpath = os.pathsep.join([throot, pythonpath]) - sys.path[0:0] = [throot] # We must not use os.environ. - - # Make sure we don't use other devices to compile documentation - env_th_flags = os.environ.get('PYTENSOR_FLAGS', '') - os.environ['PYTENSOR_FLAGS'] = 'device=cpu,force_device=True' - - def call_sphinx(builder, workdir): - import sphinx - if options['--check']: - extraopts = ['-W'] - else: - extraopts = [] - if not options['--cache'] and files is None: - extraopts.append('-E') - docpath = os.path.join(throot, 'doc') - inopt = [docpath, workdir] - if files is not None: - inopt.extend(files) - try: - import sphinx.cmd.build - ret = sphinx.cmd.build.build_main( - ['-b', builder] + extraopts + inopt) - except ImportError: - # Sphinx < 1.7 - build_main drops first argument - ret = sphinx.build_main( - ['', '-b', builder] + extraopts + inopt) - if ret != 0: - sys.exit(ret) - - if options['--all'] or options['--rst']: - mkdir("doc") - sys.path[0:0] = [os.path.join(throot, 'doc')] - call_sphinx('html', '.') - - if not options['--nopdf']: - # Generate latex file in a temp directory - import tempfile - workdir = tempfile.mkdtemp() - call_sphinx('latex', workdir) - # Compile to PDF - os.chdir(workdir) - os.system('make') - try: - shutil.copy(os.path.join(workdir, 'pytensor.pdf'), outdir) - os.chdir(outdir) - shutil.rmtree(workdir) - except OSError as e: - print('OSError:', e) - - if options['--test']: - mkdir("doc") - sys.path[0:0] = [os.path.join(throot, 'doc')] - call_sphinx('doctest', '.') - - # To go back to the original current directory. - os.chdir(currentdir) - - # Reset PYTENSOR_FLAGS - os.environ['PYTENSOR_FLAGS'] = env_th_flags From 90ad3ea0ba92c570d446ee948a407094b5337fd9 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sun, 21 Apr 2024 13:20:37 +0800 Subject: [PATCH 9/9] HP | Updated contributor guide --- doc/dev_start_guide.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/dev_start_guide.rst b/doc/dev_start_guide.rst index 9b855c2f0e..010d0ffb75 100644 --- a/doc/dev_start_guide.rst +++ b/doc/dev_start_guide.rst @@ -220,7 +220,7 @@ Afterward, you can go to `html/index.html` and navigate the changes in a browser python -m http.server **Do not commit the `html` directory. The documentation is built automatically.** - +For more documentation customizations such as different formats e.g., PDF, refer to the `Sphinx documentation `_. Other tools that might help ===========================