From fe8ac4724abf3c0a4328100eed829217cd6951c5 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:31:36 +0000 Subject: [PATCH 1/4] Changed example for extending JAX. --- doc/extending/creating_a_numba_jax_op.rst | 148 ++++++++++++++-------- 1 file changed, 96 insertions(+), 52 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..8f15d39c61 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -7,55 +7,61 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will focus specifically on the JAX case, but the same mechanisms are used for Numba as well. -Step 1: Identify the PyTensor :class:`Op` you’d like to implement in JAX +Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX ------------------------------------------------------------------------ -Find the source for the PyTensor :class:`Op` you’d like to be supported in JAX, and -identify the function signature and return values. These can be determined by -looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar +Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and +identify the function signature and return values. These can be determined by +looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: .. code:: python - def make_node(self, n, m, k): - n = as_tensor_variable(n) - m = as_tensor_variable(m) - k = as_tensor_variable(k) - assert n.ndim == 0 - assert m.ndim == 0 - assert k.ndim == 0 - return Apply( - self, - [n, m, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) + def make_node(self, a, val): + a = ptb.as_tensor_variable(a) + val = ptb.as_tensor_variable(val) + if a.ndim < 2: + raise TypeError( + "%s: first parameter must have at least" + " two dimensions" % self.__class__.__name__ + ) + elif val.ndim != 0: + raise TypeError( + f"{self.__class__.__name__}: second parameter must be a scalar" + ) + val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) + if val.dtype != a.dtype: + raise TypeError( + "%s: type of second parameter must be the same as" + " the first's" % self.__class__.__name__ + ) + return Apply(self, [a, val], [a.type()]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to -return--both in terms of their data types and number of dimensions. +return--both in terms of their data types and number of dimensions/shapes. The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from values that are +More specifically, the :class:`Apply` implies that the inputs come from two values that are automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that they must be scalars. According to this +the ``assert``\s that follow imply that the first one must be a tensor with at least two +dimensions (i.e., matrix) and the second must be a scalar. According to this logic, the inputs could have any data type (e.g. floats, ints), so our JAX implementation must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :attr:`Eye.dtype`, and that it has two non-broadcastable -dimensions. The latter implies that the result is necessarily a matrix. The -former implies that our JAX implementation will need to access the :attr:`dtype` -attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting. +determined by :meth:`a.type()` i.e., the data type of the original tensor. +This implies that the result is necessarily a matrix. Next, we can look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` -in Python. This method is effectively what needs to be implemented in JAX. +in Python. This method is effectively what needs to be implemented in JAX. Step 2: Find the relevant JAX method (or something close) @@ -82,11 +88,15 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] +In this case, we have to use custom logic to implement the JAX version of +:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use +:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set +them to the value we want. Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- -With the PyTensor `Op` replicated in JAX, we’ll need to register the +With the PyTensor `Op` replicated in JAX, we'll need to register the function with the PyTensor JAX `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. @@ -94,35 +104,31 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here’s an example for the `Eye`\ `Op`: +Here's an example for the `FillDiagonal`\ `Op`: .. code:: python import jax.numpy as jnp - from pytensor.tensor.basic import Eye + from pytensor.tensor.extra_ops import FillDiagonal from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + @jax_funcify.register(FillDiagonal) + def jax_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) - # Obtain necessary "static" attributes from the Op being converted - dtype = op.dtype - - # Create a JAX jit-able function that implements the Op - def eye(N, M, k): - return jnp.eye(N, M, k, dtype=dtype) - - return eye + return filldiagonal Step 4: Write tests ------------------- Test that your registered `Op` is working correctly by adding tests to the -appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of -the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can +appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of +the modules in ``tests.link.numba``). The tests should ensure that your implementation can handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. Check the existing tests for the general outline of these kinds of tests. In most cases, a helper function can be used to easily verify the correspondence @@ -131,23 +137,61 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`Eye`: +Here's a small example of a test for :class:`FillDiagonal`: .. code:: python + import numpy as np + import pytensor.tensor as pt + import pytensor.tensor.basic as ptb + from pytensor.configdefaults import config + from tests.link.jax.test_basic import compare_jax_and_py + from pytensor.graph import FunctionGraph + from pytensor.graph.op import get_test_value + + def test_jax_FillDiagonal(): + """Test JAX conversion of the `FillDiagonal` `Op`.""" + + # Create a symbolic input for the first input of `FillDiagonal` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create a scalar value for the second input + c = ptb.as_tensor(5) - import pytensor.tensor as pt + # Create the output variable + out = pt.fill_diagonal(a, c) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +Note +---- +In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: + +.. code:: python + def test_jax_Eye(): + """Test JAX conversion of the `Eye` `Op`.""" - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" + # Create a symbolic input for `Eye` + x_at = pt.scalar() - # Create a symbolic input for `Eye` - x_at = pt.scalar() + # Create a variable that is the output of an `Eye` `Op` + eye_var = pt.eye(x_at) - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) + # Create an PyTensor `FunctionGraph` + out_fg = FunctionGraph(outputs=[eye_var]) - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) + # Pass the graph and any inputs to the testing function + compare_jax_and_py(out_fg, [3]) - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) +This one nowadays leads to a test failure due to new restrictions in JAX + JIT, +as reported in issue `#654 `_. +All jitted functions now must have constant shape, which means a graph like the +one of :class:`Eye` can never be translated to JAX, since it's fundamentally a +function with dynamic shapes. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. \ No newline at end of file From 1e6dfdbe260bebf0194087ee33d39b48a5e15ed5 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:31:36 +0000 Subject: [PATCH 2/4] Changed example for extending JAX. --- doc/extending/creating_a_numba_jax_op.rst | 148 ++++++++++++++-------- 1 file changed, 96 insertions(+), 52 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..8f15d39c61 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -7,55 +7,61 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will focus specifically on the JAX case, but the same mechanisms are used for Numba as well. -Step 1: Identify the PyTensor :class:`Op` you’d like to implement in JAX +Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX ------------------------------------------------------------------------ -Find the source for the PyTensor :class:`Op` you’d like to be supported in JAX, and -identify the function signature and return values. These can be determined by -looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar +Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and +identify the function signature and return values. These can be determined by +looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: .. code:: python - def make_node(self, n, m, k): - n = as_tensor_variable(n) - m = as_tensor_variable(m) - k = as_tensor_variable(k) - assert n.ndim == 0 - assert m.ndim == 0 - assert k.ndim == 0 - return Apply( - self, - [n, m, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) + def make_node(self, a, val): + a = ptb.as_tensor_variable(a) + val = ptb.as_tensor_variable(val) + if a.ndim < 2: + raise TypeError( + "%s: first parameter must have at least" + " two dimensions" % self.__class__.__name__ + ) + elif val.ndim != 0: + raise TypeError( + f"{self.__class__.__name__}: second parameter must be a scalar" + ) + val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) + if val.dtype != a.dtype: + raise TypeError( + "%s: type of second parameter must be the same as" + " the first's" % self.__class__.__name__ + ) + return Apply(self, [a, val], [a.type()]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to -return--both in terms of their data types and number of dimensions. +return--both in terms of their data types and number of dimensions/shapes. The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from values that are +More specifically, the :class:`Apply` implies that the inputs come from two values that are automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that they must be scalars. According to this +the ``assert``\s that follow imply that the first one must be a tensor with at least two +dimensions (i.e., matrix) and the second must be a scalar. According to this logic, the inputs could have any data type (e.g. floats, ints), so our JAX implementation must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :attr:`Eye.dtype`, and that it has two non-broadcastable -dimensions. The latter implies that the result is necessarily a matrix. The -former implies that our JAX implementation will need to access the :attr:`dtype` -attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting. +determined by :meth:`a.type()` i.e., the data type of the original tensor. +This implies that the result is necessarily a matrix. Next, we can look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` -in Python. This method is effectively what needs to be implemented in JAX. +in Python. This method is effectively what needs to be implemented in JAX. Step 2: Find the relevant JAX method (or something close) @@ -82,11 +88,15 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] +In this case, we have to use custom logic to implement the JAX version of +:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use +:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set +them to the value we want. Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- -With the PyTensor `Op` replicated in JAX, we’ll need to register the +With the PyTensor `Op` replicated in JAX, we'll need to register the function with the PyTensor JAX `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. @@ -94,35 +104,31 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here’s an example for the `Eye`\ `Op`: +Here's an example for the `FillDiagonal`\ `Op`: .. code:: python import jax.numpy as jnp - from pytensor.tensor.basic import Eye + from pytensor.tensor.extra_ops import FillDiagonal from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + @jax_funcify.register(FillDiagonal) + def jax_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) - # Obtain necessary "static" attributes from the Op being converted - dtype = op.dtype - - # Create a JAX jit-able function that implements the Op - def eye(N, M, k): - return jnp.eye(N, M, k, dtype=dtype) - - return eye + return filldiagonal Step 4: Write tests ------------------- Test that your registered `Op` is working correctly by adding tests to the -appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of -the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can +appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of +the modules in ``tests.link.numba``). The tests should ensure that your implementation can handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. Check the existing tests for the general outline of these kinds of tests. In most cases, a helper function can be used to easily verify the correspondence @@ -131,23 +137,61 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`Eye`: +Here's a small example of a test for :class:`FillDiagonal`: .. code:: python + import numpy as np + import pytensor.tensor as pt + import pytensor.tensor.basic as ptb + from pytensor.configdefaults import config + from tests.link.jax.test_basic import compare_jax_and_py + from pytensor.graph import FunctionGraph + from pytensor.graph.op import get_test_value + + def test_jax_FillDiagonal(): + """Test JAX conversion of the `FillDiagonal` `Op`.""" + + # Create a symbolic input for the first input of `FillDiagonal` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create a scalar value for the second input + c = ptb.as_tensor(5) - import pytensor.tensor as pt + # Create the output variable + out = pt.fill_diagonal(a, c) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +Note +---- +In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: + +.. code:: python + def test_jax_Eye(): + """Test JAX conversion of the `Eye` `Op`.""" - def test_jax_Eye(): - """Test JAX conversion of the `Eye` `Op`.""" + # Create a symbolic input for `Eye` + x_at = pt.scalar() - # Create a symbolic input for `Eye` - x_at = pt.scalar() + # Create a variable that is the output of an `Eye` `Op` + eye_var = pt.eye(x_at) - # Create a variable that is the output of an `Eye` `Op` - eye_var = pt.eye(x_at) + # Create an PyTensor `FunctionGraph` + out_fg = FunctionGraph(outputs=[eye_var]) - # Create an PyTensor `FunctionGraph` - out_fg = FunctionGraph(outputs=[eye_var]) + # Pass the graph and any inputs to the testing function + compare_jax_and_py(out_fg, [3]) - # Pass the graph and any inputs to the testing function - compare_jax_and_py(out_fg, [3]) +This one nowadays leads to a test failure due to new restrictions in JAX + JIT, +as reported in issue `#654 `_. +All jitted functions now must have constant shape, which means a graph like the +one of :class:`Eye` can never be translated to JAX, since it's fundamentally a +function with dynamic shapes. In other words, only PyTensor graphs with static shapes +can be translated to JAX at the moment. \ No newline at end of file From 8f2a4a2c70efc905ea3ecced7a437dcfcb357f59 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 6 Apr 2024 22:25:09 +0800 Subject: [PATCH 3/4] Use CumOp in the example --- doc/extending/creating_a_numba_jax_op.rst | 188 ++++++++++++++++------ 1 file changed, 138 insertions(+), 50 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 8f15d39c61..0d5f6460e9 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -16,30 +16,34 @@ looking at the :meth:`Op.make_node` implementation. In general, one needs to be with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read :ref:`creating_an_op` if you are not familiar. -For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows: +For example, you want to extend support for :class:`CumsumOp`\: .. code:: python - def make_node(self, a, val): - a = ptb.as_tensor_variable(a) - val = ptb.as_tensor_variable(val) - if a.ndim < 2: - raise TypeError( - "%s: first parameter must have at least" - " two dimensions" % self.__class__.__name__ - ) - elif val.ndim != 0: - raise TypeError( - f"{self.__class__.__name__}: second parameter must be a scalar" - ) - val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype)) - if val.dtype != a.dtype: - raise TypeError( - "%s: type of second parameter must be the same as" - " the first's" % self.__class__.__name__ - ) - return Apply(self, [a, val], [a.type()]) + class CumsumOp(Op): + __props__ = ("axis",) + def __new__(typ, *args, **kwargs): + obj = object.__new__(CumOp, *args, **kwargs) + obj.mode = "add" + return obj + + +:class:`CumsumOp` turns out to be a variant of :class:`CumOp`\ :class:`Op` +which currently has an :meth:`Op.make_node` as follows: + +.. code:: python + + def make_node(self, x): + x = ptb.as_tensor_variable(x) + out_type = x.type() + + if self.axis is None: + out_type = vector(dtype=x.dtype) # Flatten + elif self.axis >= x.ndim or self.axis < -x.ndim: + raise ValueError(f"axis(={self.axis}) out of bounds") + + return Apply(self, [x], [out_type]) The :class:`Apply` instance that's returned specifies the exact types of inputs that our JAX implementation will receive and the exact types of outputs it's expected to @@ -48,22 +52,52 @@ The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the general signature of the underlying computation. -More specifically, the :class:`Apply` implies that the inputs come from two values that are -automatically converted to PyTensor variables via :func:`as_tensor_variable`, and -the ``assert``\s that follow imply that the first one must be a tensor with at least two -dimensions (i.e., matrix) and the second must be a scalar. According to this -logic, the inputs could have any data type (e.g. floats, ints), so our JAX -implementation must be able to handle all the possible data types. +More specifically, the :class:`Apply` implies that there is one input that is +automatically converted to PyTensor variables via :func:`as_tensor_variable`. +There is another parameter, `axis`, that is used to determine the direction +of the operation, hence shape of the output. The check that follows imply that +`axis` must refer to a dimension in the input tensor. The input's elements +could also have any data type (e.g. floats, ints), so our JAX implementation +must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type -determined by :meth:`a.type()` i.e., the data type of the original tensor. +determined by :meth:`x.type()` i.e., the data type of the original tensor. This implies that the result is necessarily a matrix. -Next, we can look at the :meth:`Op.perform` implementation to see exactly +Some class may have a more complex behavior. For example, the :class:`CumOp`\ :class:`Op` +also has another variant :class:`CumprodOp`\ :class:`Op` with the exact signature +as :class:`CumsumOp`\ :class:`Op`. The difference lies in that the `mode` attribute in +:class:`CumOp` definition: + +.. code:: python + + class CumOp(COp): + # See function cumsum/cumprod for docstring + + __props__ = ("axis", "mode") + check_input = False + params_type = ParamsType( + c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) + ) + + def __init__(self, axis: int | None = None, mode="add"): + if mode not in ("add", "mul"): + raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') + self.axis = axis + self.mode = mode + + c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + +`__props__` is used to parametrize the general behavior of the :class:`Op`. One need to +pay attention to this to decide whether the JAX implementation should support all variants +or raise an explicit NotImplementedError for cases that are not supported e.g., when +:class:`CumsumOp` of :class:`CumOp("add")` is supported but not :class:`CumprodOp` of +:class:`CumOp("mul")`. + +Next, we look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` in Python. This method is effectively what needs to be implemented in JAX. - Step 2: Find the relevant JAX method (or something close) --------------------------------------------------------- @@ -88,10 +122,19 @@ Here's an example for :class:`IfElse`: ) return res if n_outs > 1 else res[0] -In this case, we have to use custom logic to implement the JAX version of -:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use -:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set -them to the value we want. +In this case, :class:`CumOp` is implemented with NumPy's :func:`numpy.cumsum` +and :func:`numpy.cumprod`, which have JAX equivalents: :func:`jax.numpy.cumsum` +and :func:`jax.numpy.cumprod`. + +.. code:: python + + def perform(self, node, inputs, output_storage): + x = inputs[0] + z = output_storage[0] + if self.mode == "add": + z[0] = np.cumsum(x, axis=self.axis) + else: + z[0] = np.cumprod(x, axis=self.axis) Step 3: Register the function with the `jax_funcify` dispatcher --------------------------------------------------------------- @@ -104,24 +147,51 @@ function with the PyTensor JAX `Linker`. This is done through the use of The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. -Here's an example for the `FillDiagonal`\ `Op`: +Here's an example for the `CumOp`\ `Op`: + +.. code:: python + + import jax.numpy as jnp + + from pytensor.tensor.extra_ops import CumOp + from pytensor.link.jax.dispatch import jax_funcify + + + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + return jnp.cumprod(x, axis=axis) + + return cumop + +Suppose `jnp.cumprod` does not exist, we will need to register the function as follows: .. code:: python import jax.numpy as jnp - from pytensor.tensor.extra_ops import FillDiagonal + from pytensor.tensor.extra_ops import CumOp from pytensor.link.jax.dispatch import jax_funcify - @jax_funcify.register(FillDiagonal) - def jax_funcify_FillDiagonal(op, **kwargs): - def filldiagonal(value, diagonal): - i, j = jnp.diag_indices(min(value.shape[-2:])) - return value.at[..., i, j].set(diagonal) + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode - return filldiagonal + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + raise NotImplementedError("JAX does not support cumprod function at the moment.") + return cumop Step 4: Write tests ------------------- @@ -137,31 +207,28 @@ between a JAX/Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`FillDiagonal`: +Here's a small example of a test for :class:`CumOp` above: .. code:: python + import numpy as np import pytensor.tensor as pt - import pytensor.tensor.basic as ptb from pytensor.configdefaults import config from tests.link.jax.test_basic import compare_jax_and_py from pytensor.graph import FunctionGraph from pytensor.graph.op import get_test_value - def test_jax_FillDiagonal(): - """Test JAX conversion of the `FillDiagonal` `Op`.""" + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" - # Create a symbolic input for the first input of `FillDiagonal` + # Create a symbolic input for the first input of `CumOp` a = pt.matrix("a") # Create test value tag for a a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) - # Create a scalar value for the second input - c = ptb.as_tensor(5) - # Create the output variable - out = pt.fill_diagonal(a, c) + out = pt.cumsum(a, axis=0) # Create a PyTensor `FunctionGraph` fgraph = FunctionGraph([a], [out]) @@ -169,6 +236,27 @@ Here's a small example of a test for :class:`FillDiagonal`: # Pass the graph and inputs to the testing function compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + # For the second mode of CumOp + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + +If the variant :class:`CumprodOp` is not implemented, we can add a test for it as follows: + +.. code:: python + + import pytest + + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" + a = pt.matrix("a") + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + with pytest.raises(NotImplementedError): + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + Note ---- In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: From 4d54f9d406e63f682d06578e83f1edd75fb01833 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Sat, 6 Apr 2024 22:36:47 +0800 Subject: [PATCH 4/4] Delete new_jax_example.ipynb --- new_jax_example.ipynb | 86 ------------------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 new_jax_example.ipynb diff --git a/new_jax_example.ipynb b/new_jax_example.ipynb deleted file mode 100644 index f1736aa80d..0000000000 --- a/new_jax_example.ipynb +++ /dev/null @@ -1,86 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], - "source": [ - "# General import\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "import pytensor.tensor as pt\n", - "from pytensor.link.jax.dispatch import jax_funcify\n", - "\n", - "# Import for testing\n", - "import pytest\n", - "from pytensor.configdefaults import config\n", - "from tests.link.jax.test_basic import compare_jax_and_py\n", - "from pytensor.graph import FunctionGraph\n", - "from pytensor.graph.op import get_test_value\n", - "\n", - "# Import for the op to extend to JAX\n", - "from pytensor.tensor.extra_ops import CumOp\n", - "\n", - "@jax_funcify.register(CumOp)\n", - "def jax_funcify_CumOp(op, **kwargs):\n", - " axis = op.axis\n", - " mode = op.mode\n", - "\n", - " def cumop(x, axis=axis, mode=mode):\n", - " if mode == \"add\":\n", - " return jnp.cumsum(x, axis=axis)\n", - " else:\n", - " raise NotImplementedError(\"JAX does not support cumprod function at the moment.\")\n", - "\n", - " return cumop\n", - "\n", - "\n", - "def test_jax_CumOp():\n", - " \"\"\"Test JAX conversion of the `CumOp` `Op`.\"\"\"\n", - " a = pt.matrix(\"a\")\n", - " a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))\n", - " \n", - " out = pt.cumsum(a, axis=0)\n", - " fgraph = FunctionGraph([a], [out])\n", - " compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])\n", - " \n", - " with pytest.raises(NotImplementedError):\n", - " out = pt.cumprod(a, axis=1)\n", - " fgraph = FunctionGraph([a], [out])\n", - " compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])\n", - "\n", - "test_jax_CumOp()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytensor-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}