|
11 | 11 | import pytensor
|
12 | 12 | from pytensor.compile.ops import ViewOp
|
13 | 13 | from pytensor.configdefaults import config
|
14 |
| -from pytensor.graph import utils |
| 14 | +from pytensor.graph import utils, vectorize_graph |
15 | 15 | from pytensor.graph.basic import Apply, NominalVariable, Variable
|
16 | 16 | from pytensor.graph.null_type import NullType, null_type
|
17 | 17 | from pytensor.graph.op import get_test_values
|
@@ -2021,7 +2021,13 @@ def __str__(self):
|
2021 | 2021 | Exception args: {args_msg}"""
|
2022 | 2022 |
|
2023 | 2023 |
|
2024 |
| -def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"): |
| 2024 | +def jacobian( |
| 2025 | + expression, |
| 2026 | + wrt, |
| 2027 | + consider_constant=None, |
| 2028 | + disconnected_inputs="raise", |
| 2029 | + vectorize: bool = False, |
| 2030 | +): |
2025 | 2031 | """
|
2026 | 2032 | Compute the full Jacobian, row by row.
|
2027 | 2033 |
|
@@ -2051,62 +2057,71 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
|
2051 | 2057 | output, then a zero variable is returned. The return value is
|
2052 | 2058 | of same type as `wrt`: a list/tuple or TensorVariable in all cases.
|
2053 | 2059 | """
|
| 2060 | + # from pytensor.tensor import arange, scalar |
| 2061 | + from pytensor.tensor import eye |
2054 | 2062 |
|
2055 | 2063 | if not isinstance(expression, Variable):
|
2056 | 2064 | raise TypeError("jacobian expects a Variable as `expression`")
|
2057 | 2065 |
|
2058 |
| - if expression.ndim > 1: |
2059 |
| - raise ValueError( |
2060 |
| - "jacobian expects a 1 dimensional variable as `expression`." |
2061 |
| - " If not use flatten to make it a vector" |
2062 |
| - ) |
2063 |
| - |
2064 | 2066 | using_list = isinstance(wrt, list)
|
2065 | 2067 | using_tuple = isinstance(wrt, tuple)
|
| 2068 | + grad_kwargs = { |
| 2069 | + "consider_constant": consider_constant, |
| 2070 | + "disconnected_inputs": disconnected_inputs, |
| 2071 | + } |
2066 | 2072 |
|
2067 | 2073 | if isinstance(wrt, list | tuple):
|
2068 | 2074 | wrt = list(wrt)
|
2069 | 2075 | else:
|
2070 | 2076 | wrt = [wrt]
|
2071 | 2077 |
|
2072 |
| - if expression.ndim == 0: |
2073 |
| - # expression is just a scalar, use grad |
2074 |
| - return as_list_or_tuple( |
2075 |
| - using_list, |
2076 |
| - using_tuple, |
2077 |
| - grad( |
2078 |
| - expression, |
2079 |
| - wrt, |
2080 |
| - consider_constant=consider_constant, |
2081 |
| - disconnected_inputs=disconnected_inputs, |
2082 |
| - ), |
| 2078 | + if all(expression.type.broadcastable): |
| 2079 | + jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs) |
| 2080 | + |
| 2081 | + elif vectorize: |
| 2082 | + expression_flat = expression.ravel() |
| 2083 | + row_tangent = _float_ones_like(expression_flat).type("row_tangent") |
| 2084 | + jacobian_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs) |
| 2085 | + jacobian_matrices = vectorize_graph( |
| 2086 | + jacobian_rows, |
| 2087 | + replace={row_tangent: eye(expression_flat.size, dtype=row_tangent.dtype)}, |
2083 | 2088 | )
|
| 2089 | + # row_index = scalar("idx", dtype="int64") |
| 2090 | + # jacobian_rows = grad( |
| 2091 | + # expression.ravel()[row_index], |
| 2092 | + # wrt=wrt, |
| 2093 | + # **grad_kwargs |
| 2094 | + # ) |
| 2095 | + # rows_indices = arange(expression.size) |
| 2096 | + # jacobian_matrices = vectorize_graph( |
| 2097 | + # jacobian_rows, replace={row_index: rows_indices} |
| 2098 | + # ) |
| 2099 | + else: |
2084 | 2100 |
|
2085 |
| - def inner_function(*args): |
2086 |
| - idx = args[0] |
2087 |
| - expr = args[1] |
2088 |
| - rvals = [] |
2089 |
| - for inp in args[2:]: |
2090 |
| - rval = grad( |
2091 |
| - expr[idx], |
2092 |
| - inp, |
2093 |
| - consider_constant=consider_constant, |
2094 |
| - disconnected_inputs=disconnected_inputs, |
| 2101 | + def inner_function(*args): |
| 2102 | + idx, expr, *wrt = args |
| 2103 | + return grad(expr[idx], wrt, **grad_kwargs) |
| 2104 | + |
| 2105 | + jacobian_matrices, updates = pytensor.scan( |
| 2106 | + inner_function, |
| 2107 | + sequences=pytensor.tensor.arange(expression.size), |
| 2108 | + non_sequences=[expression.ravel(), *wrt], |
| 2109 | + return_list=True, |
| 2110 | + ) |
| 2111 | + if updates: |
| 2112 | + raise ValueError( |
| 2113 | + "The scan used to build the jacobian matrices returned a list of updates" |
2095 | 2114 | )
|
2096 |
| - rvals.append(rval) |
2097 |
| - return rvals |
2098 |
| - |
2099 |
| - # Computing the gradients does not affect the random seeds on any random |
2100 |
| - # generator used n expression (because during computing gradients we are |
2101 |
| - # just backtracking over old values. (rp Jan 2012 - if anyone has a |
2102 |
| - # counter example please show me) |
2103 |
| - jacobs, updates = pytensor.scan( |
2104 |
| - inner_function, |
2105 |
| - sequences=pytensor.tensor.arange(expression.shape[0]), |
2106 |
| - non_sequences=[expression, *wrt], |
2107 |
| - ) |
2108 |
| - assert not updates, "Scan has returned a list of updates; this should not happen." |
2109 |
| - return as_list_or_tuple(using_list, using_tuple, jacobs) |
| 2115 | + |
| 2116 | + if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim): |
| 2117 | + # There was some raveling or squeezing done prior to getting the jacobians |
| 2118 | + # Reshape into original shapes |
| 2119 | + jacobian_matrices = [ |
| 2120 | + jac_matrix.reshape((*expression.shape, *w.shape)) |
| 2121 | + for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True) |
| 2122 | + ] |
| 2123 | + |
| 2124 | + return as_list_or_tuple(using_list, using_tuple, jacobian_matrices) |
2110 | 2125 |
|
2111 | 2126 |
|
2112 | 2127 | def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
|
|
0 commit comments