From e428fb16b33c33b9c0a2e20e7b563e4479983dd0 Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 12 Dec 2024 17:25:07 +0100 Subject: [PATCH 01/14] first commit for a decorator that transforms JAX to pytensor --- pyproject.toml | 2 +- pytensor/link/jax/ops.py | 424 +++++++++++++++++++++++++++++++ tests/link/jax/test_as_jax_op.py | 26 ++ 3 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/jax/ops.py create mode 100644 tests/link/jax/test_as_jax_op.py diff --git a/pyproject.toml b/pyproject.toml index e82c42753a..8d0adba3df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ tests = [ "pytest-sphinx", ] rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"] -jax = ["jax", "jaxlib"] +jax = ["jax", "jaxlib", "equinox"] numba = ["numba>=0.57", "llvmlite"] [tool.setuptools.packages.find] diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..130ece6eda --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,424 @@ +"""Convert a jax function to a pytensor compatible function.""" + +import functools as ft +import logging +from collections.abc import Sequence + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_map, tree_unflatten + +import pytensor.compile.builders +import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op +from pytensor.link.jax.dispatch import jax_funcify + + +log = logging.getLogger(__name__) + + +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + + +def as_jax_op(jaxfunc, name=None): + """Return a Pytensor from a JAX jittable function. + + This decorator transforms any JAX jittable function into a function that accepts + and returns `pytensor.Variables`. The jax jittable function can accept any + nested python structure (pytrees) as input, and return any nested Python structure. + + It requires to define the output types of the returned values as pytensor types. A + unique name should also be passed in case the name of the jaxfunc is identical to + some other node. The design of this function is based on + https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ + + Parameters + ---------- + jaxfunc : jax jittable function + function for which the node is created, can return multiple tensors as a tuple. + It is required that all return values are able to transformed to + pytensor.Variable. + name: str + Name of the created pytensor Op, defaults to the name of the passed function. + Only used internally in the pytensor graph. + + Returns + ------- + A function which can be used in a pymc.Model as function, is differentiable + and the resulting model can be compiled either with the default C backend, or + the JAX backend. + + + Notes + ----- + The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, + available at + `pymc-labls.io `__. + To accept functions and non pytensor variables as input, the function make use + of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + variables. Shapes are inferred using + :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + """ + + def func(*args, **kwargs): + """Return a pytensor from a jax jittable function.""" + ### Split variables: in the ones that will be transformed to JAX inputs, + ### pytensor.Variables; _WrappedFunc, that are functions that have been returned + ### from a transformed function; and the rest, static variables that are not + ### transformed. + + pt_vars, static_vars_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + + func_vars, static_vars = eqx.partition( + static_vars_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) + pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) + """ + def func_unwrapped(vars_all, static_vars): + vars, vars_from_func = vars_all["vars"], vars_all["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine(vars, static_vars, func_vars_evaled) + return self.jaxfunc(*args, **kwargs) + """ + + pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + pt_vars_types_flat = [var.type for var in pt_vars_flat] + shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) + + dummy_inputs_jax = jax.tree_util.tree_map( + lambda var, shape: jnp.empty( + [int(dim.eval()) for dim in shape], dtype=var.type.dtype + ), + pt_vars, + shapes_vars, + ) + + # Combine the static variables with the inputs, and split them again in the + # output. Static variables don't take part in the graph, or might be a + # a function that is returned. + jaxfunc_partitioned, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + + func_flattened = _flatten_func(jaxfunc_partitioned, vars_treedef) + + jaxtypes_outvars = jax.eval_shape( + ft.partial(jaxfunc_partitioned, vars=dummy_inputs_jax), + ) + + jaxtypes_outvars_flat, outvars_treedef = tree_flatten(jaxtypes_outvars) + + pttypes_outvars = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) + for var in jaxtypes_outvars_flat + ] + + ### Call the function that accepts flat inputs, which in turn calls the one that + ### combines the inputs and static variables. + jitted_sol_op_jax = jax.jit(func_flattened) + len_gz = len(pttypes_outvars) + + vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) + jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + + if name is None: + curr_name = jaxfunc.__name__ + else: + curr_name = name + + # Get classes that creates a Pytensor Op out of our function that accept + # flattened inputs. They are created each time, to set a custom name for the + # class. + SolOp, VJPSolOp = _return_pytensor_ops_classes(curr_name) + + local_op = SolOp( + vars_treedef, + outvars_treedef, + input_types=pt_vars_types_flat, + output_types=pttypes_outvars, + jitted_sol_op_jax=jitted_sol_op_jax, + jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, + ) + + @jax_funcify.register(SolOp) + def sol_op_jax_funcify(op, **kwargs): + return local_op.perform_jax + + @jax_funcify.register(VJPSolOp) + def vjp_sol_op_jax_funcify(op, **kwargs): + return local_op.vjp_sol_op.perform_jax + + ### Evaluate the Pytensor Op and return unflattened results + output_flat = local_op(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] # tree_unflatten expects a sequence. + outvars = tree_unflatten(outvars_treedef, output_flat) + + static_outfuncs, static_outvars = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + + static_outfuncs_flat, treedef_outfuncs = jax.tree_util.tree_flatten( + static_outfuncs, is_leaf=callable + ) + for i_func, _ in enumerate(static_outfuncs_flat): + static_outfuncs_flat[i_func] = _WrappedFunc( + jaxfunc, i_func, *args, **kwargs + ) + + static_outfuncs = jax.tree_util.tree_unflatten( + treedef_outfuncs, static_outfuncs_flat + ) + static_vars = eqx.combine(static_outfuncs, static_outvars, is_leaf=callable) + + output = eqx.combine(outvars, static_vars, is_leaf=callable) + + return output + + return func + + +class _WrappedFunc: + def __init__(self, exterior_func, i_func, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.i_func = i_func + vars, static_vars = eqx.partition( + (self.args, self.kwargs), _filter_ptvars, is_leaf=callable + ) + self.vars = vars + self.static_vars = static_vars + self.exterior_func = exterior_func + + def __call__(self, *args, **kwargs): + # If called, assume that args and kwargs are pytensors, so return the result + # as pytensors. + def f(func, *args, **kwargs): + res = func(*args, **kwargs) + return res + + return as_jax_op(f)(self, *args, **kwargs) + + def get_vars(self): + return self.vars + + def get_func_with_vars(self, vars): + # Use other variables than the saved ones, to generate the function. This + # is used to transform vars externally from pytensor to JAX, and use the + # then create the function which is returned. + + args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) + output = self.exterior_func(*args, **kwargs) + outfuncs, _ = eqx.partition(output, callable, is_leaf=callable) + outfuncs_flat, _ = jax.tree_util.tree_flatten(outfuncs, is_leaf=callable) + interior_func = outfuncs_flat[self.i_func] + return interior_func + + +def _get_vjp_sol_op_jax(jaxfunc, len_gz): + def vjp_sol_op_jax(args): + y0 = args[:-len_gz] + gz = args[-len_gz:] + if len(gz) == 1: + gz = gz[0] + + def func(*inputs): + return jaxfunc(inputs) + + primals, vjp_fn = jax.vjp(func, *y0) + gz = tree_map( + lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), + gz, + primals, + ) + if len(y0) == 1: + return vjp_fn(gz)[0] + else: + return tuple(vjp_fn(gz)) + + return vjp_sol_op_jax + + +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Partition the jax function into static and non-static variables. + + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function + + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} + + def jaxfunc_partitioned(vars): + vars, vars_from_func = vars["vars"], vars["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine( + vars, static_vars, func_vars_evaled, is_leaf=callable + ) + + out = jaxfunc(*args, **kwargs) + outvars, static_out = eqx.partition(out, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return outvars + + return jaxfunc_partitioned, static_out_dic + + +### Construct the function that accepts flat inputs and returns flat outputs. +def _flatten_func(jaxfunc, vars_treedef): + def func_flattened(vars_flat): + vars = tree_unflatten(vars_treedef, vars_flat) + outvars = jaxfunc(vars) + outvars_flat, _ = tree_flatten(outvars) + return _normalize_flat_output(outvars_flat) + + return func_flattened + + +def _normalize_flat_output(output): + if len(output) > 1: + return tuple( + output + ) # Transform to tuple because jax makes a difference between + # tuple and list and not pytensor + else: + return output[0] + + +def _return_pytensor_ops_classes(name): + class SolOp(Op): + def __init__( + self, + input_treedef, + output_treeedef, + input_types, + output_types, + jitted_sol_op_jax, + jitted_vjp_sol_op_jax, + ): + self.vjp_sol_op = None + self.input_treedef = input_treedef + self.output_treedef = output_treeedef + self.input_types = input_types + self.output_types = output_types + self.jitted_sol_op_jax = jitted_sol_op_jax + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, *inputs): + self.num_inputs = len(inputs) + + # Define our output variables + outputs = [pt.as_tensor_variable(type()) for type in self.output_types] + self.num_outputs = len(outputs) + + self.vjp_sol_op = VJPSolOp( + self.input_treedef, + self.input_types, + self.jitted_vjp_sol_op_jax, + ) + + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_sol_op_jax(inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_sol_op_jax(inputs) + return results + + def grad(self, inputs, output_gradients): + # If a output is not used, it is disconnected and doesn't have a gradient. + # Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + if None not in self.output_types[i].shape: + output_gradients[i] = pt.zeros( + self.output_types[i].shape, self.output_types[i].dtype + ) + else: + output_gradients[i] = pt.zeros((), self.output_types[i].dtype) + result = self.vjp_sol_op(inputs, output_gradients) + + if self.num_inputs > 1: + return result + else: + return (result,) # Pytensor requires a tuple here + + # vector-jacobian product Op + class VJPSolOp(Op): + def __init__( + self, + input_treedef, + input_types, + jitted_vjp_sol_op_jax, + ): + self.input_treedef = input_treedef + self.input_types = input_types + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, y0, gz): + y0 = [ + pt.as_tensor_variable( + _y, + ).astype(self.input_types[i].dtype) + for i, _y in enumerate(y0) + ] + gz_not_disconntected = [ + pt.as_tensor_variable(_gz) + for _gz in gz + if not isinstance(_gz.type, DisconnectedType) + ] + outputs = [in_type() for in_type in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0 + gz_not_disconntected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if len(self.input_types) > 1: + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if self.num_outputs == 1: + if isinstance(results, Sequence): + return results[0] + else: + return results + else: + return tuple(results) + + SolOp.__name__ = name + SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name]) + + VJPSolOp.__name__ = "VJP_" + name + VJPSolOp.__qualname__ = ".".join( + VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name] + ) + + return SolOp, VJPSolOp diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py new file mode 100644 index 0000000000..6feb1124a2 --- /dev/null +++ b/tests/link/jax/test_as_jax_op.py @@ -0,0 +1,26 @@ +import jax +import numpy as np + +from pytensor import config +from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import as_jax_op +from pytensor.tensor import tensor +from tests.link.jax.test_basic import compare_jax_and_py + +def test_as_jax_op1(): + # 2 parameters input, single output + rng = np.random.default_rng(14) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x + y) + + out = f(x, y) + + fg = FunctionGraph([x, y], [out]) + fn, _ = compare_jax_and_py(fg, test_values) From 9cb4cc5c96f3d9d7e4d7add0ead93ea2399fac1c Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 12 Dec 2024 17:47:44 +0100 Subject: [PATCH 02/14] Add more tests --- pytensor/link/jax/ops.py | 7 +- tests/link/jax/test_as_jax_op.py | 368 ++++++++++++++++++++++++++++++- 2 files changed, 371 insertions(+), 4 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 130ece6eda..1b2325293d 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -241,7 +241,11 @@ def func(*inputs): primals, vjp_fn = jax.vjp(func, *y0) gz = tree_map( - lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), + lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)).astype( + primal.dtype + ), # Also cast to the dtype of the primal, this shouldn't be + # necessary, but it happens that the returned dtype of the gradient isn't + # the same anymore. gz, primals, ) @@ -326,6 +330,7 @@ def make_node(self, *inputs): self.num_inputs = len(inputs) # Define our output variables + print(self.output_types) outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 6feb1124a2..8d404d76db 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -1,15 +1,20 @@ import jax +import jax.numpy as jnp import numpy as np +import pytest -from pytensor import config +import pytensor.tensor as pt +from pytensor import config, grad from pytensor.graph.fg import FunctionGraph from pytensor.link.jax.ops import as_jax_op +from pytensor.scalar import all_types from pytensor.tensor import tensor from tests.link.jax.test_basic import compare_jax_and_py + def test_as_jax_op1(): # 2 parameters input, single output - rng = np.random.default_rng(14) + rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) test_values = [ @@ -21,6 +26,363 @@ def f(x, y): return jax.nn.sigmoid(x + y) out = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op2(): + # 2 parameters input, tuple output + rng = np.random.default_rng(2) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x + y), y * 2 + + out, _ = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op3(): + # 2 parameters input, list output + rng = np.random.default_rng(3) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return [jax.nn.sigmoid(x + y), y * 2] + + out, _ = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op4(): + # single 1d input, tuple output + rng = np.random.default_rng(4) + x = tensor("a", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return jax.nn.sigmoid(x), x * 2 + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op5(): + # single 0d input, tuple output + rng = np.random.default_rng(5) + x = tensor("a", shape=()) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return jax.nn.sigmoid(x), x + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op6(): + # single input, list output + rng = np.random.default_rng(6) + x = tensor("a", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return [jax.nn.sigmoid(x), 2 * x] + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op7(): + # 2 parameters input with pytree, tuple output + rng = np.random.default_rng(7) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + y_tmp = {"y": y, "y2": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op8(): + # 2 parameters input with pytree, pytree output + rng = np.random.default_rng(8) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op9(): + # 2 parameters input with pytree, pytree output and non-graph argument + rng = np.random.default_rng(9) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, non_model_arg): + return jnp.exp(x), jax.tree_util.tree_map(jax.nn.sigmoid, y) + + out = f(x, y_tmp, "Hello World!") + grad_out = grad(pt.sum(out[0]), [x]) + + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op10(): + # Use "None" in shape specification and have a non-used output of higher rank + rng = np.random.default_rng(10) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x[:, None] @ y[None], jnp.exp(x) + + out = f(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op11(): + # Test unknown static shape + rng = np.random.default_rng(11) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x = pt.cumsum(x) # Now x has an unknown shape + + @as_jax_op + def f(x, y): + return x * jnp.ones(3) + + out = f(x, y) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + - fg = FunctionGraph([x, y], [out]) +def test_as_jax_op12(): + # Test non-array return values + rng = np.random.default_rng(12) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, message): + return x * jnp.ones(3), "Success: " + message + + out = f(x, y, "Hi") + grad_out = grad(pt.sum(out[0]), [x]) + + fg = FunctionGraph([x, y], [out[0], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op13(): + # Test nested functions + rng = np.random.default_rng(13) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f_internal(y): + def f_ret(t): + return y + t + + def f_ret2(t): + return f_ret(t) + t**2 + + return f_ret, y**2 * jnp.ones(1), f_ret2 + + f, y_pow, f2 = f_internal(y) + + @as_jax_op + def f_outer(x, dict_other): + f, y_pow = dict_other["func"], dict_other["y"] + return x * jnp.ones(3), f(x) * y_pow + + out = f_outer(x, {"func": f, "y": y_pow}) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +class TestDtypes: + @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) + def test_different_in_output(self, in_dtype, out_dtype): + x = tensor("a", shape=(3,), dtype=in_dtype) + y = tensor("b", shape=(3,), dtype=in_dtype) + + if "int" in in_dtype: + test_values = [ + np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + else: + test_values = [ + np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(out_dtype) + + out = f(x, y) + assert out.dtype == out_dtype + + if "float" in in_dtype and "float" in out_dtype: + grad_out = grad(out[0], [x, y]) + assert grad_out[0].dtype == in_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) + def test_test_different_inputs(self, in1_dtype, in2_dtype): + x = tensor("a", shape=(3,), dtype=in1_dtype) + y = tensor("b", shape=(3,), dtype=in2_dtype) + + if "int" in in1_dtype: + test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] + else: + test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)] + if "int" in in2_dtype: + test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype)) + else: + test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(in1_dtype) + + out = f(x, y) + assert out.dtype == in1_dtype + + if "float" in in1_dtype and "float" in in2_dtype: + # In principle, the gradient should also be defined if the second input is + # an integer, but it doesn't work for some reason. + grad_out = grad(out[0], [x]) + assert grad_out[0].dtype == in1_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) From 43759baf385033dd5ef319b0a0ec92eea393c2af Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 13:18:10 +0100 Subject: [PATCH 03/14] Define JAXOp outside of the decorator --- pytensor/link/jax/ops.py | 263 +++++++++++++++++++-------------------- 1 file changed, 130 insertions(+), 133 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 1b2325293d..60a3581550 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -84,16 +84,8 @@ def func(*args, **kwargs): ) vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) - """ - def func_unwrapped(vars_all, static_vars): - vars, vars_from_func = vars_all["vars"], vars_all["vars_from_func"] - func_vars_evaled = tree_map( - lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func - ) - args, kwargs = eqx.combine(vars, static_vars, func_vars_evaled) - return self.jaxfunc(*args, **kwargs) - """ + # Infer shapes and types of the variables pt_vars_flat, vars_treedef = tree_flatten(pt_vars) pt_vars_types_flat = [var.type for var in pt_vars_flat] shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) @@ -135,17 +127,30 @@ def func_unwrapped(vars_all, static_vars): vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + # Get classes that creates a Pytensor Op out of our function that accept + # flattened inputs. They are created each time, to set a custom name for the + # class. + class JAXOp_local(JAXOp): + pass + + class VJPJAXOp_local(VJPJAXOp): + pass + if name is None: curr_name = jaxfunc.__name__ else: curr_name = name + JAXOp_local.__name__ = curr_name + JAXOp_local.__qualname__ = ".".join( + JAXOp_local.__qualname__.split(".")[:-1] + [curr_name] + ) - # Get classes that creates a Pytensor Op out of our function that accept - # flattened inputs. They are created each time, to set a custom name for the - # class. - SolOp, VJPSolOp = _return_pytensor_ops_classes(curr_name) + VJPJAXOp_local.__name__ = "VJP_" + curr_name + VJPJAXOp_local.__qualname__ = ".".join( + VJPJAXOp_local.__qualname__.split(".")[:-1] + ["VJP_" + curr_name] + ) - local_op = SolOp( + local_op = JAXOp_local( vars_treedef, outvars_treedef, input_types=pt_vars_types_flat, @@ -154,14 +159,6 @@ def func_unwrapped(vars_all, static_vars): jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, ) - @jax_funcify.register(SolOp) - def sol_op_jax_funcify(op, **kwargs): - return local_op.perform_jax - - @jax_funcify.register(VJPSolOp) - def vjp_sol_op_jax_funcify(op, **kwargs): - return local_op.vjp_sol_op.perform_jax - ### Evaluate the Pytensor Op and return unflattened results output_flat = local_op(*pt_vars_flat) if not isinstance(output_flat, Sequence): @@ -307,123 +304,123 @@ def _normalize_flat_output(output): return output[0] -def _return_pytensor_ops_classes(name): - class SolOp(Op): - def __init__( - self, - input_treedef, - output_treeedef, - input_types, - output_types, - jitted_sol_op_jax, - jitted_vjp_sol_op_jax, - ): - self.vjp_sol_op = None - self.input_treedef = input_treedef - self.output_treedef = output_treeedef - self.input_types = input_types - self.output_types = output_types - self.jitted_sol_op_jax = jitted_sol_op_jax - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax - - def make_node(self, *inputs): - self.num_inputs = len(inputs) - - # Define our output variables - print(self.output_types) - outputs = [pt.as_tensor_variable(type()) for type in self.output_types] - self.num_outputs = len(outputs) - - self.vjp_sol_op = VJPSolOp( - self.input_treedef, - self.input_types, - self.jitted_vjp_sol_op_jax, - ) +class JAXOp(Op): + def __init__( + self, + input_treedef, + output_treeedef, + input_types, + output_types, + jitted_sol_op_jax, + jitted_vjp_sol_op_jax, + ): + self.vjp_sol_op = None + self.input_treedef = input_treedef + self.output_treedef = output_treeedef + self.input_types = input_types + self.output_types = output_types + self.jitted_sol_op_jax = jitted_sol_op_jax + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, *inputs): + self.num_inputs = len(inputs) + + # Define our output variables + print(self.output_types) + outputs = [pt.as_tensor_variable(type()) for type in self.output_types] + self.num_outputs = len(outputs) + + self.vjp_sol_op = VJPJAXOp( + self.input_treedef, + self.input_types, + self.jitted_vjp_sol_op_jax, + ) - return Apply(self, inputs, outputs) + return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): - results = self.jitted_sol_op_jax(inputs) - if self.num_outputs > 1: - for i in range(self.num_outputs): - outputs[i][0] = np.array(results[i], self.output_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.output_types[0].dtype) + def perform(self, node, inputs, outputs): + results = self.jitted_sol_op_jax(inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_sol_op_jax(inputs) + return results + + def grad(self, inputs, output_gradients): + # If a output is not used, it is disconnected and doesn't have a gradient. + # Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + if None not in self.output_types[i].shape: + output_gradients[i] = pt.zeros( + self.output_types[i].shape, self.output_types[i].dtype + ) + else: + output_gradients[i] = pt.zeros((), self.output_types[i].dtype) + result = self.vjp_sol_op(inputs, output_gradients) - def perform_jax(self, *inputs): - results = self.jitted_sol_op_jax(inputs) - return results + if self.num_inputs > 1: + return result + else: + return (result,) # Pytensor requires a tuple here + + +# vector-jacobian product Op +class VJPJAXOp(Op): + def __init__( + self, + input_treedef, + input_types, + jitted_vjp_sol_op_jax, + ): + self.input_treedef = input_treedef + self.input_types = input_types + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, y0, gz): + y0 = [ + pt.as_tensor_variable( + _y, + ).astype(self.input_types[i].dtype) + for i, _y in enumerate(y0) + ] + gz_not_disconntected = [ + pt.as_tensor_variable(_gz) + for _gz in gz + if not isinstance(_gz.type, DisconnectedType) + ] + outputs = [in_type() for in_type in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0 + gz_not_disconntected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if len(self.input_types) > 1: + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) - def grad(self, inputs, output_gradients): - # If a output is not used, it is disconnected and doesn't have a gradient. - # Set gradient here to zero for those outputs. - for i in range(self.num_outputs): - if isinstance(output_gradients[i].type, DisconnectedType): - if None not in self.output_types[i].shape: - output_gradients[i] = pt.zeros( - self.output_types[i].shape, self.output_types[i].dtype - ) - else: - output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_sol_op(inputs, output_gradients) - - if self.num_inputs > 1: - return result - else: - return (result,) # Pytensor requires a tuple here - - # vector-jacobian product Op - class VJPSolOp(Op): - def __init__( - self, - input_treedef, - input_types, - jitted_vjp_sol_op_jax, - ): - self.input_treedef = input_treedef - self.input_types = input_types - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax - - def make_node(self, y0, gz): - y0 = [ - pt.as_tensor_variable( - _y, - ).astype(self.input_types[i].dtype) - for i, _y in enumerate(y0) - ] - gz_not_disconntected = [ - pt.as_tensor_variable(_gz) - for _gz in gz - if not isinstance(_gz.type, DisconnectedType) - ] - outputs = [in_type() for in_type in self.input_types] - self.num_outputs = len(outputs) - return Apply(self, y0 + gz_not_disconntected, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) - if len(self.input_types) > 1: - for i, result in enumerate(results): - outputs[i][0] = np.array(result, self.input_types[i].dtype) + def perform_jax(self, *inputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if self.num_outputs == 1: + if isinstance(results, Sequence): + return results[0] else: - outputs[0][0] = np.array(results, self.input_types[0].dtype) + return results + else: + return tuple(results) - def perform_jax(self, *inputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) - if self.num_outputs == 1: - if isinstance(results, Sequence): - return results[0] - else: - return results - else: - return tuple(results) - SolOp.__name__ = name - SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name]) +@jax_funcify.register(JAXOp) +def sol_op_jax_funcify(op, **kwargs): + return op.perform_jax - VJPSolOp.__name__ = "VJP_" + name - VJPSolOp.__qualname__ = ".".join( - VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name] - ) - return SolOp, VJPSolOp +@jax_funcify.register(VJPJAXOp) +def vjp_sol_op_jax_funcify(op, **kwargs): + return op.perform_jax From 2543c9ddb385ec2a57667f408c57faa51bb3c56f Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 18:00:41 +0100 Subject: [PATCH 04/14] Added comment regarding flattening of inputs --- pytensor/link/jax/ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 60a3581550..fba6022232 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -85,8 +85,12 @@ def func(*args, **kwargs): vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) - # Infer shapes and types of the variables + # Flatten nested python structures, e.g. {"a": tensor_a, "b": [tensor_b]} + # becomes [tensor_a, tensor_b], because pytensor ops only accepts lists of + # pytensor.Variables as input. pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + + # Infer shapes and types of the variables pt_vars_types_flat = [var.type for var in pt_vars_flat] shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) From 36a71d211fd617d641eef904431ae5a801701341 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 20:16:48 +0100 Subject: [PATCH 05/14] Add as_jax_op to pytensor.__init__.py and to documentation --- doc/conf.py | 1 + doc/library/index.rst | 7 +++++++ pytensor/__init__.py | 12 ++++++++++++ pytensor/link/jax/ops.py | 31 ++++++++++++++----------------- tests/link/jax/test_as_jax_op.py | 3 +-- 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 1729efc4b1..e1143714d3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -36,6 +36,7 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/doc/library/index.rst b/doc/library/index.rst index 08a5b51c34..1b72f0ac84 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -63,6 +63,13 @@ Convert to Variable .. autofunction:: pytensor.as_symbolic(...) +Wrap JAX functions +================== + +.. autofunction:: as_jax_op(...) + + Alias for :func:`pytensor.link.jax.ops.as_jax_op` + Debug ===== diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 3c925ac2f2..a7f9aa8058 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph +try: + import pytensor.link.jax.ops + from pytensor.link.jax.ops import as_jax_op +except ImportError as e: + import_error_as_jax_op = e + + def as_jax_op(*args, **kwargs): + raise ImportError( + "JAX and/or equinox are not installed. Install them" + " to use this function: pip install pytensor[jax]" + ) from import_error_as_jax_op + # isort: on diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index fba6022232..6cb23470db 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -25,32 +25,29 @@ def _filter_ptvars(x): def as_jax_op(jaxfunc, name=None): - """Return a Pytensor from a JAX jittable function. + """Return a Pytensor function from a JAX jittable function. - This decorator transforms any JAX jittable function into a function that accepts - and returns `pytensor.Variables`. The jax jittable function can accept any - nested python structure (pytrees) as input, and return any nested Python structure. - - It requires to define the output types of the returned values as pytensor types. A - unique name should also be passed in case the name of the jaxfunc is identical to - some other node. The design of this function is based on - https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ + This decorator transforms any JAX-jittable function into a function that accepts + and returns `pytensor.Variable`. The JAX-jittable function can accept any + nested python structure (a `Pytree + `_) as input, and might return + any nested Python structure. Parameters ---------- - jaxfunc : jax jittable function - function for which the node is created, can return multiple tensors as a tuple. - It is required that all return values are able to transformed to - pytensor.Variable. - name: str + jaxfunc : JAX-jittable function + JAX function which will be wrapped in a Pytensor Op. + name: str, optional Name of the created pytensor Op, defaults to the name of the passed function. Only used internally in the pytensor graph. Returns ------- - A function which can be used in a pymc.Model as function, is differentiable - and the resulting model can be compiled either with the default C backend, or - the JAX backend. + Callable : + A function which expects a nested python structure of `pytensor.Variable` and + static variables as inputs and returns `pytensor.Variable` with the same + API as the original jaxfunc. The resulting model can be compiled either with the + default C backend or the JAX backend. Notes diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 8d404d76db..3842278a04 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -4,9 +4,8 @@ import pytest import pytensor.tensor as pt -from pytensor import config, grad +from pytensor import as_jax_op, config, grad from pytensor.graph.fg import FunctionGraph -from pytensor.link.jax.ops import as_jax_op from pytensor.scalar import all_types from pytensor.tensor import tensor from tests.link.jax.test_basic import compare_jax_and_py From d4a0b6a03a0188535dfa33dba310a90e216afaba Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 20:27:16 +0100 Subject: [PATCH 06/14] Add [jax] requirement to readthedocs in order to read the docstring of as_jax_op --- doc/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/environment.yml b/doc/environment.yml index d58af79cc6..dc653b1b38 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -24,4 +24,4 @@ dependencies: - pip - pip: - sphinx_sitemap - - -e .. + - -e ..[jax] From 65984b0ef9c72d5ee0a5b12e74f2931ffc1e62af Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 16 Dec 2024 12:24:44 +0100 Subject: [PATCH 07/14] Added an example to the docstring of as_jax_op --- pytensor/link/jax/ops.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 6cb23470db..11833a7884 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -49,6 +49,33 @@ def as_jax_op(jaxfunc, name=None): API as the original jaxfunc. The resulting model can be compiled either with the default C backend or the JAX backend. + Examples + -------- + + We define a JAX function `f_jax` that accepts a matrix `x`, a vector `y` and a + dictionary as input. This is transformed to a pytensor function with the decorator + `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. + for evaluation and calculating gradients. + + >>> import numpy + >>> import jax.numpy as jnp + >>> import pytensor + >>> import pytensor.tensor as pt + >>> x = pt.tensor("x", shape=(2,)) + >>> y = pt.tensor("y", shape=(2, 2)) + >>> a = pt.tensor("a", shape=()) + >>> args_dict = {"a": a} + >>> @pytensor.as_jax_op + ... def f_jax(x, y, args_dict): + ... z = jnp.dot(x, y) + args_dict["a"] + ... return z + >>> z = f_jax(x, y, args_dict) + >>> z_sum = pt.sum(z) + >>> grad_wrt_a = pt.grad(z_sum, a) + >>> f_all = pytensor.function([x, y, a], [z_sum, grad_wrt_a]) + >>> f_all(numpy.array([1, 2]), numpy.array([[1, 2], [3, 4]]), 1) + [array(19.), array(2.)] + Notes ----- @@ -327,7 +354,6 @@ def make_node(self, *inputs): self.num_inputs = len(inputs) # Define our output variables - print(self.output_types) outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) From 5960947ebfced0cc51f5de11e526b2831f4469fa Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 3 Feb 2025 17:44:35 +0100 Subject: [PATCH 08/14] Use infer_static_shape, currently still with the possibility to use the previous approach for testing purposes --- pytensor/link/jax/ops.py | 41 +++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 11833a7884..8b20370330 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -24,7 +24,7 @@ def _filter_ptvars(x): return isinstance(x, pt.Variable) -def as_jax_op(jaxfunc, name=None): +def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): """Return a Pytensor function from a JAX jittable function. This decorator transforms any JAX-jittable function into a function that accepts @@ -57,10 +57,10 @@ def as_jax_op(jaxfunc, name=None): `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. for evaluation and calculating gradients. - >>> import numpy - >>> import jax.numpy as jnp - >>> import pytensor - >>> import pytensor.tensor as pt + >>> import numpy # doctest: +ELLIPSIS + >>> import jax.numpy as jnp # doctest: +ELLIPSIS + >>> import pytensor # doctest: +ELLIPSIS + >>> import pytensor.tensor as pt # doctest: +ELLIPSIS >>> x = pt.tensor("x", shape=(2,)) >>> y = pt.tensor("y", shape=(2, 2)) >>> a = pt.tensor("a", shape=()) @@ -116,16 +116,27 @@ def func(*args, **kwargs): # Infer shapes and types of the variables pt_vars_types_flat = [var.type for var in pt_vars_flat] - shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) - shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) - - dummy_inputs_jax = jax.tree_util.tree_map( - lambda var, shape: jnp.empty( - [int(dim.eval()) for dim in shape], dtype=var.type.dtype - ), - pt_vars, - shapes_vars, - ) + + if use_infer_static_shape: + shapes_vars_flat = [ + pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat + ] + + dummy_inputs_jax_flat = [ + jnp.empty(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) + ] + + else: + shapes_vars_flat = pytensor.compile.builders.infer_shape( + pt_vars_flat, (), () + ) + dummy_inputs_jax_flat = [ + jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) + ] + + dummy_inputs_jax = tree_unflatten(vars_treedef, dummy_inputs_jax_flat) # Combine the static variables with the inputs, and split them again in the # output. Static variables don't take part in the graph, or might be a From 104df833b5d50223ed5f334ebef79662495e9f35 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 3 Feb 2025 17:53:16 +0100 Subject: [PATCH 09/14] Remove `sol` in variable names --- pytensor/link/jax/ops.py | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 8b20370330..992ba5ee20 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -160,11 +160,11 @@ def func(*args, **kwargs): ### Call the function that accepts flat inputs, which in turn calls the one that ### combines the inputs and static variables. - jitted_sol_op_jax = jax.jit(func_flattened) + jitted_jax_op = jax.jit(func_flattened) len_gz = len(pttypes_outvars) - vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) - jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + vjp_jax_op = _get_vjp_jax_op(func_flattened, len_gz) + jitted_vjp_jax_op = jax.jit(vjp_jax_op) # Get classes that creates a Pytensor Op out of our function that accept # flattened inputs. They are created each time, to set a custom name for the @@ -194,8 +194,8 @@ class VJPJAXOp_local(VJPJAXOp): outvars_treedef, input_types=pt_vars_types_flat, output_types=pttypes_outvars, - jitted_sol_op_jax=jitted_sol_op_jax, - jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, + jitted_jax_op=jitted_jax_op, + jitted_vjp_jax_op=jitted_vjp_jax_op, ) ### Evaluate the Pytensor Op and return unflattened results @@ -265,8 +265,8 @@ def get_func_with_vars(self, vars): return interior_func -def _get_vjp_sol_op_jax(jaxfunc, len_gz): - def vjp_sol_op_jax(args): +def _get_vjp_jax_op(jaxfunc, len_gz): + def vjp_jax_op(args): y0 = args[:-len_gz] gz = args[-len_gz:] if len(gz) == 1: @@ -290,7 +290,7 @@ def func(*inputs): else: return tuple(vjp_fn(gz)) - return vjp_sol_op_jax + return vjp_jax_op def _partition_jaxfunc(jaxfunc, static_vars, func_vars): @@ -350,16 +350,16 @@ def __init__( output_treeedef, input_types, output_types, - jitted_sol_op_jax, - jitted_vjp_sol_op_jax, + jitted_jax_op, + jitted_vjp_jax_op, ): - self.vjp_sol_op = None + self.vjp_jax_op = None self.input_treedef = input_treedef self.output_treedef = output_treeedef self.input_types = input_types self.output_types = output_types - self.jitted_sol_op_jax = jitted_sol_op_jax - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + self.jitted_jax_op = jitted_jax_op + self.jitted_vjp_jax_op = jitted_vjp_jax_op def make_node(self, *inputs): self.num_inputs = len(inputs) @@ -368,16 +368,16 @@ def make_node(self, *inputs): outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) - self.vjp_sol_op = VJPJAXOp( + self.vjp_jax_op = VJPJAXOp( self.input_treedef, self.input_types, - self.jitted_vjp_sol_op_jax, + self.jitted_vjp_jax_op, ) return Apply(self, inputs, outputs) def perform(self, node, inputs, outputs): - results = self.jitted_sol_op_jax(inputs) + results = self.jitted_jax_op(inputs) if self.num_outputs > 1: for i in range(self.num_outputs): outputs[i][0] = np.array(results[i], self.output_types[i].dtype) @@ -385,7 +385,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = np.array(results, self.output_types[0].dtype) def perform_jax(self, *inputs): - results = self.jitted_sol_op_jax(inputs) + results = self.jitted_jax_op(inputs) return results def grad(self, inputs, output_gradients): @@ -399,7 +399,7 @@ def grad(self, inputs, output_gradients): ) else: output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_sol_op(inputs, output_gradients) + result = self.vjp_jax_op(inputs, output_gradients) if self.num_inputs > 1: return result @@ -413,11 +413,11 @@ def __init__( self, input_treedef, input_types, - jitted_vjp_sol_op_jax, + jitted_vjp_jax_op, ): self.input_treedef = input_treedef self.input_types = input_types - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + self.jitted_vjp_jax_op = jitted_vjp_jax_op def make_node(self, y0, gz): y0 = [ @@ -436,7 +436,7 @@ def make_node(self, y0, gz): return Apply(self, y0 + gz_not_disconntected, outputs) def perform(self, node, inputs, outputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + results = self.jitted_vjp_jax_op(tuple(inputs)) if len(self.input_types) > 1: for i, result in enumerate(results): outputs[i][0] = np.array(result, self.input_types[i].dtype) @@ -444,7 +444,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = np.array(results, self.input_types[0].dtype) def perform_jax(self, *inputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + results = self.jitted_vjp_jax_op(tuple(inputs)) if self.num_outputs == 1: if isinstance(results, Sequence): return results[0] @@ -455,10 +455,10 @@ def perform_jax(self, *inputs): @jax_funcify.register(JAXOp) -def sol_op_jax_funcify(op, **kwargs): +def jax_op_funcify(op, **kwargs): return op.perform_jax @jax_funcify.register(VJPJAXOp) -def vjp_sol_op_jax_funcify(op, **kwargs): +def vjp_jax_op_funcify(op, **kwargs): return op.perform_jax From e11777e82994bdb62b3b0a080d292640039d6c10 Mon Sep 17 00:00:00 2001 From: Jonas Date: Tue, 4 Feb 2025 20:49:59 +0100 Subject: [PATCH 10/14] Rename tests and make static variables test more meaningfull --- tests/link/jax/test_as_jax_op.py | 54 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 3842278a04..286a1334f7 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -11,8 +11,7 @@ from tests.link.jax.test_basic import compare_jax_and_py -def test_as_jax_op1(): - # 2 parameters input, single output +def test_2in_1out(): rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -33,8 +32,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op2(): - # 2 parameters input, tuple output +def test_2in_tupleout(): rng = np.random.default_rng(2) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -55,8 +53,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op3(): - # 2 parameters input, list output +def test_2in_listout(): rng = np.random.default_rng(3) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -77,8 +74,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op4(): - # single 1d input, tuple output +def test_1din_tupleout(): rng = np.random.default_rng(4) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -96,8 +92,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op5(): - # single 0d input, tuple output +def test_0din_tupleout(): rng = np.random.default_rng(5) x = tensor("a", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -115,8 +110,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op6(): - # single input, list output +def test_1in_listout(): rng = np.random.default_rng(6) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -135,8 +129,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op7(): - # 2 parameters input with pytree, tuple output +def test_pytreein_tupleout(): rng = np.random.default_rng(7) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -159,8 +152,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op8(): - # 2 parameters input with pytree, pytree output +def test_pytreein_pytreeout(): rng = np.random.default_rng(8) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -180,8 +172,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op9(): - # 2 parameters input with pytree, pytree output and non-graph argument +def test_pytreein_pytreeout_w_nongraphargs(): rng = np.random.default_rng(9) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -191,18 +182,35 @@ def test_as_jax_op9(): ] @as_jax_op - def f(x, y, non_model_arg): - return jnp.exp(x), jax.tree_util.tree_map(jax.nn.sigmoid, y) - - out = f(x, y_tmp, "Hello World!") - grad_out = grad(pt.sum(out[0]), [x]) + def f(x, y, depth, which_variable): + if which_variable == "x": + var = x + elif which_variable == "y": + var = y["a"] + y["b"][0] + else: + return "Unsupported argument" + for _ in range(depth): + var = jax.nn.sigmoid(var) + return var + # arguments depth and which_variable are not part of the graph + out = f(x, y_tmp, depth=3, which_variable="x") + grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out[0], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + out = f(x, y_tmp, depth=7, which_variable="y") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + out = f(x, y_tmp, depth=10, which_variable="z") + assert out == "Unsupported argument" + def test_as_jax_op10(): # Use "None" in shape specification and have a non-used output of higher rank From d2e788fb82967bb41eae767b3402746e01fb5480 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 5 Feb 2025 17:40:26 +0100 Subject: [PATCH 11/14] More test renaming, forgot a few --- tests/link/jax/test_as_jax_op.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 286a1334f7..d361acec4c 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -11,7 +11,7 @@ from tests.link.jax.test_basic import compare_jax_and_py -def test_2in_1out(): +def test_two_inputs_single_output(): rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -32,7 +32,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_2in_tupleout(): +def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -53,7 +53,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_2in_listout(): +def test_two_inputs_list_output(): rng = np.random.default_rng(3) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -74,7 +74,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_1din_tupleout(): +def test_single_input_tuple_output(): rng = np.random.default_rng(4) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -92,7 +92,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_0din_tupleout(): +def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) x = tensor("a", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -110,7 +110,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_1in_listout(): +def test_single_input_list_output(): rng = np.random.default_rng(6) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -129,7 +129,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_tupleout(): +def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -152,7 +152,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_pytreeout(): +def test_pytree_input_pytree_output(): rng = np.random.default_rng(8) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -172,7 +172,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_pytreeout_w_nongraphargs(): +def test_pytree_input_with_non_graph_args(): rng = np.random.default_rng(9) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -212,8 +212,7 @@ def f(x, y, depth, which_variable): assert out == "Unsupported argument" -def test_as_jax_op10(): - # Use "None" in shape specification and have a non-used output of higher rank +def test_unused_matrix_product_and_exp_gradient(): rng = np.random.default_rng(10) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -235,8 +234,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op11(): - # Test unknown static shape +def test_unknown_static_shape(): rng = np.random.default_rng(11) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -260,8 +258,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op12(): - # Test non-array return values +def test_non_array_return_values(): rng = np.random.default_rng(12) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -283,8 +280,7 @@ def f(x, y, message): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op13(): - # Test nested functions +def test_nested_functions(): rng = np.random.default_rng(13) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) From ab326e540d00a92a9b9fe69c3f023ed560c04ecd Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 5 Feb 2025 17:47:27 +0100 Subject: [PATCH 12/14] Refactoring of ops.py: code is in general cleaner, and JAXOp can now be used without the decorator @as_jax_op --- pytensor/link/jax/ops.py | 692 ++++++++++++++++++++------------------- 1 file changed, 350 insertions(+), 342 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 992ba5ee20..ca780da9c8 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -1,6 +1,5 @@ """Convert a jax function to a pytensor compatible function.""" -import functools as ft import logging from collections.abc import Sequence @@ -20,62 +19,242 @@ log = logging.getLogger(__name__) -def _filter_ptvars(x): - return isinstance(x, pt.Variable) +class JAXOp(Op): + """ + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward computation and reverse-mode differentiation (via the VJPJAXOp class). + + Parameters + ---------- + input_types : list + A list of PyTensor types for each input variable. + output_types : list + A list of PyTensor types for each output variable. + flat_func : callable + The JAX function that computes outputs from inputs. Inputs and outputs have to be provided as flat arrays. + name : str, optional + A custom name for the Op instance. If provided, the class name will be + updated accordingly. + + Example + ------- + This example defines a simple function that sums the input array with a dynamic shape. + + >>> import numpy as np + >>> import jax + >>> import jax.numpy as jnp + >>> from pytensor.tensor import TensorType + >>> + >>> # Create the jax function that sums the input array. + >>> def sum_function(x, y): + ... return jnp.sum(x + y) + >>> + >>> # Create the input and output types, input has a dynamic shape. + >>> input_type = TensorType("float32", shape=(None,)) + >>> output_type = TensorType("float32", shape=(1,)) + >>> + >>> # Instantiate a JAXOp; tree definitions are set to None for simplicity. + >>> op = JAXOp( + ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" + ... ) + >>> # Define symbolic input variables. + >>> x = pt.tensor("x", dtype="float32", shape=(2,)) + >>> y = pt.tensor("y", dtype="float32", shape=(2,)) + >>> # Compile a PyTensor function. + >>> result = op(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print( + ... f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array(14., dtype=float32)] + >>> + >>> # Compute the gradient of op(x, y) with respect to x. + >>> g = pt.grad(result[0], x) + >>> grad_f = pytensor.function([x, y], [g]) + >>> print( + ... grad_f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array([1., 1.], dtype=float32)] + """ + + def __init__(self, input_types, output_types, flat_func, name=None): + self.input_types = input_types + self.output_types = output_types + self.num_inputs = len(input_types) + self.num_outputs = len(output_types) + normalized_flat_func = _normalize_flat_func(flat_func) + self.jitted_func = jax.jit(normalized_flat_func) + + vjp_func = _get_vjp_jax_op(normalized_flat_func, len(output_types)) + normalized_vjp_func = _normalize_flat_func(vjp_func) + self.jitted_vjp = jax.jit(normalized_vjp_func) + self.vjp_jax_op = VJPJAXOp( + self.input_types, + self.jitted_vjp, + name=("VJP" + name) if name is not None else None, + ) + + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, *inputs): + outputs = [pt.as_tensor_variable(typ()) for typ in self.output_types] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_func(*inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_func(*inputs) + + def grad(self, inputs, output_gradients): + # If a output is not used, it gets disconnected by pytensor and won't have a + # gradient. Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + zero_shape = ( + self.output_types[i].shape + if None not in self.output_types[i].shape + else () + ) + output_gradients[i] = pt.zeros(zero_shape, self.output_types[i].dtype) + + # Compute the gradient. + grad_result = self.vjp_jax_op(inputs, output_gradients) + return grad_result if self.num_inputs > 1 else (grad_result,) + + +class VJPJAXOp(Op): + def __init__(self, input_types, jitted_vjp, name=None): + self.input_types = input_types + self.jitted_vjp = jitted_vjp + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, y0, gz): + y0_converted = [ + pt.as_tensor_variable(y).astype(self.input_types[i].dtype) + for i, y in enumerate(y0) + ] + gz_not_disconnected = [ + pt.as_tensor_variable(g) + for g in gz + if not isinstance(g.type, DisconnectedType) + ] + outputs = [typ() for typ in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0_converted + gz_not_disconnected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp(*inputs) + if len(self.input_types) > 1: + for i, res in enumerate(results): + outputs[i][0] = np.array(res, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_vjp(*inputs) + + +def _normalize_flat_func(func): + def normalized_func(*flat_vars): + out_flat = func(*flat_vars) + if isinstance(out_flat, Sequence): + return tuple(out_flat) if len(out_flat) > 1 else out_flat[0] + else: + return out_flat + + return normalized_func + + +def _get_vjp_jax_op(flat_func, num_out): + def vjp_op(*args): + y0 = args[:-num_out] + gz = args[-num_out:] + if len(gz) == 1: + gz = gz[0] + + def f(*inputs): + return flat_func(*inputs) + + primals, vjp_fn = jax.vjp(f, *y0) + + def broadcast_to_shape(g, shape): + if g.ndim > 0 and g.shape[0] == 1: + g_squeezed = jnp.squeeze(g, axis=0) + else: + g_squeezed = g + return jnp.broadcast_to(g_squeezed, shape) + + gz = tree_map( + lambda g, p: broadcast_to_shape(g, jnp.shape(p)).astype(p.dtype), + gz, + primals, + ) + return vjp_fn(gz) + + return vjp_op def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): - """Return a Pytensor function from a JAX jittable function. + """Return a Pytensor-compatible function from a JAX jittable function. - This decorator transforms any JAX-jittable function into a function that accepts - and returns `pytensor.Variable`. The JAX-jittable function can accept any + This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` + objects. The JAX-jittable function can accept any nested python structure (a `Pytree `_) as input, and might return any nested Python structure. Parameters ---------- - jaxfunc : JAX-jittable function - JAX function which will be wrapped in a Pytensor Op. - name: str, optional - Name of the created pytensor Op, defaults to the name of the passed function. - Only used internally in the pytensor graph. + jaxfunc : Callable + A JAX function to be wrapped. + use_infer_static_shape : bool, optional + If True, use static shape inference; otherwise, use runtime shape inference. + Default is True. + name : str, optional + A custom name for the created Pytensor Op instance. If None, the name of jaxfunc + is used. Returns ------- - Callable : - A function which expects a nested python structure of `pytensor.Variable` and - static variables as inputs and returns `pytensor.Variable` with the same - API as the original jaxfunc. The resulting model can be compiled either with the - default C backend or the JAX backend. + Callable + A function that wraps the given JAX function so that it can be called with + pytensor.Variable inputs and returns pytensor.Variable outputs. Examples -------- - We define a JAX function `f_jax` that accepts a matrix `x`, a vector `y` and a - dictionary as input. This is transformed to a pytensor function with the decorator - `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. - for evaluation and calculating gradients. - - >>> import numpy # doctest: +ELLIPSIS - >>> import jax.numpy as jnp # doctest: +ELLIPSIS - >>> import pytensor # doctest: +ELLIPSIS - >>> import pytensor.tensor as pt # doctest: +ELLIPSIS - >>> x = pt.tensor("x", shape=(2,)) - >>> y = pt.tensor("y", shape=(2, 2)) - >>> a = pt.tensor("a", shape=()) - >>> args_dict = {"a": a} - >>> @pytensor.as_jax_op - ... def f_jax(x, y, args_dict): - ... z = jnp.dot(x, y) + args_dict["a"] - ... return z - >>> z = f_jax(x, y, args_dict) - >>> z_sum = pt.sum(z) - >>> grad_wrt_a = pt.grad(z_sum, a) - >>> f_all = pytensor.function([x, y, a], [z_sum, grad_wrt_a]) - >>> f_all(numpy.array([1, 2]), numpy.array([[1, 2], [3, 4]]), 1) - [array(19.), array(2.)] - + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @as_jax_op + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.scalar("x") + >>> y = pt.scalar("y") + >>> result = add(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print(f(1, 2)) + [array(3.)] Notes ----- @@ -87,145 +266,165 @@ def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the variables. Shapes are inferred using :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + """ def func(*args, **kwargs): - """Return a pytensor from a jax jittable function.""" - ### Split variables: in the ones that will be transformed to JAX inputs, - ### pytensor.Variables; _WrappedFunc, that are functions that have been returned - ### from a transformed function; and the rest, static variables that are not - ### transformed. - - pt_vars, static_vars_tmp = eqx.partition( - (args, kwargs), _filter_ptvars, is_leaf=callable + # 1. Partition inputs into dynamic pytensor variables, wrapped functions and + # static variables. + # Static variables don't take part in the graph. + pt_vars, func_vars, static_vars = _split_inputs(args, kwargs) + + # 2. Get the original variables from the wrapped functions. + vars_from_func = tree_map(lambda f: f.get_vars(), func_vars) + input_dict = {"vars": pt_vars, "vars_from_func": vars_from_func} + + # 3. Flatten the input dictionary. + # e.g. {"a": tensor_a, "b": [tensor_b]} becomes [tensor_a, tensor_b], because + # pytensor ops only accepts lists of pytensor.Variables as input. + pt_vars_flat, pt_vars_treedef = tree_flatten( + input_dict, ) - # is_leaf=callable is used, as libraries like diffrax or equinox might return - # functions that are still seen as a nested pytree structure. We consider them - # as wrappable functions, that will be wrapped with _WrappedFunc. + pt_types = [var.type for var in pt_vars_flat] - func_vars, static_vars = eqx.partition( - static_vars_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + # 4. Create dummy inputs for shape inference. + shapes = _infer_shapes(pt_vars_flat, use_infer_static_shape) + dummy_in_flat = _create_dummy_inputs_from_shapes( + pt_vars_flat, shapes, use_infer_static_shape ) - vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) - pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) + dummy_inputs = tree_unflatten(pt_vars_treedef, dummy_in_flat) - # Flatten nested python structures, e.g. {"a": tensor_a, "b": [tensor_b]} - # becomes [tensor_a, tensor_b], because pytensor ops only accepts lists of - # pytensor.Variables as input. - pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + # 5. Partition the JAX function into dynamic and static parts. + jaxfunc_dynamic, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + flat_func = _flatten_func(jaxfunc_dynamic, pt_vars_treedef) + + # 6. Infer output types using JAX's eval_shape. + out_treedef, pt_types_flat = _infer_output_types(jaxfunc_dynamic, dummy_inputs) + + # 7. Create the Pytensor Op instance. + curr_name = "JAXOp_" + (jaxfunc.__name__ if name is None else name) + op_instance = JAXOp( + pt_types, + pt_types_flat, + flat_func, + name=curr_name, + ) - # Infer shapes and types of the variables - pt_vars_types_flat = [var.type for var in pt_vars_flat] + # 8. Execute the op and unflatten the outputs. + output_flat = op_instance(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] + outvars = tree_unflatten(out_treedef, output_flat) - if use_infer_static_shape: - shapes_vars_flat = [ - pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat - ] + # 9. Combine with static outputs and wrap eventual output functions with + # _WrappedFunc + return _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars) - dummy_inputs_jax_flat = [ - jnp.empty(shape, dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) - ] + return func - else: - shapes_vars_flat = pytensor.compile.builders.infer_shape( - pt_vars_flat, (), () - ) - dummy_inputs_jax_flat = [ - jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) - ] - dummy_inputs_jax = tree_unflatten(vars_treedef, dummy_inputs_jax_flat) +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + - # Combine the static variables with the inputs, and split them again in the - # output. Static variables don't take part in the graph, or might be a - # a function that is returned. - jaxfunc_partitioned, static_out_dic = _partition_jaxfunc( - jaxfunc, static_vars, func_vars - ) +def _split_inputs(args, kwargs): + """Split inputs into pytensor variables, static values and wrapped functions.""" - func_flattened = _flatten_func(jaxfunc_partitioned, vars_treedef) + pt_vars, static_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + func_vars, static_vars = eqx.partition( + static_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + return pt_vars, func_vars, static_vars - jaxtypes_outvars = jax.eval_shape( - ft.partial(jaxfunc_partitioned, vars=dummy_inputs_jax), - ) - jaxtypes_outvars_flat, outvars_treedef = tree_flatten(jaxtypes_outvars) +def _infer_shapes(pt_vars_flat, use_infer_static_shape): + """Infer shapes of pytensor variables.""" + if use_infer_static_shape: + return [pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat] + else: + return pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + - pttypes_outvars = [ - pt.TensorType(dtype=var.dtype, shape=var.shape) - for var in jaxtypes_outvars_flat +def _create_dummy_inputs_from_shapes(pt_vars_flat, shapes, use_infer_static_shape): + """Create dummy inputs for the jax function from inferred shapes.""" + if use_infer_static_shape: + return [ + jnp.empty(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) + ] + else: + return [ + jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) ] - ### Call the function that accepts flat inputs, which in turn calls the one that - ### combines the inputs and static variables. - jitted_jax_op = jax.jit(func_flattened) - len_gz = len(pttypes_outvars) - vjp_jax_op = _get_vjp_jax_op(func_flattened, len_gz) - jitted_vjp_jax_op = jax.jit(vjp_jax_op) +def _infer_output_types(jaxfunc_part, dummy_inputs): + """Infer output types using JAX's eval_shape.""" + jax_out = jax.eval_shape(jaxfunc_part, dummy_inputs) + jax_out_flat, out_treedef = tree_flatten(jax_out) + pt_out_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) for var in jax_out_flat + ] + return out_treedef, pt_out_types - # Get classes that creates a Pytensor Op out of our function that accept - # flattened inputs. They are created each time, to set a custom name for the - # class. - class JAXOp_local(JAXOp): - pass - class VJPJAXOp_local(VJPJAXOp): - pass +def _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars): + """Process and combine static outputs with the dynamic ones.""" + static_funcs, static_vars_out = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + flat_static, func_treedef = tree_flatten(static_funcs, is_leaf=callable) + for i in range(len(flat_static)): + flat_static[i] = _WrappedFunc(jaxfunc, i, *args, **kwargs) + static_funcs = tree_unflatten(func_treedef, flat_static) + static_combined = eqx.combine(static_funcs, static_vars_out, is_leaf=callable) + return eqx.combine(outvars, static_combined, is_leaf=callable) - if name is None: - curr_name = jaxfunc.__name__ - else: - curr_name = name - JAXOp_local.__name__ = curr_name - JAXOp_local.__qualname__ = ".".join( - JAXOp_local.__qualname__.split(".")[:-1] + [curr_name] - ) - VJPJAXOp_local.__name__ = "VJP_" + curr_name - VJPJAXOp_local.__qualname__ = ".".join( - VJPJAXOp_local.__qualname__.split(".")[:-1] + ["VJP_" + curr_name] - ) +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Split the jax function into dynamic and static components. - local_op = JAXOp_local( - vars_treedef, - outvars_treedef, - input_types=pt_vars_types_flat, - output_types=pttypes_outvars, - jitted_jax_op=jitted_jax_op, - jitted_vjp_jax_op=jitted_vjp_jax_op, - ) + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function - ### Evaluate the Pytensor Op and return unflattened results - output_flat = local_op(*pt_vars_flat) - if not isinstance(output_flat, Sequence): - output_flat = [output_flat] # tree_unflatten expects a sequence. - outvars = tree_unflatten(outvars_treedef, output_flat) + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} - static_outfuncs, static_outvars = eqx.partition( - static_out_dic["out"], callable, is_leaf=callable + def jaxfunc_partitioned(vars): + dyn_vars, func_vars_input = vars["vars"], vars["vars_from_func"] + evaluated_funcs = tree_map( + lambda f, v: f.get_func_with_vars(v), func_vars, func_vars_input ) - - static_outfuncs_flat, treedef_outfuncs = jax.tree_util.tree_flatten( - static_outfuncs, is_leaf=callable + args, kwargs = eqx.combine( + dyn_vars, static_vars, evaluated_funcs, is_leaf=callable ) - for i_func, _ in enumerate(static_outfuncs_flat): - static_outfuncs_flat[i_func] = _WrappedFunc( - jaxfunc, i_func, *args, **kwargs - ) + output = jaxfunc(*args, **kwargs) + out_dyn, static_out = eqx.partition(output, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return out_dyn - static_outfuncs = jax.tree_util.tree_unflatten( - treedef_outfuncs, static_outfuncs_flat - ) - static_vars = eqx.combine(static_outfuncs, static_outvars, is_leaf=callable) + return jaxfunc_partitioned, static_out_dic - output = eqx.combine(outvars, static_vars, is_leaf=callable) - return output +def _flatten_func(jaxfunc, treedef): + def flat_func(*flat_vars): + vars = tree_unflatten(treedef, flat_vars) + out = jaxfunc(vars) + out_flat, _ = tree_flatten(out) + return out_flat - return func + return flat_func class _WrappedFunc: @@ -233,6 +432,7 @@ def __init__(self, exterior_func, i_func, *args, **kwargs): self.args = args self.kwargs = kwargs self.i_func = i_func + # Partition the inputs to separate dynamic variables from static ones. vars, static_vars = eqx.partition( (self.args, self.kwargs), _filter_ptvars, is_leaf=callable ) @@ -244,8 +444,7 @@ def __call__(self, *args, **kwargs): # If called, assume that args and kwargs are pytensors, so return the result # as pytensors. def f(func, *args, **kwargs): - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) return as_jax_op(f)(self, *args, **kwargs) @@ -256,202 +455,11 @@ def get_func_with_vars(self, vars): # Use other variables than the saved ones, to generate the function. This # is used to transform vars externally from pytensor to JAX, and use the # then create the function which is returned. - args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) output = self.exterior_func(*args, **kwargs) - outfuncs, _ = eqx.partition(output, callable, is_leaf=callable) - outfuncs_flat, _ = jax.tree_util.tree_flatten(outfuncs, is_leaf=callable) - interior_func = outfuncs_flat[self.i_func] - return interior_func - - -def _get_vjp_jax_op(jaxfunc, len_gz): - def vjp_jax_op(args): - y0 = args[:-len_gz] - gz = args[-len_gz:] - if len(gz) == 1: - gz = gz[0] - - def func(*inputs): - return jaxfunc(inputs) - - primals, vjp_fn = jax.vjp(func, *y0) - gz = tree_map( - lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)).astype( - primal.dtype - ), # Also cast to the dtype of the primal, this shouldn't be - # necessary, but it happens that the returned dtype of the gradient isn't - # the same anymore. - gz, - primals, - ) - if len(y0) == 1: - return vjp_fn(gz)[0] - else: - return tuple(vjp_fn(gz)) - - return vjp_jax_op - - -def _partition_jaxfunc(jaxfunc, static_vars, func_vars): - """Partition the jax function into static and non-static variables. - - Returns a function that accepts only non-static variables and returns the non-static - variables. The returned static variables are stored in a dictionary and returned, - to allow the referencing after creating the function - - Additionally wrapped functions saved in func_vars are regenerated with - vars["vars_from_func"] as input, to allow the transformation of the variables. - """ - static_out_dic = {"out": None} - - def jaxfunc_partitioned(vars): - vars, vars_from_func = vars["vars"], vars["vars_from_func"] - func_vars_evaled = tree_map( - lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func - ) - args, kwargs = eqx.combine( - vars, static_vars, func_vars_evaled, is_leaf=callable - ) - - out = jaxfunc(*args, **kwargs) - outvars, static_out = eqx.partition(out, eqx.is_array, is_leaf=callable) - static_out_dic["out"] = static_out - return outvars - - return jaxfunc_partitioned, static_out_dic - - -### Construct the function that accepts flat inputs and returns flat outputs. -def _flatten_func(jaxfunc, vars_treedef): - def func_flattened(vars_flat): - vars = tree_unflatten(vars_treedef, vars_flat) - outvars = jaxfunc(vars) - outvars_flat, _ = tree_flatten(outvars) - return _normalize_flat_output(outvars_flat) - - return func_flattened - - -def _normalize_flat_output(output): - if len(output) > 1: - return tuple( - output - ) # Transform to tuple because jax makes a difference between - # tuple and list and not pytensor - else: - return output[0] - - -class JAXOp(Op): - def __init__( - self, - input_treedef, - output_treeedef, - input_types, - output_types, - jitted_jax_op, - jitted_vjp_jax_op, - ): - self.vjp_jax_op = None - self.input_treedef = input_treedef - self.output_treedef = output_treeedef - self.input_types = input_types - self.output_types = output_types - self.jitted_jax_op = jitted_jax_op - self.jitted_vjp_jax_op = jitted_vjp_jax_op - - def make_node(self, *inputs): - self.num_inputs = len(inputs) - - # Define our output variables - outputs = [pt.as_tensor_variable(type()) for type in self.output_types] - self.num_outputs = len(outputs) - - self.vjp_jax_op = VJPJAXOp( - self.input_treedef, - self.input_types, - self.jitted_vjp_jax_op, - ) - - return Apply(self, inputs, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_jax_op(inputs) - if self.num_outputs > 1: - for i in range(self.num_outputs): - outputs[i][0] = np.array(results[i], self.output_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.output_types[0].dtype) - - def perform_jax(self, *inputs): - results = self.jitted_jax_op(inputs) - return results - - def grad(self, inputs, output_gradients): - # If a output is not used, it is disconnected and doesn't have a gradient. - # Set gradient here to zero for those outputs. - for i in range(self.num_outputs): - if isinstance(output_gradients[i].type, DisconnectedType): - if None not in self.output_types[i].shape: - output_gradients[i] = pt.zeros( - self.output_types[i].shape, self.output_types[i].dtype - ) - else: - output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_jax_op(inputs, output_gradients) - - if self.num_inputs > 1: - return result - else: - return (result,) # Pytensor requires a tuple here - - -# vector-jacobian product Op -class VJPJAXOp(Op): - def __init__( - self, - input_treedef, - input_types, - jitted_vjp_jax_op, - ): - self.input_treedef = input_treedef - self.input_types = input_types - self.jitted_vjp_jax_op = jitted_vjp_jax_op - - def make_node(self, y0, gz): - y0 = [ - pt.as_tensor_variable( - _y, - ).astype(self.input_types[i].dtype) - for i, _y in enumerate(y0) - ] - gz_not_disconntected = [ - pt.as_tensor_variable(_gz) - for _gz in gz - if not isinstance(_gz.type, DisconnectedType) - ] - outputs = [in_type() for in_type in self.input_types] - self.num_outputs = len(outputs) - return Apply(self, y0 + gz_not_disconntected, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_vjp_jax_op(tuple(inputs)) - if len(self.input_types) > 1: - for i, result in enumerate(results): - outputs[i][0] = np.array(result, self.input_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.input_types[0].dtype) - - def perform_jax(self, *inputs): - results = self.jitted_vjp_jax_op(tuple(inputs)) - if self.num_outputs == 1: - if isinstance(results, Sequence): - return results[0] - else: - return results - else: - return tuple(results) + out_funcs, _ = eqx.partition(output, callable, is_leaf=callable) + out_funcs_flat, _ = tree_flatten(out_funcs, is_leaf=callable) + return out_funcs_flat[self.i_func] @jax_funcify.register(JAXOp) From 48fbf0a1a12764e7d03a732f12cb356389632ee6 Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 6 Feb 2025 13:45:46 +0100 Subject: [PATCH 13/14] Clean up tests --- tests/link/jax/test_as_jax_op.py | 124 ++++++++++++++----------------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index d361acec4c..71c34c04e5 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -13,8 +13,8 @@ def test_two_inputs_single_output(): rng = np.random.default_rng(1) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -34,8 +34,8 @@ def f(x, y): def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -44,19 +44,22 @@ def test_two_inputs_tuple_output(): def f(x, y): return jax.nn.sigmoid(x + y), y * 2 - out, _ = f(x, y) - grad_out = grad(pt.sum(out), [x, y]) + out1, out2 = f(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + # must_be_device_array is False, because the with disabled jit compilation, + # inputs are not automatically transformed to jax.Array anymore + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) -def test_two_inputs_list_output(): +def test_two_inputs_list_output_one_unused_output(): + # One output is unused, to test whether the wrapper can handle DisconnectedType rng = np.random.default_rng(3) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -76,63 +79,62 @@ def f(x, y): def test_single_input_tuple_output(): rng = np.random.default_rng(4) - x = tensor("a", shape=(2,)) + x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return jax.nn.sigmoid(x), x * 2 - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) - x = tensor("a", shape=()) + x = tensor("x", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return jax.nn.sigmoid(x), x - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_single_input_list_output(): rng = np.random.default_rng(6) - x = tensor("a", shape=(2,)) + x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return [jax.nn.sigmoid(x), 2 * x] - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) - with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) y_tmp = {"y": y, "y2": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -149,13 +151,13 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_pytree_input_pytree_output(): rng = np.random.default_rng(8) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(1,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) y_tmp = {"a": y, "b": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -171,11 +173,14 @@ def f(x, y): fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + def test_pytree_input_with_non_graph_args(): rng = np.random.default_rng(9) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(1,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) y_tmp = {"a": y, "b": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -212,10 +217,13 @@ def f(x, y, depth, which_variable): assert out == "Unsupported argument" -def test_unused_matrix_product_and_exp_gradient(): +def test_unused_matrix_product(): + # A matrix output is unused, to test whether the wrapper can handle a + # DisconnectedType with a larger dimension. + rng = np.random.default_rng(10) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -236,19 +244,19 @@ def f(x, y): def test_unknown_static_shape(): rng = np.random.default_rng(11) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - x = pt.cumsum(x) # Now x has an unknown shape + x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape @as_jax_op def f(x, y): return x * jnp.ones(3) - out = f(x, y) + out = f(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -258,32 +266,10 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_non_array_return_values(): - rng = np.random.default_rng(12) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) - test_values = [ - rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) - ] - - @as_jax_op - def f(x, y, message): - return x * jnp.ones(3), "Success: " + message - - out = f(x, y, "Hi") - grad_out = grad(pt.sum(out[0]), [x]) - - fg = FunctionGraph([x, y], [out[0], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) - - with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) - - def test_nested_functions(): rng = np.random.default_rng(13) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -319,8 +305,8 @@ class TestDtypes: @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) def test_different_in_output(self, in_dtype, out_dtype): - x = tensor("a", shape=(3,), dtype=in_dtype) - y = tensor("b", shape=(3,), dtype=in_dtype) + x = tensor("x", shape=(3,), dtype=in_dtype) + y = tensor("y", shape=(3,), dtype=in_dtype) if "int" in in_dtype: test_values = [ @@ -356,8 +342,8 @@ def f(x, y): @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) def test_test_different_inputs(self, in1_dtype, in2_dtype): - x = tensor("a", shape=(3,), dtype=in1_dtype) - y = tensor("b", shape=(3,), dtype=in2_dtype) + x = tensor("x", shape=(3,), dtype=in1_dtype) + y = tensor("y", shape=(3,), dtype=in2_dtype) if "int" in in1_dtype: test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] From b8c4523bdb1c0598c12f3bf162a2164c43902bed Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 6 Feb 2025 14:13:56 +0100 Subject: [PATCH 14/14] Add to some tests a direct call to JAXOp --- tests/link/jax/test_as_jax_op.py | 131 +++++++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 17 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 71c34c04e5..62fd270032 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -6,8 +6,9 @@ import pytensor.tensor as pt from pytensor import as_jax_op, config, grad from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import JAXOp from pytensor.scalar import all_types -from pytensor.tensor import tensor +from pytensor.tensor import TensorType, tensor from tests.link.jax.test_basic import compare_jax_and_py @@ -19,11 +20,11 @@ def test_two_inputs_single_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return jax.nn.sigmoid(x + y) - out = f(x, y) + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -31,6 +32,17 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) @@ -40,11 +52,11 @@ def test_two_inputs_tuple_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return jax.nn.sigmoid(x + y), y * 2 - out1, out2 = f(x, y) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out1 + out2), [x, y]) fg = FunctionGraph([x, y], [out1, out2, *grad_out]) @@ -54,6 +66,17 @@ def f(x, y): # inputs are not automatically transformed to jax.Array anymore fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_two_inputs_list_output_one_unused_output(): # One output is unused, to test whether the wrapper can handle DisconnectedType @@ -64,11 +87,11 @@ def test_two_inputs_list_output_one_unused_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return [jax.nn.sigmoid(x + y), y * 2] - out, _ = f(x, y) + # Test with as_jax_op decorator + out, _ = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -76,17 +99,28 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out, _ = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_single_input_tuple_output(): rng = np.random.default_rng(4) x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return jax.nn.sigmoid(x), x * 2 - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -94,17 +128,28 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) x = tensor("x", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return jax.nn.sigmoid(x), x - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -112,17 +157,28 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_single_input_list_output(): rng = np.random.default_rng(6) x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return [jax.nn.sigmoid(x), 2 * x] - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -130,6 +186,20 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage, with unspecified output shapes + jax_op = JAXOp( + [x.type], + [ + TensorType(config.floatX, shape=(None,)), + TensorType(config.floatX, shape=(None,)), + ], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) @@ -144,6 +214,7 @@ def test_pytree_input_tuple_output(): def f(x, y): return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + # Test with as_jax_op decorator out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]), [x, y]) @@ -167,6 +238,7 @@ def test_pytree_input_pytree_output(): def f(x, y): return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + # Test with as_jax_op decorator out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) @@ -198,6 +270,7 @@ def f(x, y, depth, which_variable): var = jax.nn.sigmoid(var) return var + # Test with as_jax_op decorator # arguments depth and which_variable are not part of the graph out = f(x, y_tmp, depth=3, which_variable="x") grad_out = grad(pt.sum(out), [x]) @@ -228,11 +301,11 @@ def test_unused_matrix_product(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return x[:, None] @ y[None], jnp.exp(x) - out = f(x, y) + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out[1]), [x]) fg = FunctionGraph([x, y], [out[1], *grad_out]) @@ -241,6 +314,20 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [ + TensorType(config.floatX, shape=(3, 3)), + TensorType(config.floatX, shape=(3,)), + ], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_unknown_static_shape(): rng = np.random.default_rng(11) @@ -252,11 +339,10 @@ def test_unknown_static_shape(): x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape - @as_jax_op def f(x, y): return x * jnp.ones(3) - out = f(x_cumsum, y) + out = as_jax_op(f)(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -265,6 +351,17 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(None,))], + f, + ) + out = jax_op(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_nested_functions(): rng = np.random.default_rng(13)