diff --git a/doc/conf.py b/doc/conf.py index 1729efc4b1..e1143714d3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -36,6 +36,7 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/doc/environment.yml b/doc/environment.yml index d58af79cc6..dc653b1b38 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -24,4 +24,4 @@ dependencies: - pip - pip: - sphinx_sitemap - - -e .. + - -e ..[jax] diff --git a/doc/library/index.rst b/doc/library/index.rst index 08a5b51c34..1b72f0ac84 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -63,6 +63,13 @@ Convert to Variable .. autofunction:: pytensor.as_symbolic(...) +Wrap JAX functions +================== + +.. autofunction:: as_jax_op(...) + + Alias for :func:`pytensor.link.jax.ops.as_jax_op` + Debug ===== diff --git a/pyproject.toml b/pyproject.toml index e82c42753a..8d0adba3df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ tests = [ "pytest-sphinx", ] rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"] -jax = ["jax", "jaxlib"] +jax = ["jax", "jaxlib", "equinox"] numba = ["numba>=0.57", "llvmlite"] [tool.setuptools.packages.find] diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 3c925ac2f2..a7f9aa8058 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph +try: + import pytensor.link.jax.ops + from pytensor.link.jax.ops import as_jax_op +except ImportError as e: + import_error_as_jax_op = e + + def as_jax_op(*args, **kwargs): + raise ImportError( + "JAX and/or equinox are not installed. Install them" + " to use this function: pip install pytensor[jax]" + ) from import_error_as_jax_op + # isort: on diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..ca780da9c8 --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,472 @@ +"""Convert a jax function to a pytensor compatible function.""" + +import logging +from collections.abc import Sequence + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_map, tree_unflatten + +import pytensor.compile.builders +import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op +from pytensor.link.jax.dispatch import jax_funcify + + +log = logging.getLogger(__name__) + + +class JAXOp(Op): + """ + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward computation and reverse-mode differentiation (via the VJPJAXOp class). + + Parameters + ---------- + input_types : list + A list of PyTensor types for each input variable. + output_types : list + A list of PyTensor types for each output variable. + flat_func : callable + The JAX function that computes outputs from inputs. Inputs and outputs have to be provided as flat arrays. + name : str, optional + A custom name for the Op instance. If provided, the class name will be + updated accordingly. + + Example + ------- + This example defines a simple function that sums the input array with a dynamic shape. + + >>> import numpy as np + >>> import jax + >>> import jax.numpy as jnp + >>> from pytensor.tensor import TensorType + >>> + >>> # Create the jax function that sums the input array. + >>> def sum_function(x, y): + ... return jnp.sum(x + y) + >>> + >>> # Create the input and output types, input has a dynamic shape. + >>> input_type = TensorType("float32", shape=(None,)) + >>> output_type = TensorType("float32", shape=(1,)) + >>> + >>> # Instantiate a JAXOp; tree definitions are set to None for simplicity. + >>> op = JAXOp( + ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" + ... ) + >>> # Define symbolic input variables. + >>> x = pt.tensor("x", dtype="float32", shape=(2,)) + >>> y = pt.tensor("y", dtype="float32", shape=(2,)) + >>> # Compile a PyTensor function. + >>> result = op(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print( + ... f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array(14., dtype=float32)] + >>> + >>> # Compute the gradient of op(x, y) with respect to x. + >>> g = pt.grad(result[0], x) + >>> grad_f = pytensor.function([x, y], [g]) + >>> print( + ... grad_f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array([1., 1.], dtype=float32)] + """ + + def __init__(self, input_types, output_types, flat_func, name=None): + self.input_types = input_types + self.output_types = output_types + self.num_inputs = len(input_types) + self.num_outputs = len(output_types) + normalized_flat_func = _normalize_flat_func(flat_func) + self.jitted_func = jax.jit(normalized_flat_func) + + vjp_func = _get_vjp_jax_op(normalized_flat_func, len(output_types)) + normalized_vjp_func = _normalize_flat_func(vjp_func) + self.jitted_vjp = jax.jit(normalized_vjp_func) + self.vjp_jax_op = VJPJAXOp( + self.input_types, + self.jitted_vjp, + name=("VJP" + name) if name is not None else None, + ) + + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, *inputs): + outputs = [pt.as_tensor_variable(typ()) for typ in self.output_types] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_func(*inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_func(*inputs) + + def grad(self, inputs, output_gradients): + # If a output is not used, it gets disconnected by pytensor and won't have a + # gradient. Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + zero_shape = ( + self.output_types[i].shape + if None not in self.output_types[i].shape + else () + ) + output_gradients[i] = pt.zeros(zero_shape, self.output_types[i].dtype) + + # Compute the gradient. + grad_result = self.vjp_jax_op(inputs, output_gradients) + return grad_result if self.num_inputs > 1 else (grad_result,) + + +class VJPJAXOp(Op): + def __init__(self, input_types, jitted_vjp, name=None): + self.input_types = input_types + self.jitted_vjp = jitted_vjp + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, y0, gz): + y0_converted = [ + pt.as_tensor_variable(y).astype(self.input_types[i].dtype) + for i, y in enumerate(y0) + ] + gz_not_disconnected = [ + pt.as_tensor_variable(g) + for g in gz + if not isinstance(g.type, DisconnectedType) + ] + outputs = [typ() for typ in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0_converted + gz_not_disconnected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp(*inputs) + if len(self.input_types) > 1: + for i, res in enumerate(results): + outputs[i][0] = np.array(res, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_vjp(*inputs) + + +def _normalize_flat_func(func): + def normalized_func(*flat_vars): + out_flat = func(*flat_vars) + if isinstance(out_flat, Sequence): + return tuple(out_flat) if len(out_flat) > 1 else out_flat[0] + else: + return out_flat + + return normalized_func + + +def _get_vjp_jax_op(flat_func, num_out): + def vjp_op(*args): + y0 = args[:-num_out] + gz = args[-num_out:] + if len(gz) == 1: + gz = gz[0] + + def f(*inputs): + return flat_func(*inputs) + + primals, vjp_fn = jax.vjp(f, *y0) + + def broadcast_to_shape(g, shape): + if g.ndim > 0 and g.shape[0] == 1: + g_squeezed = jnp.squeeze(g, axis=0) + else: + g_squeezed = g + return jnp.broadcast_to(g_squeezed, shape) + + gz = tree_map( + lambda g, p: broadcast_to_shape(g, jnp.shape(p)).astype(p.dtype), + gz, + primals, + ) + return vjp_fn(gz) + + return vjp_op + + +def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): + """Return a Pytensor-compatible function from a JAX jittable function. + + This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` + objects. The JAX-jittable function can accept any + nested python structure (a `Pytree + `_) as input, and might return + any nested Python structure. + + Parameters + ---------- + jaxfunc : Callable + A JAX function to be wrapped. + use_infer_static_shape : bool, optional + If True, use static shape inference; otherwise, use runtime shape inference. + Default is True. + name : str, optional + A custom name for the created Pytensor Op instance. If None, the name of jaxfunc + is used. + + Returns + ------- + Callable + A function that wraps the given JAX function so that it can be called with + pytensor.Variable inputs and returns pytensor.Variable outputs. + + Examples + -------- + + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @as_jax_op + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.scalar("x") + >>> y = pt.scalar("y") + >>> result = add(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print(f(1, 2)) + [array(3.)] + + Notes + ----- + The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, + available at + `pymc-labls.io `__. + To accept functions and non pytensor variables as input, the function make use + of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + variables. Shapes are inferred using + :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + + """ + + def func(*args, **kwargs): + # 1. Partition inputs into dynamic pytensor variables, wrapped functions and + # static variables. + # Static variables don't take part in the graph. + pt_vars, func_vars, static_vars = _split_inputs(args, kwargs) + + # 2. Get the original variables from the wrapped functions. + vars_from_func = tree_map(lambda f: f.get_vars(), func_vars) + input_dict = {"vars": pt_vars, "vars_from_func": vars_from_func} + + # 3. Flatten the input dictionary. + # e.g. {"a": tensor_a, "b": [tensor_b]} becomes [tensor_a, tensor_b], because + # pytensor ops only accepts lists of pytensor.Variables as input. + pt_vars_flat, pt_vars_treedef = tree_flatten( + input_dict, + ) + pt_types = [var.type for var in pt_vars_flat] + + # 4. Create dummy inputs for shape inference. + shapes = _infer_shapes(pt_vars_flat, use_infer_static_shape) + dummy_in_flat = _create_dummy_inputs_from_shapes( + pt_vars_flat, shapes, use_infer_static_shape + ) + dummy_inputs = tree_unflatten(pt_vars_treedef, dummy_in_flat) + + # 5. Partition the JAX function into dynamic and static parts. + jaxfunc_dynamic, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + flat_func = _flatten_func(jaxfunc_dynamic, pt_vars_treedef) + + # 6. Infer output types using JAX's eval_shape. + out_treedef, pt_types_flat = _infer_output_types(jaxfunc_dynamic, dummy_inputs) + + # 7. Create the Pytensor Op instance. + curr_name = "JAXOp_" + (jaxfunc.__name__ if name is None else name) + op_instance = JAXOp( + pt_types, + pt_types_flat, + flat_func, + name=curr_name, + ) + + # 8. Execute the op and unflatten the outputs. + output_flat = op_instance(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] + outvars = tree_unflatten(out_treedef, output_flat) + + # 9. Combine with static outputs and wrap eventual output functions with + # _WrappedFunc + return _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars) + + return func + + +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + + +def _split_inputs(args, kwargs): + """Split inputs into pytensor variables, static values and wrapped functions.""" + + pt_vars, static_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + func_vars, static_vars = eqx.partition( + static_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + return pt_vars, func_vars, static_vars + + +def _infer_shapes(pt_vars_flat, use_infer_static_shape): + """Infer shapes of pytensor variables.""" + if use_infer_static_shape: + return [pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat] + else: + return pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + + +def _create_dummy_inputs_from_shapes(pt_vars_flat, shapes, use_infer_static_shape): + """Create dummy inputs for the jax function from inferred shapes.""" + if use_infer_static_shape: + return [ + jnp.empty(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) + ] + else: + return [ + jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) + ] + + +def _infer_output_types(jaxfunc_part, dummy_inputs): + """Infer output types using JAX's eval_shape.""" + jax_out = jax.eval_shape(jaxfunc_part, dummy_inputs) + jax_out_flat, out_treedef = tree_flatten(jax_out) + pt_out_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) for var in jax_out_flat + ] + return out_treedef, pt_out_types + + +def _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars): + """Process and combine static outputs with the dynamic ones.""" + static_funcs, static_vars_out = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + flat_static, func_treedef = tree_flatten(static_funcs, is_leaf=callable) + for i in range(len(flat_static)): + flat_static[i] = _WrappedFunc(jaxfunc, i, *args, **kwargs) + static_funcs = tree_unflatten(func_treedef, flat_static) + static_combined = eqx.combine(static_funcs, static_vars_out, is_leaf=callable) + return eqx.combine(outvars, static_combined, is_leaf=callable) + + +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Split the jax function into dynamic and static components. + + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function + + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} + + def jaxfunc_partitioned(vars): + dyn_vars, func_vars_input = vars["vars"], vars["vars_from_func"] + evaluated_funcs = tree_map( + lambda f, v: f.get_func_with_vars(v), func_vars, func_vars_input + ) + args, kwargs = eqx.combine( + dyn_vars, static_vars, evaluated_funcs, is_leaf=callable + ) + output = jaxfunc(*args, **kwargs) + out_dyn, static_out = eqx.partition(output, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return out_dyn + + return jaxfunc_partitioned, static_out_dic + + +def _flatten_func(jaxfunc, treedef): + def flat_func(*flat_vars): + vars = tree_unflatten(treedef, flat_vars) + out = jaxfunc(vars) + out_flat, _ = tree_flatten(out) + return out_flat + + return flat_func + + +class _WrappedFunc: + def __init__(self, exterior_func, i_func, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.i_func = i_func + # Partition the inputs to separate dynamic variables from static ones. + vars, static_vars = eqx.partition( + (self.args, self.kwargs), _filter_ptvars, is_leaf=callable + ) + self.vars = vars + self.static_vars = static_vars + self.exterior_func = exterior_func + + def __call__(self, *args, **kwargs): + # If called, assume that args and kwargs are pytensors, so return the result + # as pytensors. + def f(func, *args, **kwargs): + return func(*args, **kwargs) + + return as_jax_op(f)(self, *args, **kwargs) + + def get_vars(self): + return self.vars + + def get_func_with_vars(self, vars): + # Use other variables than the saved ones, to generate the function. This + # is used to transform vars externally from pytensor to JAX, and use the + # then create the function which is returned. + args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) + output = self.exterior_func(*args, **kwargs) + out_funcs, _ = eqx.partition(output, callable, is_leaf=callable) + out_funcs_flat, _ = tree_flatten(out_funcs, is_leaf=callable) + return out_funcs_flat[self.i_func] + + +@jax_funcify.register(JAXOp) +def jax_op_funcify(op, **kwargs): + return op.perform_jax + + +@jax_funcify.register(VJPJAXOp) +def vjp_jax_op_funcify(op, **kwargs): + return op.perform_jax diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py new file mode 100644 index 0000000000..62fd270032 --- /dev/null +++ b/tests/link/jax/test_as_jax_op.py @@ -0,0 +1,474 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import as_jax_op, config, grad +from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import JAXOp +from pytensor.scalar import all_types +from pytensor.tensor import TensorType, tensor +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_two_inputs_single_output(): + rng = np.random.default_rng(1) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y) + + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_two_inputs_tuple_output(): + rng = np.random.default_rng(2) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y), y * 2 + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + # must_be_device_array is False, because the with disabled jit compilation, + # inputs are not automatically transformed to jax.Array anymore + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_two_inputs_list_output_one_unused_output(): + # One output is unused, to test whether the wrapper can handle DisconnectedType + rng = np.random.default_rng(3) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return [jax.nn.sigmoid(x + y), y * 2] + + # Test with as_jax_op decorator + out, _ = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out, _ = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_single_input_tuple_output(): + rng = np.random.default_rng(4) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x * 2 + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_scalar_input_tuple_output(): + rng = np.random.default_rng(5) + x = tensor("x", shape=()) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_single_input_list_output(): + rng = np.random.default_rng(6) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return [jax.nn.sigmoid(x), 2 * x] + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage, with unspecified output shapes + jax_op = JAXOp( + [x.type], + [ + TensorType(config.floatX, shape=(None,)), + TensorType(config.floatX, shape=(None,)), + ], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_pytree_input_tuple_output(): + rng = np.random.default_rng(7) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + y_tmp = {"y": y, "y2": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + + # Test with as_jax_op decorator + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + +def test_pytree_input_pytree_output(): + rng = np.random.default_rng(8) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + + # Test with as_jax_op decorator + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + +def test_pytree_input_with_non_graph_args(): + rng = np.random.default_rng(9) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, depth, which_variable): + if which_variable == "x": + var = x + elif which_variable == "y": + var = y["a"] + y["b"][0] + else: + return "Unsupported argument" + for _ in range(depth): + var = jax.nn.sigmoid(var) + return var + + # Test with as_jax_op decorator + # arguments depth and which_variable are not part of the graph + out = f(x, y_tmp, depth=3, which_variable="x") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + out = f(x, y_tmp, depth=7, which_variable="y") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + out = f(x, y_tmp, depth=10, which_variable="z") + assert out == "Unsupported argument" + + +def test_unused_matrix_product(): + # A matrix output is unused, to test whether the wrapper can handle a + # DisconnectedType with a larger dimension. + + rng = np.random.default_rng(10) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return x[:, None] @ y[None], jnp.exp(x) + + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [ + TensorType(config.floatX, shape=(3, 3)), + TensorType(config.floatX, shape=(3,)), + ], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_unknown_static_shape(): + rng = np.random.default_rng(11) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape + + def f(x, y): + return x * jnp.ones(3) + + out = as_jax_op(f)(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(None,))], + f, + ) + out = jax_op(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_nested_functions(): + rng = np.random.default_rng(13) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f_internal(y): + def f_ret(t): + return y + t + + def f_ret2(t): + return f_ret(t) + t**2 + + return f_ret, y**2 * jnp.ones(1), f_ret2 + + f, y_pow, f2 = f_internal(y) + + @as_jax_op + def f_outer(x, dict_other): + f, y_pow = dict_other["func"], dict_other["y"] + return x * jnp.ones(3), f(x) * y_pow + + out = f_outer(x, {"func": f, "y": y_pow}) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +class TestDtypes: + @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) + def test_different_in_output(self, in_dtype, out_dtype): + x = tensor("x", shape=(3,), dtype=in_dtype) + y = tensor("y", shape=(3,), dtype=in_dtype) + + if "int" in in_dtype: + test_values = [ + np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + else: + test_values = [ + np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(out_dtype) + + out = f(x, y) + assert out.dtype == out_dtype + + if "float" in in_dtype and "float" in out_dtype: + grad_out = grad(out[0], [x, y]) + assert grad_out[0].dtype == in_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) + def test_test_different_inputs(self, in1_dtype, in2_dtype): + x = tensor("x", shape=(3,), dtype=in1_dtype) + y = tensor("y", shape=(3,), dtype=in2_dtype) + + if "int" in in1_dtype: + test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] + else: + test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)] + if "int" in in2_dtype: + test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype)) + else: + test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(in1_dtype) + + out = f(x, y) + assert out.dtype == in1_dtype + + if "float" in in1_dtype and "float" in in2_dtype: + # In principle, the gradient should also be defined if the second input is + # an integer, but it doesn't work for some reason. + grad_out = grad(out[0], [x]) + assert grad_out[0].dtype == in1_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values)