diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 04572b29d0..5154ae63c1 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -11,7 +11,7 @@ import pytensor from pytensor.compile.ops import ViewOp from pytensor.configdefaults import config -from pytensor.graph import utils +from pytensor.graph import utils, vectorize_graph from pytensor.graph.basic import Apply, NominalVariable, Variable from pytensor.graph.null_type import NullType, null_type from pytensor.graph.op import get_test_values @@ -2021,7 +2021,13 @@ def __str__(self): Exception args: {args_msg}""" -def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"): +def jacobian( + expression, + wrt, + consider_constant=None, + disconnected_inputs="raise", + vectorize: bool = False, +): """ Compute the full Jacobian, row by row. @@ -2051,62 +2057,71 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise output, then a zero variable is returned. The return value is of same type as `wrt`: a list/tuple or TensorVariable in all cases. """ + # from pytensor.tensor import arange, scalar + from pytensor.tensor import eye if not isinstance(expression, Variable): raise TypeError("jacobian expects a Variable as `expression`") - if expression.ndim > 1: - raise ValueError( - "jacobian expects a 1 dimensional variable as `expression`." - " If not use flatten to make it a vector" - ) - using_list = isinstance(wrt, list) using_tuple = isinstance(wrt, tuple) + grad_kwargs = { + "consider_constant": consider_constant, + "disconnected_inputs": disconnected_inputs, + } if isinstance(wrt, list | tuple): wrt = list(wrt) else: wrt = [wrt] - if expression.ndim == 0: - # expression is just a scalar, use grad - return as_list_or_tuple( - using_list, - using_tuple, - grad( - expression, - wrt, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, - ), + if all(expression.type.broadcastable): + jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs) + + elif vectorize: + expression_flat = expression.ravel() + row_tangent = _float_ones_like(expression_flat).type("row_tangent") + jacobian_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs) + jacobian_matrices = vectorize_graph( + jacobian_rows, + replace={row_tangent: eye(expression_flat.size, dtype=row_tangent.dtype)}, ) + # row_index = scalar("idx", dtype="int64") + # jacobian_rows = grad( + # expression.ravel()[row_index], + # wrt=wrt, + # **grad_kwargs + # ) + # rows_indices = arange(expression.size) + # jacobian_matrices = vectorize_graph( + # jacobian_rows, replace={row_index: rows_indices} + # ) + else: - def inner_function(*args): - idx = args[0] - expr = args[1] - rvals = [] - for inp in args[2:]: - rval = grad( - expr[idx], - inp, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, + def inner_function(*args): + idx, expr, *wrt = args + return grad(expr[idx], wrt, **grad_kwargs) + + jacobian_matrices, updates = pytensor.scan( + inner_function, + sequences=pytensor.tensor.arange(expression.size), + non_sequences=[expression.ravel(), *wrt], + return_list=True, + ) + if updates: + raise ValueError( + "The scan used to build the jacobian matrices returned a list of updates" ) - rvals.append(rval) - return rvals - - # Computing the gradients does not affect the random seeds on any random - # generator used n expression (because during computing gradients we are - # just backtracking over old values. (rp Jan 2012 - if anyone has a - # counter example please show me) - jacobs, updates = pytensor.scan( - inner_function, - sequences=pytensor.tensor.arange(expression.shape[0]), - non_sequences=[expression, *wrt], - ) - assert not updates, "Scan has returned a list of updates; this should not happen." - return as_list_or_tuple(using_list, using_tuple, jacobs) + + if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim): + # There was some raveling or squeezing done prior to getting the jacobians + # Reshape into original shapes + jacobian_matrices = [ + jac_matrix.reshape((*expression.shape, *w.shape)) + for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True) + ] + + return as_list_or_tuple(using_list, using_tuple, jacobian_matrices) def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 061a159fc2..9583d2cf72 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3074,6 +3074,10 @@ def flatten(x, ndim=1): else: dims = (-1,) + if len(dims) == _x.ndim: + # Nothing to ravel + return _x + x_reshaped = _x.reshape(dims) shape_kept_dims = _x.type.shape[: ndim - 1] bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])